diff --git a/docker/README.md b/docker/README.md index d5a472da7d..9e7ac15145 100644 --- a/docker/README.md +++ b/docker/README.md @@ -64,7 +64,7 @@ GPU_DRIVER=nvidia Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail. -## Even Moar Customizing! +## Even More Customizing! See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below. diff --git a/docs/contributing/frontend/WORKFLOWS.md b/docs/contributing/frontend/WORKFLOWS.md index e71d797b8a..533419e070 100644 --- a/docs/contributing/frontend/WORKFLOWS.md +++ b/docs/contributing/frontend/WORKFLOWS.md @@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances "Custom" fields will always be treated as stateless fields. -##### Collection and Scalar Fields +##### Single and Collection Fields -Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field. +Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field. -If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`). - -If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). +- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`). +- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`). +- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). ## Implementation @@ -173,8 +173,7 @@ Field types are represented as structured objects: ```ts type FieldType = { name: string; - isCollection: boolean; - isCollectionOrScalar: boolean; + cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION'; }; ``` @@ -186,7 +185,7 @@ There are 4 general cases for field type parsing. When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. -We create a field type name from this `type` string (e.g. `string` -> `StringField`). +We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`. ##### Complex Types @@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. -We use the item type for field type name, adding `isCollection: true` to the field type. +We use the item type for field type name. The cardinality is `"COLLECTION"`. -##### Collection or Scalar Types +##### Single or Collection Types When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. -After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type. +After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`. ##### Optional Fields diff --git a/docs/features/CONTROLNET.md b/docs/features/CONTROLNET.md index d07353089d..718b12b0f8 100644 --- a/docs/features/CONTROLNET.md +++ b/docs/features/CONTROLNET.md @@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o There are several ways to install IP-Adapter models with an existing InvokeAI installation: 1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models. -2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. +2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. 3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder. #### Using IP-Adapter diff --git a/docs/help/diffusion.md b/docs/help/diffusion.md index 0dbb09f304..7182a51d67 100644 --- a/docs/help/diffusion.md +++ b/docs/help/diffusion.md @@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s 4. The VAE decodes the final latent image from latent space into image space. Image-to-image is a similar process, with only step 1 being different: -1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. +1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor). diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 36859a5795..059834eb45 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The ### Requirements -Before you start, go through the [installation requirements]. +Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md). ### Installation Walkthrough @@ -79,7 +79,7 @@ Before you start, go through the [installation requirements]. 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. - - You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command. + - You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command. !!! example "Install with an extra index URL" @@ -116,4 +116,4 @@ Before you start, go through the [installation requirements]. !!! warning - If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. + If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. diff --git a/installer/templates/invoke.bat.in b/installer/templates/invoke.bat.in index 9c9c08d82d..c8ef19710b 100644 --- a/installer/templates/invoke.bat.in +++ b/installer/templates/invoke.bat.in @@ -10,8 +10,7 @@ set INVOKEAI_ROOT=. echo Desired action: echo 1. Generate images with the browser-based interface echo 2. Open the developer console -echo 3. Run the InvokeAI image database maintenance script -echo 4. Command-line help +echo 3. Command-line help echo Q - Quit echo. echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest. @@ -34,9 +33,6 @@ IF /I "%choice%" == "1" ( echo *** Type `exit` to quit this shell and deactivate the Python virtual environment *** call cmd /k ) ELSE IF /I "%choice%" == "3" ( - echo Running the db maintenance script... - python .venv\Scripts\invokeai-db-maintenance.exe -) ELSE IF /I "%choice%" == "4" ( echo Displaying command line help... python .venv\Scripts\invokeai-web.exe --help %* pause diff --git a/installer/templates/invoke.sh.in b/installer/templates/invoke.sh.in index 9c45eba5b0..b8d5a7af23 100644 --- a/installer/templates/invoke.sh.in +++ b/installer/templates/invoke.sh.in @@ -47,11 +47,6 @@ do_choice() { bash --init-file "$file_name" ;; 3) - clear - printf "Running the db maintenance script\n" - invokeai-db-maintenance --root ${INVOKEAI_ROOT} - ;; - 4) clear printf "Command-line help\n" invokeai-web --help @@ -71,8 +66,7 @@ do_line_input() { printf "What would you like to do?\n" printf "1: Generate images using the browser-based interface\n" printf "2: Open the developer console\n" - printf "3: Run the InvokeAI image database maintenance script\n" - printf "4: Command-line help\n" + printf "3: Command-line help\n" printf "Q: Quit\n\n" printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n" read -p "Please enter 1-4, Q: [1] " yn diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0cfcf2f3b7..19a7bb083d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService +from ..services.events.events_fastapievents import FastAPIEventService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -29,11 +30,10 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService -from ..services.session_processor.session_processor_default import DefaultSessionProcessor +from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage -from .events import FastAPIEventService # TODO: is there a better way to achieve this? @@ -103,7 +103,7 @@ class ApiDependencies: ) names = SimpleNameService() performance_statistics = InvocationStatsService() - session_processor = DefaultSessionProcessor() + session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py deleted file mode 100644 index 2ac07e6dfe..0000000000 --- a/invokeai/app/api/events.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -import threading -from queue import Empty, Queue -from typing import Any - -from fastapi_events.dispatcher import dispatch - -from ..services.events.events_base import EventServiceBase - - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event)) - - super().__init__() - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put({"event_name": event_name, "payload": payload}) - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block=False) - if not event: # Probably stopping - continue - - dispatch( - event.get("event_name"), - payload=event.get("payload"), - middleware_id=self.event_handler_id, - ) - - except Empty: - await asyncio.sleep(0.1) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 9c55ff6531..a947b83abe 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin @@ -41,14 +41,17 @@ async def upload_image( board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), + metadata: Optional[JsonValue] = Body( + default=None, description="The metadata to associate with the image", embed=True + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") - metadata = None - workflow = None - graph = None + _metadata = None + _workflow = None + _graph = None contents = await file.read() try: @@ -62,27 +65,27 @@ async def upload_image( # TODO: retain non-invokeai metadata on upload? # attempt to parse metadata from image - metadata_raw = pil_image.info.get("invokeai_metadata", None) + metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None) if isinstance(metadata_raw, str): - metadata = metadata_raw + _metadata = metadata_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image") pass # attempt to parse workflow from image workflow_raw = pil_image.info.get("invokeai_workflow", None) if isinstance(workflow_raw, str): - workflow = workflow_raw + _workflow = workflow_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image") pass # attempt to extract graph from image graph_raw = pil_image.info.get("invokeai_graph", None) if isinstance(graph_raw, str): - graph = graph_raw + _graph = graph_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image") pass try: @@ -92,9 +95,9 @@ async def upload_image( image_category=image_category, session_id=session_id, board_id=board_id, - metadata=metadata, - workflow=workflow, - graph=graph, + metadata=_metadata, + workflow=_workflow, + graph=_graph, is_intermediate=is_intermediate, ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 1ba3e30e07..b1221f7a34 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException -from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 40f1f2213b..7161e54a41 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -203,6 +203,7 @@ async def get_batch_status( responses={ 200: {"model": SessionQueueItem}, }, + response_model_exclude_none=True, ) async def get_queue_item( queue_id: str = Path(description="The queue id to perform this operation on"), diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 463545d9bc..b39922c69b 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -1,66 +1,125 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from typing import Any + from fastapi import FastAPI -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event +from pydantic import BaseModel from socketio import ASGIApp, AsyncServer -from ..services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadEventBase, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadEventBase, + DownloadProgressEvent, + DownloadStartedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelEventBase, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueEventBase, + QueueItemStatusChangedEvent, + register_events, +) + + +class QueueSubscriptionEvent(BaseModel): + """Event data for subscribing to the socket.io queue room. + This is a pydantic model to ensure the data is in the correct format.""" + + queue_id: str + + +class BulkDownloadSubscriptionEvent(BaseModel): + """Event data for subscribing to the socket.io bulk downloads room. + This is a pydantic model to ensure the data is in the correct format.""" + + bulk_download_id: str + + +QUEUE_EVENTS = { + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, +} + +MODEL_EVENTS = { + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ModelInstallCancelledEvent, + ModelInstallErrorEvent, +} + +BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} class SocketIO: - __sio: AsyncServer - __app: ASGIApp + _sub_queue = "subscribe_queue" + _unsub_queue = "unsubscribe_queue" - __sub_queue: str = "subscribe_queue" - __unsub_queue: str = "unsubscribe_queue" - - __sub_bulk_download: str = "subscribe_bulk_download" - __unsub_bulk_download: str = "unsubscribe_bulk_download" + _sub_bulk_download = "subscribe_bulk_download" + _unsub_bulk_download = "unsubscribe_bulk_download" def __init__(self, app: FastAPI): - self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") - self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") - app.mount("/ws", self.__app) + self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") + app.mount("/ws", self._app) - self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) - self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) - local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self._sio.on(self._sub_queue, handler=self._handle_sub_queue) + self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) + self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) + self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) - self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) - self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) - local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) + register_events(QUEUE_EVENTS, self._handle_queue_event) + register_events(MODEL_EVENTS, self._handle_model_event) + register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) - async def _handle_queue_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["queue_id"], - ) + async def _handle_sub_queue(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) - async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.enter_room(sid, data["queue_id"]) + async def _handle_unsub_queue(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) - async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.leave_room(sid, data["queue_id"]) + async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) + async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) - async def _handle_bulk_download_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["bulk_download_id"], - ) + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id) - async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.enter_room(sid, data["bulk_download_id"]) + async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None: + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) - async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.leave_room(sid, data["bulk_download_id"]) + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 062682f7d0..b7da548377 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice @@ -182,23 +183,14 @@ def custom_openapi() -> dict[str, Any]: openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) invoker_schema["class"] = "invocation" - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } + # Add all event schemas + for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): + json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") + if "$defs" in json_schema: + for schema_key, schema in json_schema["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema + del json_schema["$defs"] + openapi_schema["components"]["schemas"][event.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 158f11a58e..766b44fdc8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(self.clip.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(self.clip.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation): ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(tokenizer, CLIPTokenizer) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) if context.config.get().log_tokenization: - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -136,11 +134,7 @@ class SDXLPromptInvocationBase: zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: tokenizer_info = context.models.load(clip_field.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(clip_field.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -177,20 +171,23 @@ class SDXLPromptInvocationBase: ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + assert isinstance(tokenizer, CLIPTokenizer) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -203,7 +200,7 @@ class SDXLPromptInvocationBase: if context.config.get().log_tokenization: # TODO: better logging for and syntax - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) # TODO: ask for optimizations? to not run text_encoder twice c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e69f4b54ad..e533583829 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,7 +25,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, - Input, InputField, OutputField, UIType, @@ -82,13 +81,13 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) -@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") +@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") control_model: ModelIdentifierField = InputField( - description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel + description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel ) control_weight: Union[float, List[float]] = InputField( default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 34a30628da..de40879eef 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput): CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation): ip_adapter_model: ModelIdentifierField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.IPAdapterModel, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3ac3973bf..a88eff0fcb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), - set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 245034c481..94a6136fcb 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) @@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): pass +@invocation_output("model_identifier_output") +class ModelIdentifierOutput(BaseInvocationOutput): + """Model identifier output""" + + model: ModelIdentifierField = OutputField(description="Model identifier", title="Model") + + +@invocation( + "model_identifier", + title="Model identifier", + tags=["model"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class ModelIdentifierInvocation(BaseInvocation): + """Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as + input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an + error.""" + + model: ModelIdentifierField = InputField(description="The model to select", title="Model") + + def invoke(self, context: InvocationContext) -> ModelIdentifierOutput: + if not context.models.exists(self.model.key): + raise Exception(f"Unknown model {self.model.key}") + + return ModelIdentifierOutput(model=self.model) + + @invocation( "main_model_loader", title="Main Model", tags=["model"], category="model", - version="1.0.2", + version="1.0.3", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" - model: ModelIdentifierField = InputField( - description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel - ) + model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel) # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: @@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput): clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3") class LoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -197,12 +225,12 @@ class LoRASelectorOutput(BaseInvocationOutput): lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA") -@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0") +@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1") class LoRASelectorInvocation(BaseInvocation): """Selects a LoRA model and weight.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) @@ -273,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLLoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -414,12 +442,12 @@ class SDXLLoRACollectionLoader(BaseInvocation): return output -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3") class VAELoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" vae_model: ModelIdentifierField = InputField( - description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel + description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel ) def invoke(self, context: InvocationContext) -> VAEOutput: diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 9b1ee90350..1c0817cb92 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,4 @@ -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import SubModelType @@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel + description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel ) # TODO: precision? @@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel + description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel ) # TODO: precision? diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index b22a089d3f..04f9a6c695 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput): @invocation( - "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" + "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3" ) class T2IAdapterInvocation(BaseInvocation): """Collects T2I-Adapter info to pass to other nodes.""" @@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation): t2i_adapter_model: ModelIdentifierField = InputField( description="The T2I-Adapter model.", title="T2I-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.T2IAdapterModel, ) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 04cec928f4..d4bf059b8f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_completed( @@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None - self._invoker.services.events.emit_bulk_download_completed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + self._invoker.services.events.emit_bulk_download_complete( + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_failed( @@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert exception is not None - self._invoker.services.events.emit_bulk_download_failed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, - error=str(exception), + self._invoker.services.events.emit_bulk_download_error( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) ) def stop(self, *args, **kwargs): diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 4555477004..5025255c91 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,7 +8,7 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl @@ -34,6 +34,9 @@ from .download_base import ( UnknownJobIDException, ) +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -45,7 +48,7 @@ class DownloadQueueService(DownloadQueueServiceBase): self, max_parallel_dl: int = 5, app_config: Optional[InvokeAIAppConfig] = None, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ @@ -408,28 +411,18 @@ class DownloadQueueService(DownloadQueueServiceBase): job.status = DownloadJobStatus.RUNNING self._execute_cb(job, "on_start") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix()) + self._event_bus.emit_download_started(job) def _signal_job_progress(self, job: DownloadJob) -> None: self._execute_cb(job, "on_progress") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_progress( - str(job.source), - download_path=job.download_path.as_posix(), - current_bytes=job.bytes, - total_bytes=job.total_bytes, - ) + self._event_bus.emit_download_progress(job) def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED self._execute_cb(job, "on_complete") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_complete( - str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes - ) + self._event_bus.emit_download_complete(job) def _signal_job_cancelled(self, job: DownloadJob) -> None: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: @@ -437,7 +430,7 @@ class DownloadQueueService(DownloadQueueServiceBase): job.status = DownloadJobStatus.CANCELLED self._execute_cb(job, "on_cancelled") if self._event_bus: - self._event_bus.emit_download_cancelled(str(job.source)) + self._event_bus.emit_download_cancelled(job) # if multifile download, then signal the parent if parent_job := self._download_part2parent.get(job.source, None): @@ -451,9 +444,7 @@ class DownloadQueueService(DownloadQueueServiceBase): self._execute_cb(job, "on_error", excp) if self._event_bus: - assert job.error_type - assert job.error - self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error) + self._event_bus.emit_download_error(job) def _cleanup_cancelled_job(self, job: DownloadJob) -> None: self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa91cdaec8..3c0fb0a30b 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,490 +1,195 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Optional -from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.app.services.session_queue.session_queue_common import ( - BatchStatus, - EnqueueBatchResult, - SessionQueueItem, - SessionQueueStatus, +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + EventBase, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, ) -from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import AnyModelConfig -from invokeai.backend.model_manager.config import SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput + from invokeai.app.services.download.download_base import DownloadJob + from invokeai.app.services.events.events_common import EventBase + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + from invokeai.app.services.session_processor.session_processor_common import ProgressImage + from invokeai.app.services.session_queue.session_queue_common import ( + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, + ) + from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType class EventServiceBase: - queue_event: str = "queue_event" - bulk_download_event: str = "bulk_download_event" - download_event: str = "download_event" - model_event: str = "model_event" - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + def dispatch(self, event: "EventBase") -> None: pass - def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: - """Bulk download events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.bulk_download_event, - payload={"event": event_name, "data": payload}, - ) + # region: Invocation - def __emit_queue_event(self, event_name: str, payload: dict) -> None: - """Queue events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.queue_event, - payload={"event": event_name, "data": payload}, - ) + def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: + """Emitted when an invocation is started""" + self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) - def __emit_download_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.download_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_model_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.model_event, - payload={"event": event_name, "data": payload}, - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress( + def emit_invocation_denoise_progress( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - source_node_id: str, - progress_image: Optional[ProgressImage], - step: int, - order: int, - total_steps: int, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", + intermediate_state: PipelineIntermediateState, + progress_image: "ProgressImage", ) -> None: - """Emitted when there is generation progress""" - self.__emit_queue_event( - event_name="generator_progress", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "source_node_id": source_node_id, - "progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None, - "step": step, - "order": order, - "total_steps": total_steps, - }, - ) + """Emitted at each step during denoising of an invocation.""" + self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) def emit_invocation_complete( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - result: dict, - node: dict, - source_node_id: str, + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "result": result, - }, - ) + """Emitted when an invocation is complete""" + self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) def emit_invocation_error( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", error_type: str, - error: str, - user_id: str | None, - project_id: str | None, + error_message: str, + error_traceback: str, ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "error_type": error_type, - "error": error, - "user_id": user_id, - "project_id": project_id, - }, - ) + """Emitted when an invocation encounters an error""" + self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback)) - def emit_invocation_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has started""" - self.__emit_queue_event( - event_name="invocation_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - }, - ) + # endregion - def emit_graph_execution_complete( - self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str - ) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_queue_event( - event_name="graph_execution_state_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) - - def emit_model_load_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is requested""" - self.__emit_queue_event( - event_name="model_load_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) - - def emit_model_load_completed( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_queue_event( - event_name="model_load_completed", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) - - def emit_session_canceled( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - ) -> None: - """Emitted when a session is canceled""" - self.__emit_queue_event( - event_name="session_canceled", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + # region Queue def emit_queue_item_status_changed( - self, - session_queue_item: SessionQueueItem, - batch_status: BatchStatus, - queue_status: SessionQueueStatus, + self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" ) -> None: """Emitted when a queue item's status changes""" - self.__emit_queue_event( - event_name="queue_item_status_changed", - payload={ - "queue_id": queue_status.queue_id, - "queue_item": { - "queue_id": session_queue_item.queue_id, - "item_id": session_queue_item.item_id, - "status": session_queue_item.status, - "batch_id": session_queue_item.batch_id, - "session_id": session_queue_item.session_id, - "error": session_queue_item.error, - "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, - "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, - "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - }, - "batch_status": batch_status.model_dump(mode="json"), - "queue_status": queue_status.model_dump(mode="json"), - }, - ) + self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: """Emitted when a batch is enqueued""" - self.__emit_queue_event( - event_name="batch_enqueued", - payload={ - "queue_id": enqueue_result.queue_id, - "batch_id": enqueue_result.batch.batch_id, - "enqueued": enqueue_result.enqueued, - }, - ) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) def emit_queue_cleared(self, queue_id: str) -> None: - """Emitted when the queue is cleared""" - self.__emit_queue_event( - event_name="queue_cleared", - payload={"queue_id": queue_id}, - ) + """Emitted when a queue is cleared""" + self.dispatch(QueueClearedEvent.build(queue_id)) - def emit_download_started(self, source: str, download_path: str) -> None: - """ - Emit when a download job is started. + # endregion - :param url: The downloaded url - """ - self.__emit_download_event( - event_name="download_started", - payload={"source": source, "download_path": download_path}, - ) + # region Download - def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: - """ - Emit "download_progress" events at regular intervals during a download job. + def emit_download_started(self, job: "DownloadJob") -> None: + """Emitted when a download is started""" + self.dispatch(DownloadStartedEvent.build(job)) - :param source: The downloaded source - :param download_path: The local downloaded file - :param current_bytes: Number of bytes downloaded so far - :param total_bytes: The size of the file being downloaded (if known) - """ - self.__emit_download_event( - event_name="download_progress", - payload={ - "source": source, - "download_path": download_path, - "current_bytes": current_bytes, - "total_bytes": total_bytes, - }, - ) + def emit_download_progress(self, job: "DownloadJob") -> None: + """Emitted at intervals during a download""" + self.dispatch(DownloadProgressEvent.build(job)) - def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: - """ - Emit a "download_complete" event at the end of a successful download. + def emit_download_complete(self, job: "DownloadJob") -> None: + """Emitted when a download is completed""" + self.dispatch(DownloadCompleteEvent.build(job)) - :param source: Source URL - :param download_path: Path to the locally downloaded file - :param total_bytes: The size of the downloaded file - """ - self.__emit_download_event( - event_name="download_complete", - payload={ - "source": source, - "download_path": download_path, - "total_bytes": total_bytes, - }, - ) + def emit_download_cancelled(self, job: "DownloadJob") -> None: + """Emitted when a download is cancelled""" + self.dispatch(DownloadCancelledEvent.build(job)) - def emit_download_cancelled(self, source: str) -> None: - """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" - self.__emit_download_event( - event_name="download_cancelled", - payload={ - "source": source, - }, - ) + def emit_download_error(self, job: "DownloadJob") -> None: + """Emitted when a download encounters an error""" + self.dispatch(DownloadErrorEvent.build(job)) - def emit_download_error(self, source: str, error_type: str, error: str) -> None: - """ - Emit a "download_error" event when an download job encounters an exception. + # endregion - :param source: Source URL - :param error_type: The name of the exception that raised the error - :param error: The traceback from this error - """ - self.__emit_download_event( - event_name="download_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) + # region Model loading - def emit_model_install_downloading( - self, - source: str, - local_path: str, - bytes: int, - total_bytes: int, - parts: List[Dict[str, Union[str, int]]], - id: int, + def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + """Emitted when a model load is started.""" + self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) + + def emit_model_load_complete( + self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None ) -> None: - """ - Emit at intervals while the install job is in progress (remote models only). + """Emitted when a model load is complete.""" + self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) - :param source: Source of the model - :param local_path: Where model is downloading to - :param parts: Progress of downloading URLs that comprise the model, if any. - :param bytes: Number of bytes downloaded so far. - :param total_bytes: Total size of download, including all files. - This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes". - """ - self.__emit_model_event( - event_name="model_install_downloading", - payload={ - "source": source, - "local_path": local_path, - "bytes": bytes, - "total_bytes": total_bytes, - "parts": parts, - "id": id, - }, - ) + # endregion - def emit_model_install_downloads_done(self, source: str) -> None: - """ - Emit once when all parts are downloaded, but before the probing and registration start. + # region Model install - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_downloads_done", - payload={"source": source}, - ) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: + """Emitted at intervals while the install job is in progress (remote models only).""" + self.dispatch(ModelInstallDownloadProgressEvent.build(job)) - def emit_model_install_running(self, source: str) -> None: - """ - Emit once when an install job becomes active. + def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_running", - payload={"source": source}, - ) + def emit_model_install_started(self, job: "ModelInstallJob") -> None: + """Emitted once when an install job is started (after any download).""" + self.dispatch(ModelInstallStartedEvent.build(job)) - def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: - """ - Emit when an install job is completed successfully. + def emit_model_install_complete(self, job: "ModelInstallJob") -> None: + """Emitted when an install job is completed successfully.""" + self.dispatch(ModelInstallCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - :param key: Model config record key - :param total_bytes: Size of the model (may be None for installation of a local path) - """ - self.__emit_model_event( - event_name="model_install_completed", - payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id}, - ) + def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: + """Emitted when an install job is cancelled.""" + self.dispatch(ModelInstallCancelledEvent.build(job)) - def emit_model_install_cancelled(self, source: str, id: int) -> None: - """ - Emit when an install job is cancelled. + def emit_model_install_error(self, job: "ModelInstallJob") -> None: + """Emitted when an install job encounters an exception.""" + self.dispatch(ModelInstallErrorEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_cancelled", - payload={"source": source, "id": id}, - ) + # endregion - def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: - """ - Emit when an install job encounters an exception. - - :param source: Source of the model - :param error_type: The name of the exception - :param error: A text description of the exception - """ - self.__emit_model_event( - event_name="model_install_error", - payload={"source": source, "error_type": error_type, "error": error, "id": id}, - ) + # region Bulk image download def emit_bulk_download_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download starts""" - self._emit_bulk_download_event( - event_name="bulk_download_started", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + """Emitted when a bulk image download is started""" + self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_completed( + def emit_bulk_download_complete( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download completes""" - self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + """Emitted when a bulk image download is complete""" + self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_failed( + def emit_bulk_download_error( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str ) -> None: - """Emitted when a bulk download fails""" - self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - "error": error, - }, + """Emitted when a bulk image download has an error""" + self.dispatch( + BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) ) + + # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py new file mode 100644 index 0000000000..7d3d489bf5 --- /dev/null +++ b/invokeai/app/services/events/events_common.py @@ -0,0 +1,591 @@ +from math import floor +from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar + +from fastapi_events.handlers.local import local_handler +from fastapi_events.registry.payload_schema import registry as payload_schema +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.services.session_queue.session_queue_common import ( + QUEUE_ITEM_STATUS, + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, +) +from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.services.download.download_base import DownloadJob + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + + +class EventBase(BaseModel): + """Base class for all events. All events must inherit from this class. + + Events must define a class attribute `__event_name__` to identify the event. + + All other attributes should be defined as normal for a pydantic model. + + A timestamp is automatically added to the event when it is created. + """ + + timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + @classmethod + def get_events(cls) -> set[type["EventBase"]]: + """Get a set of all event models.""" + + event_subclasses: set[type["EventBase"]] = set() + for subclass in cls.__subclasses__(): + # We only want to include subclasses that are event models, not intermediary classes + if hasattr(subclass, "__event_name__"): + event_subclasses.add(subclass) + event_subclasses.update(subclass.get_events()) + + return event_subclasses + + +TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True) + +FastAPIEvent: TypeAlias = tuple[str, TEvent] +""" +A tuple representing a `fastapi-events` event, with the event name and payload. +Provide a generic type to `TEvent` to specify the payload type. +""" + + +class FastAPIEventFunc(Protocol, Generic[TEvent]): + def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ... + + +def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None: + """Register a function to handle specific events. + + :param events: An event or set of events to handle + :param func: The function to handle the events + """ + events = events if isinstance(events, set) else {events} + for event in events: + assert hasattr(event, "__event_name__") + local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] + + +class QueueEventBase(EventBase): + """Base class for queue events""" + + queue_id: str = Field(description="The ID of the queue") + + +class QueueItemEventBase(QueueEventBase): + """Base class for queue item events""" + + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + + +class InvocationEventBase(QueueItemEventBase): + """Base class for invocation events""" + + session_id: str = Field(description="The ID of the session (aka graph execution state)") + queue_id: str = Field(description="The ID of the queue") + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") + invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") + + +@payload_schema.register +class InvocationStartedEvent(InvocationEventBase): + """Event model for invocation_started""" + + __event_name__ = "invocation_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + ) + + +@payload_schema.register +class InvocationDenoiseProgressEvent(InvocationEventBase): + """Event model for invocation_denoise_progress""" + + __event_name__ = "invocation_denoise_progress" + + progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") + step: int = Field(description="The current step of the invocation") + total_steps: int = Field(description="The total number of steps in the invocation") + order: int = Field(description="The order of the invocation in the session") + percentage: float = Field(description="The percentage of completion of the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + intermediate_state: PipelineIntermediateState, + progress_image: ProgressImage, + ) -> "InvocationDenoiseProgressEvent": + step = intermediate_state.step + total_steps = intermediate_state.total_steps + order = intermediate_state.order + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + progress_image=progress_image, + step=step, + total_steps=total_steps, + order=order, + percentage=cls.calc_percentage(step, total_steps, order), + ) + + @staticmethod + def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float: + """Calculate the percentage of completion of denoising.""" + if total_steps == 0: + return 0.0 + if scheduler_order == 2: + return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2) + # order == 1 + return (step + 1 + 1) / (total_steps + 1) + + +@payload_schema.register +class InvocationCompleteEvent(InvocationEventBase): + """Event model for invocation_complete""" + + __event_name__ = "invocation_complete" + + result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + ) -> "InvocationCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + result=result, + ) + + +@payload_schema.register +class InvocationErrorEvent(InvocationEventBase): + """Event model for invocation_error""" + + __event_name__ = "invocation_error" + + error_type: str = Field(description="The error type") + error_message: str = Field(description="The error message") + error_traceback: str = Field(description="The error traceback") + user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation") + project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + error_type: str, + error_message: str, + error_traceback: str, + ) -> "InvocationErrorEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + user_id=getattr(queue_item, "user_id", None), + project_id=getattr(queue_item, "project_id", None), + ) + + +@payload_schema.register +class QueueItemStatusChangedEvent(QueueItemEventBase): + """Event model for queue_item_status_changed""" + + __event_name__ = "queue_item_status_changed" + + status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item") + error_type: Optional[str] = Field(default=None, description="The error type, if any") + error_message: Optional[str] = Field(default=None, description="The error message, if any") + error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any") + created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created") + updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated") + started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started") + completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed") + batch_status: BatchStatus = Field(description="The status of the batch") + queue_status: SessionQueueStatus = Field(description="The status of the queue") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + ) -> "QueueItemStatusChangedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + status=queue_item.status, + error_type=queue_item.error_type, + error_message=queue_item.error_message, + error_traceback=queue_item.error_traceback, + created_at=str(queue_item.created_at) if queue_item.created_at else None, + updated_at=str(queue_item.updated_at) if queue_item.updated_at else None, + started_at=str(queue_item.started_at) if queue_item.started_at else None, + completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, + batch_status=batch_status, + queue_status=queue_status, + ) + + +@payload_schema.register +class BatchEnqueuedEvent(QueueEventBase): + """Event model for batch_enqueued""" + + __event_name__ = "batch_enqueued" + + batch_id: str = Field(description="The ID of the batch") + enqueued: int = Field(description="The number of invocations enqueued") + requested: int = Field( + description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" + ) + priority: int = Field(description="The priority of the batch") + + @classmethod + def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + return cls( + queue_id=enqueue_result.queue_id, + batch_id=enqueue_result.batch.batch_id, + enqueued=enqueue_result.enqueued, + requested=enqueue_result.requested, + priority=enqueue_result.priority, + ) + + +@payload_schema.register +class QueueClearedEvent(QueueEventBase): + """Event model for queue_cleared""" + + __event_name__ = "queue_cleared" + + @classmethod + def build(cls, queue_id: str) -> "QueueClearedEvent": + return cls(queue_id=queue_id) + + +class DownloadEventBase(EventBase): + """Base class for events associated with a download""" + + source: str = Field(description="The source of the download") + + +@payload_schema.register +class DownloadStartedEvent(DownloadEventBase): + """Event model for download_started""" + + __event_name__ = "download_started" + + download_path: str = Field(description="The local path where the download is saved") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadStartedEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix()) + + +@payload_schema.register +class DownloadProgressEvent(DownloadEventBase): + """Event model for download_progress""" + + __event_name__ = "download_progress" + + download_path: str = Field(description="The local path where the download is saved") + current_bytes: int = Field(description="The number of bytes downloaded so far") + total_bytes: int = Field(description="The total number of bytes to be downloaded") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadProgressEvent": + assert job.download_path + return cls( + source=str(job.source), + download_path=job.download_path.as_posix(), + current_bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register +class DownloadCompleteEvent(DownloadEventBase): + """Event model for download_complete""" + + __event_name__ = "download_complete" + + download_path: str = Field(description="The local path where the download is saved") + total_bytes: int = Field(description="The total number of bytes downloaded") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes) + + +@payload_schema.register +class DownloadCancelledEvent(DownloadEventBase): + """Event model for download_cancelled""" + + __event_name__ = "download_cancelled" + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent": + return cls(source=str(job.source)) + + +@payload_schema.register +class DownloadErrorEvent(DownloadEventBase): + """Event model for download_error""" + + __event_name__ = "download_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadErrorEvent": + assert job.error_type + assert job.error + return cls(source=str(job.source), error_type=job.error_type, error=job.error) + + +class ModelEventBase(EventBase): + """Base class for events associated with a model""" + + +@payload_schema.register +class ModelLoadStartedEvent(ModelEventBase): + """Event model for model_load_started""" + + __event_name__ = "model_load_started" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register +class ModelLoadCompleteEvent(ModelEventBase): + """Event model for model_load_complete""" + + __event_name__ = "model_load_complete" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register +class ModelInstallDownloadProgressEvent(ModelEventBase): + """Event model for model_install_download_progress""" + + __event_name__ = "model_install_download_progress" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register +class ModelInstallDownloadsCompleteEvent(ModelEventBase): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_downloads_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallStartedEvent(ModelEventBase): + """Event model for model_install_started""" + + __event_name__ = "model_install_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallCompleteEvent(ModelEventBase): + """Event model for model_install_complete""" + + __event_name__ = "model_install_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + key: str = Field(description="Model config record key") + total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": + assert job.config_out is not None + return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) + + +@payload_schema.register +class ModelInstallCancelledEvent(ModelEventBase): + """Event model for model_install_cancelled""" + + __event_name__ = "model_install_cancelled" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallErrorEvent(ModelEventBase): + """Event model for model_install_error""" + + __event_name__ = "model_install_error" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + error_type: str = Field(description="The name of the exception") + error: str = Field(description="A text description of the exception") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent": + assert job.error_type is not None + assert job.error is not None + return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) + + +class BulkDownloadEventBase(EventBase): + """Base class for events associated with a bulk image download""" + + bulk_download_id: str = Field(description="The ID of the bulk image download") + bulk_download_item_id: str = Field(description="The ID of the bulk image download item") + bulk_download_item_name: str = Field(description="The name of the bulk image download item") + + +@payload_schema.register +class BulkDownloadStartedEvent(BulkDownloadEventBase): + """Event model for bulk_download_started""" + + __event_name__ = "bulk_download_started" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadStartedEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register +class BulkDownloadCompleteEvent(BulkDownloadEventBase): + """Event model for bulk_download_complete""" + + __event_name__ = "bulk_download_complete" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadCompleteEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register +class BulkDownloadErrorEvent(BulkDownloadEventBase): + """Event model for bulk_download_error""" + + __event_name__ = "bulk_download_error" + + error: str = Field(description="The error message") + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + ) -> "BulkDownloadErrorEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + error=error, + ) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py new file mode 100644 index 0000000000..8279d3bb34 --- /dev/null +++ b/invokeai/app/services/events/events_fastapievents.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +import threading +from queue import Empty, Queue + +from fastapi_events.dispatcher import dispatch + +from invokeai.app.services.events.events_common import ( + EventBase, +) + +from .events_base import EventServiceBase + + +class FastAPIEventService(EventServiceBase): + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self._queue = Queue[EventBase | None]() + self._stop_event = threading.Event() + asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) + + super().__init__() + + def stop(self, *args, **kwargs): + self._stop_event.set() + self._queue.put(None) + + def dispatch(self, event: EventBase) -> None: + self._queue.put(event) + + async def _dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self._queue.get(block=False) + if not event: # Probably stopping + continue + # Leave the payloads as live pydantic models + dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) + + except Empty: + await asyncio.sleep(0.1) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/invokeai/app/services/model_install/__init__.py b/invokeai/app/services/model_install/__init__.py index 00a33c203e..941485a134 100644 --- a/invokeai/app/services/model_install/__init__.py +++ b/invokeai/app/services/model_install/__init__.py @@ -1,11 +1,13 @@ """Initialization file for model install service package.""" from .model_install_base import ( + ModelInstallServiceBase, +) +from .model_install_common import ( HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, UnknownInstallJobException, URLModelSource, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index b622c8dade..76b77f0419 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -1,242 +1,17 @@ # Copyright 2023 Lincoln D. Stein and the InvokeAI development team """Baseclass definitions for the model installer.""" -import re -import traceback from abc import ABC, abstractmethod -from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -from pydantic import BaseModel, Field, PrivateAttr, field_validator -from pydantic.networks import AnyHttpUrl -from typing_extensions import Annotated +from typing import Any, Dict, List, Optional, Union from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob +from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.config import ModelSourceType -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - - -class InstallStatus(str, Enum): - """State of an install job running in the background.""" - - WAITING = "waiting" # waiting to be dequeued - DOWNLOADING = "downloading" # downloading of model files in process - DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run - RUNNING = "running" # being processed - COMPLETED = "completed" # finished running - ERROR = "error" # terminated with an error message - CANCELLED = "cancelled" # terminated with an error message - - -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - -class UnknownInstallJobException(Exception): - """Raised when the status of an unknown job is requested.""" - - -class StringLikeSource(BaseModel): - """ - Base class for model sources, implements functions that lets the source be sorted and indexed. - - These shenanigans let this stuff work: - - source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') - mydict = {source1: 'model 1'} - assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' - assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' - - source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) - assert source1 == source2 - assert source1 == 'C:/users/mort/foo.safetensors' - """ - - def __hash__(self) -> int: - """Return hash of the path field, for indexing.""" - return hash(str(self)) - - def __lt__(self, other: object) -> int: - """Return comparison of the stringified version, for sorting.""" - return str(self) < str(other) - - def __eq__(self, other: object) -> bool: - """Return equality on the stringified version.""" - if isinstance(other, Path): - return str(self) == other.as_posix() - else: - return str(self) == str(other) - - -class LocalModelSource(StringLikeSource): - """A local file or directory path.""" - - path: str | Path - inplace: Optional[bool] = False - type: Literal["local"] = "local" - - # these methods allow the source to be used in a string-like way, - # for example as an index into a dict - def __str__(self) -> str: - """Return string version of path when string rep needed.""" - return Path(self.path).as_posix() - - -class HFModelSource(StringLikeSource): - """ - A HuggingFace repo_id with optional variant, sub-folder and access token. - Note that the variant option, if not provided to the constructor, will default to fp16, which is - what people (almost) always want. - """ - - repo_id: str - variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 - subfolder: Optional[Path] = None - access_token: Optional[str] = None - type: Literal["hf"] = "hf" - - @field_validator("repo_id") - @classmethod - def proper_repo_id(cls, v: str) -> str: # noqa D102 - if not re.match(r"^([.\w-]+/[.\w-]+)$", v): - raise ValueError(f"{v}: invalid repo_id format") - return v - - def __str__(self) -> str: - """Return string version of repoid when string rep needed.""" - base: str = self.repo_id - if self.variant: - base += f":{self.variant or ''}" - if self.subfolder: - base += f":{self.subfolder}" - return base - - -class URLModelSource(StringLikeSource): - """A generic URL point to a checkpoint file.""" - - url: AnyHttpUrl - access_token: Optional[str] = None - type: Literal["url"] = "url" - - def __str__(self) -> str: - """Return string version of the url when string rep needed.""" - return str(self.url) - - -ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] - -MODEL_SOURCE_TO_TYPE_MAP = { - URLModelSource: ModelSourceType.Url, - HFModelSource: ModelSourceType.HFRepoID, - LocalModelSource: ModelSourceType.Path, -} - - -class ModelInstallJob(BaseModel): - """Object that tracks the current status of an install request.""" - - id: int = Field(description="Unique ID for this job") - status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") - error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") - config_in: Dict[str, Any] = Field( - default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." - ) - config_out: Optional[AnyModelConfig] = Field( - default=None, description="After successful installation, this will hold the configuration object." - ) - inplace: bool = Field( - default=False, description="Leave model in its current location; otherwise install under models directory" - ) - source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") - local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") - bytes: int = Field( - default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" - ) - total_bytes: int = Field(default=0, description="Total size of the model to be installed") - source_metadata: Optional[AnyModelRepoMetadata] = Field( - default=None, description="Metadata provided by the model source" - ) - error: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the text of the exception" - ) - error_traceback: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the exception traceback" - ) - # internal flags and transitory settings - _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) - _exception: Optional[Exception] = PrivateAttr(default=None) - - def set_error(self, e: Exception) -> None: - """Record the error and traceback from an exception.""" - self._exception = e - self.error = str(e) - self.error_traceback = self._format_error(e) - self.status = InstallStatus.ERROR - self.error_reason = self._exception.__class__.__name__ if self._exception else None - - def cancel(self) -> None: - """Call to cancel the job.""" - self.status = InstallStatus.CANCELLED - - @property - def error_type(self) -> Optional[str]: - """Class name of the exception that led to status==ERROR.""" - return self._exception.__class__.__name__ if self._exception else None - - def _format_error(self, exception: Exception) -> str: - """Error traceback.""" - return "".join(traceback.format_exception(exception)) - - @property - def cancelled(self) -> bool: - """Set status to CANCELLED.""" - return self.status == InstallStatus.CANCELLED - - @property - def errored(self) -> bool: - """Return true if job has errored.""" - return self.status == InstallStatus.ERROR - - @property - def waiting(self) -> bool: - """Return true if job is waiting to run.""" - return self.status == InstallStatus.WAITING - - @property - def downloading(self) -> bool: - """Return true if job is downloading.""" - return self.status == InstallStatus.DOWNLOADING - - @property - def downloads_done(self) -> bool: - """Return true if job's downloads ae done.""" - return self.status == InstallStatus.DOWNLOADS_DONE - - @property - def running(self) -> bool: - """Return true if job is running.""" - return self.status == InstallStatus.RUNNING - - @property - def complete(self) -> bool: - """Return true if job completed without errors.""" - return self.status == InstallStatus.COMPLETED - - @property - def in_terminal_state(self) -> bool: - """Return true if job is in a terminal state.""" - return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] +from invokeai.backend.model_manager import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -280,7 +55,7 @@ class ModelInstallServiceBase(ABC): @property @abstractmethod - def event_bus(self) -> Optional[EventServiceBase]: + def event_bus(self) -> Optional["EventServiceBase"]: """Return the event service base object associated with the installer.""" @abstractmethod diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py new file mode 100644 index 0000000000..751b5baa4b --- /dev/null +++ b/invokeai/app/services/model_install/model_install_common.py @@ -0,0 +1,227 @@ +import re +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Set, Union + +from pydantic import BaseModel, Field, PrivateAttr, field_validator +from pydantic.networks import AnyHttpUrl +from typing_extensions import Annotated + +from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob +from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant +from invokeai.backend.model_manager.config import ModelSourceType +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class InstallStatus(str, Enum): + """State of an install job running in the background.""" + + WAITING = "waiting" # waiting to be dequeued + DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run + RUNNING = "running" # being processed + COMPLETED = "completed" # finished running + ERROR = "error" # terminated with an error message + CANCELLED = "cancelled" # terminated with an error message + + +class UnknownInstallJobException(Exception): + """Raised when the status of an unknown job is requested.""" + + +class StringLikeSource(BaseModel): + """ + Base class for model sources, implements functions that lets the source be sorted and indexed. + + These shenanigans let this stuff work: + + source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') + mydict = {source1: 'model 1'} + assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' + assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' + + source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) + assert source1 == source2 + assert source1 == 'C:/users/mort/foo.safetensors' + """ + + def __hash__(self) -> int: + """Return hash of the path field, for indexing.""" + return hash(str(self)) + + def __lt__(self, other: object) -> int: + """Return comparison of the stringified version, for sorting.""" + return str(self) < str(other) + + def __eq__(self, other: object) -> bool: + """Return equality on the stringified version.""" + if isinstance(other, Path): + return str(self) == other.as_posix() + else: + return str(self) == str(other) + + +class LocalModelSource(StringLikeSource): + """A local file or directory path.""" + + path: str | Path + inplace: Optional[bool] = False + type: Literal["local"] = "local" + + # these methods allow the source to be used in a string-like way, + # for example as an index into a dict + def __str__(self) -> str: + """Return string version of path when string rep needed.""" + return Path(self.path).as_posix() + + +class HFModelSource(StringLikeSource): + """ + A HuggingFace repo_id with optional variant, sub-folder and access token. + Note that the variant option, if not provided to the constructor, will default to fp16, which is + what people (almost) always want. + """ + + repo_id: str + variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 + subfolder: Optional[Path] = None + access_token: Optional[str] = None + type: Literal["hf"] = "hf" + + @field_validator("repo_id") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(r"^([.\w-]+/[.\w-]+)$", v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def __str__(self) -> str: + """Return string version of repoid when string rep needed.""" + base: str = self.repo_id + if self.variant: + base += f":{self.variant or ''}" + if self.subfolder: + base += f":{self.subfolder}" + return base + + +class URLModelSource(StringLikeSource): + """A generic URL point to a checkpoint file.""" + + url: AnyHttpUrl + access_token: Optional[str] = None + type: Literal["url"] = "url" + + def __str__(self) -> str: + """Return string version of the url when string rep needed.""" + return str(self.url) + + +ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] + +MODEL_SOURCE_TO_TYPE_MAP = { + URLModelSource: ModelSourceType.Url, + HFModelSource: ModelSourceType.HFRepoID, + LocalModelSource: ModelSourceType.Path, +} + + +class ModelInstallJob(BaseModel): + """Object that tracks the current status of an install request.""" + + id: int = Field(description="Unique ID for this job") + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") + error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") + config_in: Dict[str, Any] = Field( + default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." + ) + config_out: Optional[AnyModelConfig] = Field( + default=None, description="After successful installation, this will hold the configuration object." + ) + inplace: bool = Field( + default=False, description="Leave model in its current location; otherwise install under models directory" + ) + source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") + local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") + bytes: int = Field( + default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" + ) + total_bytes: int = Field(default=0, description="Total size of the model to be installed") + source_metadata: Optional[AnyModelRepoMetadata] = Field( + default=None, description="Metadata provided by the model source" + ) + download_parts: Set[DownloadJob] = Field( + default_factory=set, description="Download jobs contributing to this install" + ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) + # internal flags and transitory settings + _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) + _exception: Optional[Exception] = PrivateAttr(default=None) + + def set_error(self, e: Exception) -> None: + """Record the error and traceback from an exception.""" + self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) + self.status = InstallStatus.ERROR + self.error_reason = self._exception.__class__.__name__ if self._exception else None + + def cancel(self) -> None: + """Call to cancel the job.""" + self.status = InstallStatus.CANCELLED + + @property + def error_type(self) -> Optional[str]: + """Class name of the exception that led to status==ERROR.""" + return self._exception.__class__.__name__ if self._exception else None + + def _format_error(self, exception: Exception) -> str: + """Error traceback.""" + return "".join(traceback.format_exception(exception)) + + @property + def cancelled(self) -> bool: + """Set status to CANCELLED.""" + return self.status == InstallStatus.CANCELLED + + @property + def errored(self) -> bool: + """Return true if job has errored.""" + return self.status == InstallStatus.ERROR + + @property + def waiting(self) -> bool: + """Return true if job is waiting to run.""" + return self.status == InstallStatus.WAITING + + @property + def downloading(self) -> bool: + """Return true if job is downloading.""" + return self.status == InstallStatus.DOWNLOADING + + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + + @property + def running(self) -> bool: + """Return true if job is running.""" + return self.status == InstallStatus.RUNNING + + @property + def complete(self) -> bool: + """Return true if job completed without errors.""" + return self.status == InstallStatus.COMPLETED + + @property + def in_terminal_state(self) -> bool: + """Return true if job is in a terminal state.""" + return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index cde9a6502e..c78a09ce87 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml @@ -21,6 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( @@ -45,13 +46,12 @@ from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.util import slugify -from .model_install_base import ( +from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, StringLikeSource, URLModelSource, @@ -59,6 +59,9 @@ from .model_install_base import ( TMPDIR_PREFIX = "tmpinstall_" +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" @@ -68,7 +71,7 @@ class ModelInstallService(ModelInstallServiceBase): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, session: Optional[Session] = None, ): """ @@ -104,7 +107,7 @@ class ModelInstallService(ModelInstallServiceBase): return self._record_store @property - def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 + def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102 return self._event_bus # make the invoker optional here because we don't need it and it @@ -825,6 +828,7 @@ class ModelInstallService(ModelInstallServiceBase): else: # update sizes install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.download_parts = download_job.download_parts self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: @@ -864,36 +868,20 @@ class ModelInstallService(ModelInstallServiceBase): job.status = InstallStatus.RUNNING self._logger.info(f"Model install started: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_running(str(job.source)) + self._event_bus.emit_model_install_started(job) def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: assert job._download_job is not None - parts: List[Dict[str, str | int]] = [ - { - "url": str(x.source), - "local_path": str(x.download_path), - "bytes": x.bytes, - "total_bytes": x.total_bytes, - } - for x in job._download_job.download_parts - ] assert job.bytes is not None assert job.total_bytes is not None - self._event_bus.emit_model_install_downloading( - str(job.source), - local_path=job.local_path.as_posix(), - parts=parts, - bytes=sum(x["bytes"] for x in parts), - total_bytes=sum(x["total_bytes"] for x in parts), - id=job.id, - ) + self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: job.status = InstallStatus.DOWNLOADS_DONE self._logger.info(f"Model download complete: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_downloads_done(str(job.source)) + self._event_bus.emit_model_install_downloads_complete(job) def _signal_job_completed(self, job: ModelInstallJob) -> None: job.status = InstallStatus.COMPLETED @@ -903,24 +891,19 @@ class ModelInstallService(ModelInstallServiceBase): if self._event_bus: assert job.local_path is not None assert job.config_out is not None - key = job.config_out.key - self._event_bus.emit_model_install_completed( - source=str(job.source), key=key, id=job.id, total_bytes=job.bytes - ) + self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") if self._event_bus: - error_type = job.error_type - error = job.error - assert error_type is not None - assert error is not None - self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id) + assert job.error_type is not None + assert job.error is not None + self._event_bus.emit_model_install_error(job) def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._logger.info(f"Model install canceled: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) + self._event_bus.emit_model_install_cancelled(job) @staticmethod def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 7de36793fb..22d815483e 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -7,7 +7,6 @@ from typing import Callable, Dict, Optional from torch import Tensor -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -18,18 +17,12 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting """ @property diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index cd14235ee0..221a042da5 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -11,7 +11,6 @@ from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, @@ -59,25 +58,18 @@ class ModelLoadService(ModelLoadServiceBase): """Return the checkpoint convert cache used by this loader.""" return self._convert_cache - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting """ - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - ) + + # We don't have an invoker during testing + # TODO(psyche): Mock this method on the invoker in the tests + if hasattr(self, "_invoker"): + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore loaded_model: LoadedModel = implementation( @@ -87,13 +79,9 @@ class ModelLoadService(ModelLoadServiceBase): convert_cache=self._convert_cache, ).load_model(model_config, submodel_type) - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - loaded=True, - ) + if hasattr(self, "_invoker"): + self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) + return loaded_model def load_model_from_path( @@ -150,32 +138,3 @@ class ModelLoadService(ModelLoadServiceBase): raw_model = loader(model_path) ram_cache.put(key=cache_key, model=raw_model) return LoadedModel(_locker=ram_cache.get(key=cache_key)) - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - submodel_type: Optional[SubModelType] = None, - ) -> None: - if not self._invoker: - return - - if not loaded: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) - else: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 485ef2f8c3..15611bb5f8 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,49 @@ from abc import ABC, abstractmethod +from threading import Event +from typing import Optional, Protocol +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.util.profiler import Profiler + + +class SessionRunnerBase(ABC): + """ + Base class for session runner. + """ + + @abstractmethod + def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: + """Starts the session runner. + + Args: + services: The invocation services. + cancel_event: The cancel event. + profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session + stats will be still be recorded and logged when profiling is disabled. + """ + pass + + @abstractmethod + def run(self, queue_item: SessionQueueItem) -> None: + """Runs a session. + + Args: + queue_item: The session to run. + """ + pass + + @abstractmethod + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Run a single node in the graph. + + Args: + invocation: The invocation to run. + queue_item: The session queue item. + """ + pass class SessionProcessorBase(ABC): @@ -26,3 +69,85 @@ class SessionProcessorBase(ABC): def get_status(self) -> SessionProcessorStatus: """Gets the status of the session processor""" pass + + +class OnBeforeRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that will be executed. + queue_item: The session queue item. + """ + ... + + +class OnAfterRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that was executed. + queue_item: The session queue item. + """ + ... + + +class OnNodeError(Protocol): + def __call__( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Callback to run when a node has an error. + + Args: + invocation: The invocation that errored. + queue_item: The session queue item. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... + + +class OnBeforeRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a session. + + Args: + queue_item: The session queue item. + """ + ... + + +class OnAfterRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run after executing a session. + + Args: + queue_item: The session queue item. + """ + ... + + +class OnNonFatalProcessorError(Protocol): + def __call__( + self, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Callback to run when a non-fatal error occurs in the processor. + + Args: + queue_item: The session queue item, if one was being executed when the error occurred. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 894996b1e6..3f348fb239 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -4,24 +4,325 @@ from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, + register_events, +) from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_base import ( + OnAfterRunNode, + OnAfterRunSession, + OnBeforeRunNode, + OnBeforeRunSession, + OnNodeError, + OnNonFatalProcessorError, +) from invokeai.app.services.session_processor.session_processor_common import CanceledException -from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError +from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker -from .session_processor_base import SessionProcessorBase +from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +class DefaultSessionRunner(SessionRunnerBase): + """Processes a single session's invocations.""" + + def __init__( + self, + on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None, + on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None, + on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None, + on_node_error_callbacks: Optional[list[OnNodeError]] = None, + on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, + ): + """ + Args: + on_before_run_session_callbacks: Callbacks to run before the session starts. + on_before_run_node_callbacks: Callbacks to run before each node starts. + on_after_run_node_callbacks: Callbacks to run after each node completes. + on_node_error_callbacks: Callbacks to run when a node errors. + on_after_run_session_callbacks: Callbacks to run after the session completes. + """ + + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] + self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] + self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] + self._on_node_error_callbacks = on_node_error_callbacks or [] + self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] + + def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): + self._services = services + self._cancel_event = cancel_event + self._profiler = profiler + + def _is_canceled(self) -> bool: + """Check if the cancel event is set. This is also passed to the invocation context builder and called during + denoising to check if the session has been canceled.""" + return self._cancel_event.is_set() + + def run(self, queue_item: SessionQueueItem): + # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. + + self._on_before_run_session(queue_item=queue_item) + + # Loop over invocations until the session is complete or canceled + while True: + try: + invocation = queue_item.session.next() + # Anything other than a `NodeInputError` is handled as a processor error + except NodeInputError as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=e.node, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + break + + if invocation is None or self._is_canceled(): + break + + self.run_node(invocation, queue_item) + + # The session is complete if all invocations have been run or there is an error on the session. + # At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must + # use the cancel event to check if the session is canceled. + if ( + queue_item.session.is_complete() + or self._is_canceled() + or queue_item.status in ["failed", "canceled", "completed"] + ): + break + + self._on_after_run_session(queue_item=queue_item) + + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + try: + # Any unhandled exception in this scope is an invocation error & will fail the graph + with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + context = build_invocation_context( + data=data, + services=self._services, + is_canceled=self._is_canceled, + ) + + # Invoke the node + output = invocation.invoke_internal(context=context, services=self._services) + # Save output and history + queue_item.session.complete(invocation.id, output) + + self._on_after_run_node(invocation, queue_item, output) + + except KeyboardInterrupt: + # TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here? + pass + except CanceledException: + # A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need + # to do any handling here, and no error should be set - just pass and the cancellation will be handled + # correctly in the next iteration of the session runner loop. + # + # See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we + # handle cancellation. + pass + except Exception as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: + """Called before a session is run. + + - Start the profiler if profiling is enabled. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + + # If profiling is enabled, start the profiler + if self._profiler is not None: + self._profiler.start(profile_id=queue_item.session_id) + + for callback in self._on_before_run_session_callbacks: + callback(queue_item=queue_item) + + def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: + """Called after a session is run. + + - Stop the profiler if profiling is enabled. + - Update the queue item's session object in the database. + - If not already canceled or failed, complete the queue item. + - Log and reset performance statistics. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler is not None: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._services.performance_statistics.dump_stats( + graph_execution_state_id=queue_item.session.id, output_path=stats_path + ) + + try: + # Update the queue item with the completed session. If the queue item has been removed from the queue, + # we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared + # while the session is running. + queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + + # The queue item may have been canceled or failed while the session was running. We should only complete it + # if it is not already canceled or failed. + if queue_item.status not in ["canceled", "failed"]: + queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id) + + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._services.performance_statistics.log_stats(queue_item.session.id) + self._services.performance_statistics.reset_stats() + + for callback in self._on_after_run_session_callbacks: + callback(queue_item=queue_item) + except SessionQueueItemNotFoundError: + pass + + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Called before a node is run. + + - Emits an invocation started event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Send starting event + self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation) + + for callback in self._on_before_run_node_callbacks: + callback(invocation=invocation, queue_item=queue_item) + + def _on_after_run_node( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput + ): + """Called after a node is run. + + - Emits an invocation complete event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Send complete event on successful runs + self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output) + + for callback in self._on_after_run_node_callbacks: + callback(invocation=invocation, queue_item=queue_item, output=output) + + def _on_node_error( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + error_type: str, + error_message: str, + error_traceback: str, + ): + """Called when a node errors. Node errors may occur when running or preparing the node.. + + - Set the node error on the session object. + - Log the error. + - Fail the queue item. + - Emits an invocation error event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Node errors do not get the full traceback. Only the queue item gets the full traceback. + node_error = f"{error_type}: {error_message}" + queue_item.session.set_node_error(invocation.id, node_error) + self._services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}" + ) + self._services.logger.error(error_traceback) + + # Fail the queue item + queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + queue_item = self._services.session_queue.fail_queue_item( + queue_item.item_id, error_type, error_message, error_traceback + ) + + # Send error event + self._services.events.emit_invocation_error( + queue_item=queue_item, + invocation=invocation, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + for callback in self._on_node_error_callbacks: + callback( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: + def __init__( + self, + session_runner: Optional[SessionRunnerBase] = None, + on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None, + thread_limit: int = 1, + polling_interval: int = 1, + ) -> None: + super().__init__() + + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] + self._thread_limit = thread_limit + self._polling_interval = polling_interval + + def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None @@ -31,11 +332,11 @@ class DefaultSessionProcessor(SessionProcessorBase): self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) + register_events(QueueClearedEvent, self._on_queue_cleared) + register_events(BatchEnqueuedEvent, self._on_batch_enqueued) + register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_limit = thread_limit - self._thread_semaphore = BoundedSemaphore(thread_limit) - self._polling_interval = polling_interval + self._thread_semaphore = BoundedSemaphore(self._thread_limit) # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. @@ -49,6 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) self._thread = Thread( name="session_processor", target=self._process, @@ -67,30 +369,25 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent) -> None: - event_name = event[1]["event"] + async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: + if self._queue_item and self._queue_item.queue_id == event[1].queue_id: + self._cancel_event.set() + self._poll_now() - if ( - event_name == "session_canceled" - and self._queue_item - and self._queue_item.item_id == event[1]["data"]["queue_item_id"] - ): - self._cancel_event.set() - self._poll_now() - elif ( - event_name == "queue_cleared" - and self._queue_item - and self._queue_item.queue_id == event[1]["data"]["queue_id"] - ): - self._cancel_event.set() - self._poll_now() - elif event_name == "batch_enqueued": - self._poll_now() - elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [ - "completed", - "failed", - "canceled", - ]: + async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: + self._poll_now() + + async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: + if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is + # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel + # event, which the session runner checks between invocations. If set, the session runner loop is broken. + # + # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such + # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item + # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. + if event[1].status == "canceled": + self._cancel_event.set() self._poll_now() def resume(self) -> SessionProcessorStatus: @@ -116,8 +413,8 @@ class DefaultSessionProcessor(SessionProcessorBase): resume_event: ThreadEvent, cancel_event: ThreadEvent, ): - # Outermost processor try block; any unhandled exception is a fatal processor error try: + # Any unhandled exception in this block is a fatal processor error and will stop the processor. self._thread_semaphore.acquire() stop_event.clear() resume_event.set() @@ -125,8 +422,8 @@ class DefaultSessionProcessor(SessionProcessorBase): while not stop_event.is_set(): poll_now_event.clear() - # Middle processor try block; any unhandled exception is a non-fatal processor error try: + # Any unhandled exception in this block is a nonfatal processor error and will be handled. # If we are paused, wait for resume event resume_event.wait() @@ -142,159 +439,69 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) + # Run the graph + self.session_runner.run(queue_item=self._queue_item) - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() - - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) - - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) - - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue - except Exception: - # Non-fatal error in processor - self._invoker.services.logger.error( - f"Non-fatal error in session processor:\n{traceback.format_exc()}" + except Exception as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_non_fatal_processor_error( + queue_item=self._queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, ) - # Cancel the queue item - if self._queue_item is not None: - self._invoker.services.session_queue.cancel_queue_item( - self._queue_item.item_id, error=traceback.format_exc() - ) - # Reset the invocation to None to prepare for the next session - self._invocation = None - # Immediately poll for next queue item + # Wait for next polling interval or event to try again poll_now_event.wait(self._polling_interval) continue - except Exception: + except Exception as e: # Fatal error in processor, log and pass - we're done here - self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) pass finally: stop_event.clear() poll_now_event.clear() self._queue_item = None self._thread_semaphore.release() + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Called when a non-fatal error occurs in the processor. + + - Log the error. + - If a queue item is provided, update the queue item with the completed session & fail it. + - Run any callbacks registered for this event. + """ + + self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) + + if queue_item is not None: + # Update the queue item with the completed session & fail it + queue_item = self._invoker.services.session_queue.set_queue_item_session( + queue_item.item_id, queue_item.session + ) + queue_item = self._invoker.services.session_queue.fail_queue_item( + item_id=queue_item.item_id, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + for callback in self._on_non_fatal_processor_error_callbacks: + callback( + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index e0b6e4f528..341e034487 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItemDTO, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults @@ -73,10 +74,22 @@ class SessionQueueBase(ABC): pass @abstractmethod - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def complete_queue_item(self, item_id: int) -> SessionQueueItem: + """Completes a session queue item""" + pass + + @abstractmethod + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: """Cancels a session queue item""" pass + @abstractmethod + def fail_queue_item( + self, item_id: int, error_type: str, error_message: str, error_traceback: str + ) -> SessionQueueItem: + """Fails a session queue item""" + pass + @abstractmethod def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: """Cancels all queue items with matching batch IDs""" @@ -103,3 +116,8 @@ class SessionQueueBase(ABC): def get_queue_item(self, item_id: int) -> SessionQueueItem: """Gets a session queue item by ID""" pass + + @abstractmethod + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + """Sets the session for a session queue item. Use this to update the session state.""" + pass diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 94db6999c2..7f4601eba7 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -3,7 +3,16 @@ import json from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast -from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + StrictStr, + TypeAdapter, + field_validator, + model_validator, +) from pydantic_core import to_jsonable_python from invokeai.app.invocations.baseinvocation import BaseInvocation @@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel): session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) - error: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored") + error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_traceback: Optional[str] = Field( + default=None, + description="The error traceback if this queue item errored", + validation_alias=AliasChoices("error_traceback", "error"), + ) created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index ffcd7c40ca..467853aae4 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,10 +2,6 @@ import sqlite3 import threading from typing import Optional, Union, cast -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( @@ -27,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( calc_session_count, prepare_values_to_insert, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -41,7 +38,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker = invoker self._set_in_progress_to_canceled() prune_result = self.prune(DEFAULT_QUEUE_ID) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) + if prune_result.deleted > 0: self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") @@ -51,52 +48,6 @@ class SqliteSessionQueue(SessionQueueBase): self.__conn = db.conn self.__cursor = self.__conn.cursor() - def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: - return event[1]["event"] in match_in - - async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: - event_name = event[1]["event"] - - # This was a match statement, but match is not supported on python 3.9 - if event_name == "graph_execution_state_complete": - await self._handle_complete_event(event) - elif event_name == "invocation_error": - await self._handle_error_event(event) - elif event_name == "session_canceled": - await self._handle_cancel_event(event) - return event - - async def _handle_complete_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - # When a queue item has an error, we get an error event, then a completed event. - # Mark the queue item completed only if it isn't already marked completed, e.g. - # by a previously-handled error event. - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed") - except SessionQueueItemNotFoundError: - return - - async def _handle_error_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] - queue_item = self.get_queue_item(item_id) - # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) - except SessionQueueItemNotFoundError: - return - - async def _handle_cancel_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") - except SessionQueueItemNotFoundError: - return - def _set_in_progress_to_canceled(self) -> None: """ Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. @@ -271,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase): return SessionQueueItem.queue_item_from_dict(dict(result)) def _set_queue_item_status( - self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None + self, + item_id: int, + status: QUEUE_ITEM_STATUS, + error_type: Optional[str] = None, + error_message: Optional[str] = None, + error_traceback: Optional[str] = None, ) -> SessionQueueItem: try: self.__lock.acquire() self.__cursor.execute( """--sql UPDATE session_queue - SET status = ?, error = ? + SET status = ?, error_type = ?, error_message = ?, error_traceback = ? WHERE item_id = ? """, - (status, error, item_id), + (status, error_type, error_message, error_traceback, item_id), ) self.__conn.commit() except Exception: @@ -292,11 +248,7 @@ class SqliteSessionQueue(SessionQueueBase): queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=queue_item, - batch_status=batch_status, - queue_status=queue_status, - ) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item def is_empty(self, queue_id: str) -> IsEmptyResult: @@ -338,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return IsFullResult(is_full=is_full) - def delete_queue_item(self, item_id: int) -> SessionQueueItem: - queue_item = self.get_queue_item(item_id=item_id) - try: - self.__lock.acquire() - self.__cursor.execute( - """--sql - DELETE FROM session_queue - WHERE - item_id = ? - """, - (item_id,), - ) - self.__conn.commit() - except Exception: - self.__conn.rollback() - raise - finally: - self.__lock.release() - return queue_item - def clear(self, queue_id: str) -> ClearResult: try: self.__lock.acquire() @@ -424,17 +356,28 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return PruneResult(deleted=count) - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["canceled", "failed", "completed"]: - status = "failed" if error is not None else "canceled" - queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") + return queue_item + + def complete_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="completed") + return queue_item + + def fail_queue_item( + self, + item_id: int, + error_type: str, + error_message: str, + error_traceback: str, + ) -> SessionQueueItem: + queue_item = self._set_queue_item_status( + item_id=item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: @@ -470,18 +413,10 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -521,18 +456,10 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -562,6 +489,29 @@ class SqliteSessionQueue(SessionQueueBase): raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") return SessionQueueItem.queue_item_from_dict(dict(result)) + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + try: + # Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors + # when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced + # during execution. + session_json = session.model_dump_json(warnings=False, exclude_none=True) + self.__lock.acquire() + self.__cursor.execute( + """--sql + UPDATE session_queue + SET session = ? + WHERE item_id = ? + """, + (session_json, item_id), + ) + self.__conn.commit() + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return self.get_queue_item(item_id) + def list_queue_items( self, queue_id: str, @@ -578,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase): status, priority, field_values, - error, + error_type, + error_message, + error_traceback, created_at, updated_at, completed_at, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index cc2ea5cedb..8508d2484c 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -8,6 +8,7 @@ import networkx as nx from pydantic import ( BaseModel, GetJsonSchemaHandler, + ValidationError, field_validator, ) from pydantic.fields import Field @@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError): pass +class NodeInputError(ValueError): + """Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an + input fails validation. + + Attributes: + node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to + determine which field caused the failure. + """ + + def __init__(self, node: BaseInvocation, e: ValidationError): + self.original_error = e + self.node = node + # When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error + # represents the first input that failed. + self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"]) + super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}") + + +def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str: + """Helper to pretty-print pydantic error locations as dot-separated strings. + Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages + """ + path = "" + for i, x in enumerate(loc): + if isinstance(x, str): + if i > 0: + path += "." + path += x + else: + path += f"[{x}]" + return path + + @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -821,7 +855,10 @@ class GraphExecutionState(BaseModel): # Get values from edges if next_node is not None: - self._prepare_inputs(next_node) + try: + self._prepare_inputs(next_node) + except ValidationError as e: + raise NodeInputError(next_node, e) # If next is still none, there's no next node, return None return next_node diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1acc5c725d..c932e66989 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,4 +1,3 @@ -import threading from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Callable, Dict, Optional, Union @@ -359,12 +358,11 @@ class ModelsInterface(InvocationContextInterface): if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data) + return self._services.model_manager.load.load_model(model, submodel_type) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - result = self._services.model_manager.load.load_model(model, _submodel_type, self._data) - return result + return self._services.model_manager.load.load_model(model, _submodel_type) def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None @@ -388,8 +386,7 @@ class ModelsInterface(InvocationContextInterface): if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) - return result + return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Get a model's config. @@ -516,10 +513,10 @@ class ConfigInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event + self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool] ) -> None: super().__init__(services, data) - self._cancel_event = cancel_event + self._is_canceled = is_canceled def is_canceled(self) -> bool: """Checks if the current session has been canceled. @@ -527,7 +524,7 @@ class UtilInterface(InvocationContextInterface): Returns: True if the current session has been canceled, False if not. """ - return self._cancel_event.is_set() + return self._is_canceled() def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ @@ -602,7 +599,7 @@ class InvocationContext: def build_invocation_context( services: InvocationServices, data: InvocationContextData, - cancel_event: threading.Event, + is_canceled: Callable[[], bool], ) -> InvocationContext: """Builds the invocation context for a specific invocation execution. @@ -619,7 +616,7 @@ def build_invocation_context( tensors = TensorsInterface(services=services, data=data) models = ModelsInterface(services=services, data=data) config = ConfigInterface(services=services, data=data) - util = UtilInterface(services=services, data=data, cancel_event=cancel_event) + util = UtilInterface(services=services, data=data, is_canceled=is_canceled) conditioning = ConditioningInterface(services=services, data=data) boards = BoardsInterface(services=services, data=data) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 61f35a3b4e..cadf09f457 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -42,7 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) - migrator.register_migration(build_migration_10(app_config=config, logger=logger)) + migrator.register_migration(build_migration_10()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py index 4c4f742d4c..ce2cd2e965 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -1,75 +1,35 @@ -import shutil import sqlite3 -from logging import Logger -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -LEGACY_CORE_MODELS = [ - # OpenPose - "any/annotators/dwpose/yolox_l.onnx", - "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - # DepthAnything - "any/annotators/depth_anything/depth_anything_vitl14.pth", - "any/annotators/depth_anything/depth_anything_vitb14.pth", - "any/annotators/depth_anything/depth_anything_vits14.pth", - # Lama inpaint - "core/misc/lama/lama.pt", - # RealESRGAN upscale - "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", - "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", - "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", -] - class Migration10Callback: - def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: - self._app_config = app_config - self._logger = logger - def __call__(self, cursor: sqlite3.Cursor) -> None: - self._remove_convert_cache() - self._remove_downloaded_models() - self._remove_unused_core_models() + self._update_error_cols(cursor) - def _remove_convert_cache(self) -> None: - """Rename models/.cache to models/.convert_cache.""" - self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") - legacy_convert_path = self._app_config.root_path / "models" / ".cache" - shutil.rmtree(legacy_convert_path, ignore_errors=True) + def _update_error_cols(self, cursor: sqlite3.Cursor) -> None: + """ + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ - def _remove_downloaded_models(self) -> None: - """Remove models from their old locations; they will re-download when needed.""" - self._logger.info( - "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." - ) - for model_path in LEGACY_CORE_MODELS: - legacy_dest_path = self._app_config.models_path / model_path - legacy_dest_path.unlink(missing_ok=True) - - def _remove_unused_core_models(self) -> None: - """Remove unused core models and their directories.""" - self._logger.info("Removing defunct core models.") - for dir in ["face_restoration", "misc", "upscaling"]: - path_to_remove = self._app_config.models_path / "core" / dir - shutil.rmtree(path_to_remove, ignore_errors=True) - shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;") + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;") + cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;") -def build_migration_10(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: +def build_migration_10() -> Migration: """ Build the migration from database version 9 to 10. This migration does the following: - - Moves "core" models previously downloaded with download_with_progress_bar() into new - "models/.download_cache" directory. - - Renames "models/.cache" to "models/.convert_cache". + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. """ migration_10 = Migration( from_version=9, to_version=10, - callback=Migration10Callback(app_config=app_config, logger=logger), + callback=Migration10Callback(), ) return migration_10 diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py new file mode 100644 index 0000000000..3b616e2b82 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py @@ -0,0 +1,77 @@ +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = [ + # OpenPose + "any/annotators/dwpose/yolox_l.onnx", + "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "any/annotators/depth_anything/depth_anything_vitl14.pth", + "any/annotators/depth_anything/depth_anything_vitb14.pth", + "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "core/misc/lama/lama.pt", + # RealESRGAN upscale + "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +] + + +class Migration11Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._remove_convert_cache() + self._remove_downloaded_models() + self._remove_unused_core_models() + + def _remove_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + shutil.rmtree(legacy_convert_path, ignore_errors=True) + + def _remove_downloaded_models(self) -> None: + """Remove models from their old locations; they will re-download when needed.""" + self._logger.info( + "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." + ) + for model_path in LEGACY_CORE_MODELS: + legacy_dest_path = self._app_config.models_path / model_path + legacy_dest_path.unlink(missing_ok=True) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 9 to 10. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ + migration_11 = Migration( + from_version=10, + to_version=11, + callback=Migration11Callback(app_config=app_config, logger=logger), + ) + + return migration_11 diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8cb59f5b3a..8992e59ace 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional import torch from PIL import Image @@ -13,8 +13,36 @@ if TYPE_CHECKING: from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.shared.invocation_context import InvocationContextData +# fast latents preview matrix for sdxl +# generated by @StAlKeR7779 +SDXL_LATENT_RGB_FACTORS = [ + # R G B + [0.3816, 0.4930, 0.5320], + [-0.3753, 0.1631, 0.1739], + [0.1770, 0.3588, -0.2048], + [-0.4350, -0.2644, -0.4289], +] +SDXL_SMOOTH_MATRIX = [ + [0.0358, 0.0964, 0.0358], + [0.0964, 0.4711, 0.0964], + [0.0358, 0.0964, 0.0358], +] -def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): +# origingally adapted from code by @erucipe and @keturn here: +# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 +# these updated numbers for v1.5 are from @torridgristle +SD1_5_LATENT_RGB_FACTORS = [ + # R G B + [0.3444, 0.1385, 0.0670], # L1 + [0.1247, 0.4027, 0.1494], # L2 + [-0.3192, 0.2513, 0.2103], # L3 + [-0.1307, -0.1874, -0.7445], # L4 +] + + +def sample_to_lowres_estimated_image( + samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None +): latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors if smooth_matrix is not None: @@ -47,64 +75,12 @@ def stable_diffusion_step_callback( else: sample = intermediate_state.latents - # TODO: This does not seem to be needed any more? - # # txt2img provides a Tensor in the step_callback - # # img2img provides a PipelineIntermediateState - # if isinstance(sample, PipelineIntermediateState): - # # this was an img2img - # print('img2img') - # latents = sample.latents - # step = sample.step - # else: - # print('txt2img') - # latents = sample - # step = intermediate_state.step - - # TODO: only output a preview image when requested - if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]: - # fast latents preview matrix for sdxl - # generated by @StAlKeR7779 - sdxl_latent_rgb_factors = torch.tensor( - [ - # R G B - [0.3816, 0.4930, 0.5320], - [-0.3753, 0.1631, 0.1739], - [0.1770, 0.3588, -0.2048], - [-0.4350, -0.2644, -0.4289], - ], - dtype=sample.dtype, - device=sample.device, - ) - - sdxl_smooth_matrix = torch.tensor( - [ - [0.0358, 0.0964, 0.0358], - [0.0964, 0.4711, 0.0964], - [0.0358, 0.0964, 0.0358], - ], - dtype=sample.dtype, - device=sample.device, - ) - + sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device) + sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device) image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) else: - # origingally adapted from code by @erucipe and @keturn here: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 - - # these updated numbers for v1.5 are from @torridgristle - v1_5_latent_rgb_factors = torch.tensor( - [ - # R G B - [0.3444, 0.1385, 0.0670], # L1 - [0.1247, 0.4027, 0.1494], # L2 - [-0.3192, 0.2513, 0.2103], # L3 - [-0.1307, -0.1874, -0.7445], # L4 - ], - dtype=sample.dtype, - device=sample.device, - ) - + v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device) image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) (width, height) = image.size @@ -113,15 +89,9 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - events.emit_generator_progress( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - node_id=context_data.invocation.id, - source_node_id=context_data.source_invocation_id, - progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), - step=intermediate_state.step, - order=intermediate_state.order, - total_steps=intermediate_state.total_steps, + events.emit_invocation_denoise_progress( + context_data.queue_item, + context_data.invocation, + intermediate_state, + ProgressImage(dataURL=dataURL, width=width, height=height), ) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index ec77bbe477..bb7fd6f1d4 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -42,10 +42,26 @@ T = TypeVar("T") @dataclass class CacheRecord(Generic[T]): - """Elements of the cache.""" + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + """ key: str model: T + device: torch.device + state_dict: Optional[Dict[str, torch.Tensor]] size: int loaded: bool = False _locks: int = 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index bd7b2ffc7a..10c0210052 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -20,7 +20,6 @@ context. Use like this: import gc import math -import sys import time from contextlib import suppress from logging import Logger @@ -163,7 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]): return size = calc_model_size_by_data(model) self.make_room(size) - cache_record = CacheRecord(key, model, size) + + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -253,21 +254,40 @@ class ModelCache(ModelCacheBase[AnyModel]): May raise a torch.cuda.OutOfMemoryError """ self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - model = cache_entry.model + source_device = cache_entry.device - source_device = model.device if hasattr(model, "device") else self.storage_device + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return + if not hasattr(cache_entry.model, "to"): + return + + # This roundabout method for moving the model around is done to avoid + # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. + # When moving to VRAM, we copy (not move) each element of the state dict from + # RAM to a new state dict in VRAM, and then inject it into the model. + # This operation is slightly faster than running `to()` on the whole model. + # + # When the model needs to be removed from VRAM we simply delete the copy + # of the state dict in VRAM, and reinject the state dict that is cached + # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() + try: - if hasattr(model, "to"): - model.to(target_device) - elif isinstance(model, dict): - for _, v in model.items(): - if hasattr(v, "to"): - v.to(target_device) + if cache_entry.state_dict is not None: + assert hasattr(cache_entry.model, "load_state_dict") + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True) + cache_entry.model.load_state_dict(new_dict, assign=True) + cache_entry.model.to(target_device) + cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e @@ -347,43 +367,12 @@ class ModelCache(ModelCacheBase[AnyModel]): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" ) - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + if not cache_entry.locked: self.logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7de7a8e01c..f7a91ef756 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2,6 +2,7 @@ "accessibility": { "about": "About", "createIssue": "Create Issue", + "submitSupportTicket": "Submit Support Ticket", "invokeProgressBar": "Invoke progress bar", "menu": "Menu", "mode": "Mode", @@ -146,7 +147,9 @@ "viewing": "Viewing", "viewingDesc": "Review images in a large gallery view", "editing": "Editing", - "editingDesc": "Edit on the Control Layers canvas" + "editingDesc": "Edit on the Control Layers canvas", + "enabled": "Enabled", + "disabled": "Disabled" }, "controlnet": { "controlAdapter_one": "Control Adapter", @@ -775,10 +778,14 @@ "cannotConnectToSelf": "Cannot connect to self", "cannotDuplicateConnection": "Cannot create duplicate connections", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", + "missingNode": "Missing invocation node", + "missingInvocationTemplate": "Missing invocation template", + "missingFieldTemplate": "Missing field template", "nodePack": "Node pack", "collection": "Collection", - "collectionFieldType": "{{name}} Collection", - "collectionOrScalarFieldType": "{{name}} Collection|Scalar", + "singleFieldType": "{{name}} (Single)", + "collectionFieldType": "{{name}} (Collection)", + "collectionOrScalarFieldType": "{{name}} (Single or Collection)", "colorCodeEdges": "Color-Code Edges", "colorCodeEdgesHelp": "Color-code edges according to their connected fields", "connectionWouldCreateCycle": "Connection would create a cycle", @@ -893,7 +900,10 @@ "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.", - "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time." + "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.", + "imageAccessError": "Unable to find image {{image_name}}, resetting to default", + "boardAccessError": "Unable to find board {{board_id}}, resetting to default", + "modelAccessError": "Unable to find model {{key}}, resetting to default" }, "parameters": { "aspect": "Aspect", @@ -948,7 +958,7 @@ "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "controlAdapterNoImageSelected": "no Control Adapter image selected", "controlAdapterImageNotProcessed": "Control Adapter image not processed", - "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64", + "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}", "ipAdapterNoModelSelected": "no IP adapter selected", "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", @@ -1066,8 +1076,9 @@ }, "toast": { "addedToBoard": "Added to board", - "baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel", - "baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels", + "baseModelChanged": "Base Model Changed", + "baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel", + "baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels", "canceled": "Processing Canceled", "canvasCopiedClipboard": "Canvas Copied to Clipboard", "canvasDownloaded": "Canvas Downloaded", @@ -1088,10 +1099,17 @@ "metadataLoadFailed": "Failed to load metadata", "modelAddedSimple": "Model Added to Queue", "modelImportCanceled": "Model Import Canceled", + "outOfMemoryError": "Out of Memory Error", + "outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.", "parameters": "Parameters", - "parameterNotSet": "{{parameter}} not set", - "parameterSet": "{{parameter}} set", - "parametersNotSet": "Parameters Not Set", + "parameterSet": "Parameter Recalled", + "parameterSetDesc": "Recalled {{parameter}}", + "parameterNotSet": "Parameter Not Recalled", + "parameterNotSetDesc": "Unable to recall {{parameter}}", + "parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}", + "parametersSet": "Parameters Recalled", + "parametersNotSet": "Parameters Not Recalled", + "errorCopied": "Error Copied", "problemCopyingCanvas": "Problem Copying Canvas", "problemCopyingCanvasDesc": "Unable to export base layer", "problemCopyingImage": "Unable to Copy Image", @@ -1111,11 +1129,13 @@ "sentToImageToImage": "Sent To Image To Image", "sentToUnifiedCanvas": "Sent to Unified Canvas", "serverError": "Server Error", + "sessionRef": "Session: {{sessionId}}", "setAsCanvasInitialImage": "Set as canvas initial image", "setCanvasInitialImage": "Set canvas initial image", "setControlImage": "Set as control image", "setInitialImage": "Set as initial image", "setNodeField": "Set as node field", + "somethingWentWrong": "Something Went Wrong", "uploadFailed": "Upload failed", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadInitialImage": "Upload Initial Image", @@ -1555,7 +1575,6 @@ "controlLayers": "Control Layers", "globalMaskOpacity": "Global Mask Opacity", "autoNegative": "Auto Negative", - "toggleVisibility": "Toggle Layer Visibility", "deletePrompt": "Delete Prompt", "resetRegion": "Reset Region", "debugLayers": "Debug Layers", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index dbdda8e209..169bfdb066 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -382,7 +382,7 @@ "canvasMerged": "Lienzo consolidado", "sentToImageToImage": "Enviar hacia Imagen a Imagen", "sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado", - "parametersNotSet": "Parámetros no establecidos", + "parametersNotSet": "Parámetros no recuperados", "metadataLoadFailed": "Error al cargar metadatos", "serverError": "Error en el servidor", "canceled": "Procesando la cancelación", @@ -390,7 +390,8 @@ "uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG", "parameterSet": "Conjunto de parámetros", "parameterNotSet": "Parámetro no configurado", - "problemCopyingImage": "No se puede copiar la imagen" + "problemCopyingImage": "No se puede copiar la imagen", + "errorCopied": "Error al copiar" }, "tooltip": { "feature": { diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index f365b43e10..bd82dd9a5b 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -524,7 +524,20 @@ "missingNodeTemplate": "Modello di nodo mancante", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante", "missingFieldTemplate": "Modello di campo mancante", - "imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata" + "imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata", + "layer": { + "initialImageNoImageSelected": "Nessuna immagine iniziale selezionata", + "t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}", + "controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato", + "controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile", + "controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata", + "controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata", + "ipAdapterNoModelSelected": "Nessun adattatore IP selezionato", + "ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile", + "ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata", + "rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP", + "rgNoRegion": "Nessuna regione selezionata" + } }, "useCpuNoise": "Usa la CPU per generare rumore", "iterations": "Iterazioni", @@ -824,8 +837,8 @@ "unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi", "addLinearView": "Aggiungi alla vista Lineare", "unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro", - "collectionFieldType": "{{name}} Raccolta", - "collectionOrScalarFieldType": "{{name}} Raccolta|Scalare", + "collectionFieldType": "{{name}} (Raccolta)", + "collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)", "nodeVersion": "Versione Nodo", "inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})", "unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"", @@ -863,7 +876,13 @@ "edit": "Modifica", "graph": "Grafico", "showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati", - "showEdgeLabels": "Mostra le etichette del collegamento" + "showEdgeLabels": "Mostra le etichette del collegamento", + "cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta", + "noGraph": "Nessun grafico", + "missingNode": "Nodo di invocazione mancante", + "missingInvocationTemplate": "Modello di invocazione mancante", + "missingFieldTemplate": "Modello di campo mancante", + "singleFieldType": "{{name}} (Singola)" }, "boards": { "autoAddBoard": "Aggiungi automaticamente bacheca", @@ -1034,7 +1053,16 @@ "graphFailedToQueue": "Impossibile mettere in coda il grafico", "batchFieldValues": "Valori Campi Lotto", "time": "Tempo", - "openQueue": "Apri coda" + "openQueue": "Apri coda", + "iterations_one": "Iterazione", + "iterations_many": "Iterazioni", + "iterations_other": "Iterazioni", + "prompts_one": "Prompt", + "prompts_many": "Prompt", + "prompts_other": "Prompt", + "generations_one": "Generazione", + "generations_many": "Generazioni", + "generations_other": "Generazioni" }, "models": { "noMatchingModels": "Nessun modello corrispondente", @@ -1563,7 +1591,6 @@ "brushSize": "Dimensioni del pennello", "globalMaskOpacity": "Opacità globale della maschera", "autoNegative": "Auto Negativo", - "toggleVisibility": "Attiva/disattiva la visibilità dei livelli", "deletePrompt": "Cancella il prompt", "debugLayers": "Debug dei Livelli", "rectangle": "Rettangolo", diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json index 76377bd215..afcce62163 100644 --- a/invokeai/frontend/web/public/locales/nl.json +++ b/invokeai/frontend/web/public/locales/nl.json @@ -6,7 +6,7 @@ "settingsLabel": "Instellingen", "img2img": "Afbeelding naar afbeelding", "unifiedCanvas": "Centraal canvas", - "nodes": "Werkstroom-editor", + "nodes": "Werkstromen", "upload": "Upload", "load": "Laad", "statusDisconnected": "Niet verbonden", @@ -34,7 +34,60 @@ "controlNet": "ControlNet", "imageFailedToLoad": "Kan afbeelding niet laden", "learnMore": "Meer informatie", - "advanced": "Uitgebreid" + "advanced": "Uitgebreid", + "file": "Bestand", + "installed": "Geïnstalleerd", + "notInstalled": "Niet $t(common.installed)", + "simple": "Eenvoudig", + "somethingWentWrong": "Er ging iets mis", + "add": "Voeg toe", + "checkpoint": "Checkpoint", + "details": "Details", + "outputs": "Uitvoeren", + "save": "Bewaar", + "nextPage": "Volgende pagina", + "blue": "Blauw", + "alpha": "Alfa", + "red": "Rood", + "editor": "Editor", + "folder": "Map", + "format": "structuur", + "goTo": "Ga naar", + "template": "Sjabloon", + "input": "Invoer", + "loglevel": "Logboekniveau", + "safetensors": "Safetensors", + "saveAs": "Bewaar als", + "created": "Gemaakt", + "green": "Groen", + "tab": "Tab", + "positivePrompt": "Positieve prompt", + "negativePrompt": "Negatieve prompt", + "selected": "Geselecteerd", + "orderBy": "Sorteer op", + "prevPage": "Vorige pagina", + "beta": "Bèta", + "copyError": "$t(gallery.copy) Fout", + "toResolve": "Op te lossen", + "aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:", + "aboutHeading": "Creatieve macht voor jou", + "copy": "Kopieer", + "data": "Gegevens", + "or": "of", + "updated": "Bijgewerkt", + "outpaint": "outpainten", + "viewing": "Bekijken", + "viewingDesc": "Beoordeel afbeelding in een grote galerijweergave", + "editing": "Bewerken", + "editingDesc": "Bewerk op het canvas Stuurlagen", + "ai": "ai", + "inpaint": "inpainten", + "unknown": "Onbekend", + "delete": "Verwijder", + "direction": "Richting", + "error": "Fout", + "localSystem": "Lokaal systeem", + "unknownError": "Onbekende fout" }, "gallery": { "galleryImageSize": "Afbeeldingsgrootte", @@ -310,10 +363,41 @@ "modelSyncFailed": "Synchronisatie modellen mislukt", "modelDeleteFailed": "Model kon niet verwijderd worden", "convertingModelBegin": "Model aan het converteren. Even geduld.", - "predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)", + "predictionType": "Soort voorspelling", "advanced": "Uitgebreid", "modelType": "Soort model", - "vaePrecision": "Nauwkeurigheid VAE" + "vaePrecision": "Nauwkeurigheid VAE", + "loraTriggerPhrases": "LoRA-triggerzinnen", + "urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.", + "modelName": "Modelnaam", + "path": "Pad", + "triggerPhrases": "Triggerzinnen", + "typePhraseHere": "Typ zin hier in", + "useDefaultSettings": "Gebruik standaardinstellingen", + "modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding", + "modelImageUpdated": "Modelafbeelding bijgewerkt", + "modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding", + "noMatchingModels": "Geen overeenkomende modellen", + "scanPlaceholder": "Pad naar een lokale map", + "noModelsInstalled": "Geen modellen geïnstalleerd", + "noModelsInstalledDesc1": "Installeer modellen met de", + "noModelSelected": "Geen model geselecteerd", + "starterModels": "Beginnermodellen", + "textualInversions": "Tekstuele omkeringen", + "upcastAttention": "Upcast-aandacht", + "uploadImage": "Upload afbeelding", + "mainModelTriggerPhrases": "Triggerzinnen hoofdmodel", + "urlOrLocalPath": "URL of lokaal pad", + "scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.", + "simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map", + "modelSettings": "Modelinstellingen", + "pathToConfig": "Pad naar configuratie", + "prune": "Snoei", + "pruneTooltip": "Snoei voltooide importeringen uit wachtrij", + "repoVariant": "Repovariant", + "scanFolder": "Lees map in", + "scanResults": "Resultaten inlezen", + "source": "Bron" }, "parameters": { "images": "Afbeeldingen", @@ -353,13 +437,13 @@ "copyImage": "Kopieer afbeelding", "denoisingStrength": "Sterkte ontruisen", "scheduler": "Planner", - "seamlessXAxis": "X-as", - "seamlessYAxis": "Y-as", + "seamlessXAxis": "Naadloze tegels in x-as", + "seamlessYAxis": "Naadloze tegels in y-as", "clipSkip": "Overslaan CLIP", "negativePromptPlaceholder": "Negatieve prompt", "controlNetControlMode": "Aansturingsmodus", "positivePromptPlaceholder": "Positieve prompt", - "maskBlur": "Vervaag", + "maskBlur": "Vervaging van masker", "invoke": { "noNodesInGraph": "Geen knooppunten in graaf", "noModelSelected": "Geen model ingesteld", @@ -369,11 +453,25 @@ "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt", "noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding", "noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.", - "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.", + "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.", "systemDisconnected": "Systeem is niet verbonden", "missingNodeTemplate": "Knooppuntsjabloon ontbreekt", "missingFieldTemplate": "Veldsjabloon ontbreekt", - "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan" + "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan", + "layer": { + "initialImageNoImageSelected": "geen initiële afbeelding geselecteerd", + "controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd", + "controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter", + "controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd", + "controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt", + "ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter", + "ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd", + "rgNoRegion": "geen gebied geselecteerd", + "rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters", + "t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64", + "ipAdapterNoModelSelected": "geen IP-adapter geselecteerd" + }, + "imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt" }, "isAllowedToUpscale": { "useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model", @@ -383,7 +481,26 @@ "useCpuNoise": "Gebruik CPU-ruis", "imageActions": "Afbeeldingshandeling", "iterations": "Iteraties", - "coherenceMode": "Modus" + "coherenceMode": "Modus", + "infillColorValue": "Vulkleur", + "remixImage": "Meng afbeelding opnieuw", + "setToOptimalSize": "Optimaliseer grootte voor het model", + "setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)", + "aspect": "Beeldverhouding", + "infillMosaicTileWidth": "Breedte tegel", + "setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)", + "lockAspectRatio": "Zet beeldverhouding vast", + "infillMosaicTileHeight": "Hoogte tegel", + "globalNegativePromptPlaceholder": "Globale negatieve prompt", + "globalPositivePromptPlaceholder": "Globale positieve prompt", + "useSize": "Gebruik grootte", + "swapDimensions": "Wissel afmetingen om", + "globalSettings": "Globale instellingen", + "coherenceEdgeSize": "Randgrootte", + "coherenceMinDenoise": "Min. ontruising", + "infillMosaicMinColor": "Min. kleur", + "infillMosaicMaxColor": "Max. kleur", + "cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling" }, "settings": { "models": "Modellen", @@ -410,7 +527,12 @@ "intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist", "intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist", "clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.", - "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen" + "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen", + "clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken", + "enableInformationalPopovers": "Schakel informatieve hulpballonnen in", + "enableInvisibleWatermark": "Schakel onzichtbaar watermerk in", + "enableNSFWChecker": "Schakel NSFW-controle in", + "reloadingIn": "Opnieuw laden na" }, "toast": { "uploadFailed": "Upload mislukt", @@ -425,8 +547,8 @@ "connected": "Verbonden met server", "canceled": "Verwerking geannuleerd", "uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn", - "parameterNotSet": "Parameter niet ingesteld", - "parameterSet": "Instellen parameters", + "parameterNotSet": "{{parameter}} niet ingesteld", + "parameterSet": "{{parameter}} ingesteld", "problemCopyingImage": "Kan Afbeelding Niet Kopiëren", "baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld", "baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld", @@ -443,11 +565,11 @@ "maskSavedAssets": "Masker bewaard in Assets", "problemDownloadingCanvas": "Fout bij downloaden van canvas", "problemMergingCanvas": "Fout bij samenvoegen canvas", - "setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding", + "setCanvasInitialImage": "Initiële canvasafbeelding ingesteld", "imageUploaded": "Afbeelding geüpload", "addedToBoard": "Toegevoegd aan bord", "workflowLoaded": "Werkstroom geladen", - "modelAddedSimple": "Model toegevoegd", + "modelAddedSimple": "Model toegevoegd aan wachtrij", "problemImportingMaskDesc": "Kan masker niet exporteren", "problemCopyingCanvas": "Fout bij kopiëren canvas", "problemSavingCanvas": "Fout bij bewaren canvas", @@ -459,7 +581,18 @@ "maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets", "canvasSavedGallery": "Canvas bewaard in galerij", "imageUploadFailed": "Fout bij uploaden afbeelding", - "problemImportingMask": "Fout bij importeren masker" + "problemImportingMask": "Fout bij importeren masker", + "workflowDeleted": "Werkstroom verwijderd", + "invalidUpload": "Ongeldige upload", + "uploadInitialImage": "Initiële afbeelding uploaden", + "setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas", + "problemRetrievingWorkflow": "Fout bij ophalen van werkstroom", + "parameters": "Parameters", + "modelImportCanceled": "Importeren model geannuleerd", + "problemDeletingWorkflow": "Fout bij verwijderen van werkstroom", + "prunedQueue": "Wachtrij gesnoeid", + "problemDownloadingImage": "Fout bij downloaden afbeelding", + "resetInitialImage": "Initiële afbeelding hersteld" }, "tooltip": { "feature": { @@ -533,7 +666,11 @@ "showOptionsPanel": "Toon zijscherm", "menu": "Menu", "showGalleryPanel": "Toon deelscherm Galerij", - "loadMore": "Laad meer" + "loadMore": "Laad meer", + "about": "Over", + "mode": "Modus", + "resetUI": "$t(accessibility.reset) UI", + "createIssue": "Maak probleem aan" }, "nodes": { "zoomOutNodes": "Uitzoomen", @@ -547,7 +684,7 @@ "loadWorkflow": "Laad werkstroom", "downloadWorkflow": "Download JSON van werkstroom", "scheduler": "Planner", - "missingTemplate": "Ontbrekende sjabloon", + "missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)", "workflowDescription": "Korte beschrijving", "versionUnknown": " Versie onbekend", "noNodeSelected": "Geen knooppunt gekozen", @@ -563,7 +700,7 @@ "integer": "Geheel getal", "nodeTemplate": "Sjabloon knooppunt", "nodeOpacity": "Dekking knooppunt", - "unableToLoadWorkflow": "Kan werkstroom niet valideren", + "unableToLoadWorkflow": "Fout bij laden werkstroom", "snapToGrid": "Lijn uit op raster", "noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave", "nodeSearch": "Zoek naar knooppunten", @@ -614,11 +751,56 @@ "unknownField": "Onbekend veld", "colorCodeEdges": "Kleurgecodeerde randen", "unknownNode": "Onbekend knooppunt", - "mismatchedVersion": "Heeft niet-overeenkomende versie", + "mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)", "addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)", "loadingNodes": "Bezig met laden van knooppunten...", "snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing", - "workflowSettings": "Instellingen werkstroomeditor" + "workflowSettings": "Instellingen werkstroomeditor", + "addLinearView": "Voeg toe aan lineaire weergave", + "nodePack": "Knooppuntpakket", + "unknownInput": "Onbekende invoer: {{name}}", + "sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet", + "collectionFieldType": "Verzameling {{name}}", + "deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd", + "graph": "Grafiek", + "targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet", + "resetToDefaultValue": "Herstel naar standaardwaarden", + "editMode": "Bewerk in Werkstroom-editor", + "showEdgeLabels": "Toon randlabels", + "showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven", + "clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.", + "unableToParseFieldType": "fout bij bepalen soort veld", + "sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet", + "unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"", + "targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet", + "reorderLinearView": "Herorden lineaire weergave", + "newWorkflowDesc": "Een nieuwe werkstroom aanmaken?", + "collectionOrScalarFieldType": "Verzameling|scalair {{name}}", + "newWorkflow": "Nieuwe werkstroom", + "unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom", + "unsupportedAnyOfLength": "te veel union-leden ({{count}})", + "unknownOutput": "Onbekende uitvoer: {{name}}", + "viewMode": "Gebruik in lineaire weergave", + "unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref", + "unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}", + "unknownNodeType": "Onbekend soort knooppunt", + "edit": "Bewerk", + "updateAllNodes": "Werk knooppunten bij", + "allNodesUpdated": "Alle knooppunten bijgewerkt", + "nodeVersion": "Knooppuntversie", + "newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.", + "clearWorkflow": "Maak werkstroom leeg", + "clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?", + "inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})", + "outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})", + "unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties", + "unknownFieldType": "Soort $t(nodes.unknownField): {{type}}", + "unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom", + "betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.", + "prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.", + "noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.", + "unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt", + "unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten" }, "controlnet": { "amult": "a_mult", @@ -691,9 +873,28 @@ "canny": "Canny", "depthZoeDescription": "Genereer diepteblad via Zoe", "hedDescription": "Herkenning van holistisch-geneste randen", - "setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H", + "setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)", "scribble": "Krabbel", - "maxFaces": "Max. gezichten" + "maxFaces": "Max. gezichten", + "dwOpenpose": "DW Openpose", + "depthAnything": "Depth Anything", + "base": "Basis", + "hands": "Handen", + "selectCLIPVisionModel": "Selecteer een CLIP Vision-model", + "modelSize": "Modelgrootte", + "small": "Klein", + "large": "Groot", + "resizeSimple": "Wijzig grootte (eenvoudig)", + "beginEndStepPercentShort": "Begin-/eind-%", + "depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything", + "face": "Gezicht", + "body": "Lichaam", + "dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose", + "ipAdapterMethod": "Methode", + "full": "Volledig", + "style": "Alleen stijl", + "composition": "Alleen samenstelling", + "setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)" }, "dynamicPrompts": { "seedBehaviour": { @@ -706,7 +907,10 @@ "maxPrompts": "Max. prompts", "promptsWithCount_one": "{{count}} prompt", "promptsWithCount_other": "{{count}} prompts", - "dynamicPrompts": "Dynamische prompts" + "dynamicPrompts": "Dynamische prompts", + "showDynamicPrompts": "Toon dynamische prompts", + "loading": "Genereren van dynamische prompts...", + "promptsPreview": "Voorvertoning prompts" }, "popovers": { "noiseUseCPU": { @@ -719,7 +923,7 @@ }, "paramScheduler": { "paragraphs": [ - "De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model." + "De planner gebruikt gedurende het genereringsproces." ], "heading": "Planner" }, @@ -806,8 +1010,8 @@ }, "clipSkip": { "paragraphs": [ - "Kies hoeveel CLIP-modellagen je wilt overslaan.", - "Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen." + "Aantal over te slaan CLIP-modellagen.", + "Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen." ], "heading": "Overslaan CLIP" }, @@ -991,17 +1195,26 @@ "denoisingStrength": "Sterkte ontruising", "refinermodel": "Verfijningsmodel", "posAestheticScore": "Positieve esthetische score", - "concatPromptStyle": "Plak prompt- en stijltekst aan elkaar", + "concatPromptStyle": "Koppelen van prompt en stijl", "loading": "Bezig met laden...", "steps": "Stappen", - "posStylePrompt": "Positieve-stijlprompt" + "posStylePrompt": "Positieve-stijlprompt", + "freePromptStyle": "Handmatige stijlprompt", + "refinerSteps": "Aantal stappen verfijner" }, "models": { "noMatchingModels": "Geen overeenkomend modellen", "loading": "bezig met laden", "noMatchingLoRAs": "Geen overeenkomende LoRA's", "noModelsAvailable": "Geen modellen beschikbaar", - "selectModel": "Kies een model" + "selectModel": "Kies een model", + "noLoRAsInstalled": "Geen LoRA's geïnstalleerd", + "noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd", + "defaultVAE": "Standaard-VAE", + "lora": "LoRA", + "esrganModel": "ESRGAN-model", + "addLora": "Voeg LoRA toe", + "concepts": "Concepten" }, "boards": { "autoAddBoard": "Voeg automatisch bord toe", @@ -1019,7 +1232,13 @@ "downloadBoard": "Download bord", "changeBoard": "Wijzig bord", "loading": "Bezig met laden...", - "clearSearch": "Maak zoekopdracht leeg" + "clearSearch": "Maak zoekopdracht leeg", + "deleteBoard": "Verwijder bord", + "deleteBoardAndImages": "Verwijder bord en afbeeldingen", + "deleteBoardOnly": "Verwijder alleen bord", + "deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld", + "movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:", + "movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:" }, "invocationCache": { "disable": "Schakel uit", @@ -1036,5 +1255,39 @@ "clear": "Wis", "maxCacheSize": "Max. grootte cache", "cacheSize": "Grootte cache" + }, + "accordions": { + "generation": { + "title": "Genereren" + }, + "image": { + "title": "Afbeelding" + }, + "advanced": { + "title": "Geavanceerd", + "options": "$t(accordions.advanced.title) Opties" + }, + "control": { + "title": "Besturing" + }, + "compositing": { + "title": "Samenstellen", + "coherenceTab": "Coherentiefase", + "infillTab": "Invullen" + } + }, + "hrf": { + "upscaleMethod": "Opschaalmethode", + "metadata": { + "strength": "Sterkte oplossing voor hoge resolutie", + "method": "Methode oplossing voor hoge resolutie", + "enabled": "Oplossing voor hoge resolutie ingeschakeld" + }, + "hrf": "Oplossing voor hoge resolutie", + "enableHrf": "Schakel oplossing in voor hoge resolutie" + }, + "prompt": { + "addPromptTrigger": "Voeg prompttrigger toe", + "compatibleEmbeddings": "Compatibele embeddings" } } diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index 7fa1b73e7a..03ff7eb706 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -1594,7 +1594,6 @@ "deleteAll": "Удалить всё", "addLayer": "Добавить слой", "moveToFront": "На передний план", - "toggleVisibility": "Переключить видимость слоя", "addPositivePrompt": "Добавить $t(common.positivePrompt)", "addIPAdapter": "Добавить $t(common.ipAdapter)", "regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 30d8f41200..2d878d96e7 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -21,10 +21,10 @@ import i18n from 'i18n'; import { size } from 'lodash-es'; import { memo, useCallback, useEffect } from 'react'; import { ErrorBoundary } from 'react-error-boundary'; +import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import PreselectedImage from './PreselectedImage'; -import Toaster from './Toaster'; const DEFAULT_CONFIG = {}; @@ -46,6 +46,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { useSocketIO(); useGlobalModifiersInit(); useGlobalHotkeys(); + useGetOpenAPISchemaQuery(); const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone(); @@ -94,7 +95,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { - ); diff --git a/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx index d2992a8cd9..ced3037a40 100644 --- a/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx +++ b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx @@ -1,5 +1,8 @@ -import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library'; +import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import newGithubIssueUrl from 'new-github-issue-url'; +import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi'; @@ -11,31 +14,39 @@ type Props = { }; const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => { - const toast = useToast(); const { t } = useTranslation(); + const isLocal = useAppSelector((s) => s.config.isLocal); const handleCopy = useCallback(() => { const text = JSON.stringify(serializeError(error), null, 2); navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``); toast({ - title: 'Error Copied', + id: 'ERROR_COPIED', + title: t('toast.errorCopied'), }); - }, [error, toast]); + }, [error, t]); - const url = useMemo( - () => - newGithubIssueUrl({ + const url = useMemo(() => { + if (isLocal) { + return newGithubIssueUrl({ user: 'invoke-ai', repo: 'InvokeAI', template: 'BUG_REPORT.yml', title: `[bug]: ${error.name}: ${error.message}`, - }), - [error.message, error.name] - ); + }); + } else { + return 'https://support.invoke.ai/support/tickets/new'; + } + }, [error.message, error.name, isLocal]); + return ( - {t('common.somethingWentWrong')} + + + {t('common.somethingWentWrong')} + + { {t('common.copyError')} - }>{t('accessibility.createIssue')} + }> + {isLocal ? t('accessibility.createIssue') : t('accessibility.submitSupportTicket')} + diff --git a/invokeai/frontend/web/src/app/components/Toaster.ts b/invokeai/frontend/web/src/app/components/Toaster.ts deleted file mode 100644 index c86fd5060d..0000000000 --- a/invokeai/frontend/web/src/app/components/Toaster.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { useToast } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast, clearToastQueue } from 'features/system/store/systemSlice'; -import type { MakeToastArg } from 'features/system/util/makeToast'; -import { makeToast } from 'features/system/util/makeToast'; -import { memo, useCallback, useEffect } from 'react'; - -/** - * Logical component. Watches the toast queue and makes toasts when the queue is not empty. - * @returns null - */ -const Toaster = () => { - const dispatch = useAppDispatch(); - const toastQueue = useAppSelector((s) => s.system.toastQueue); - const toast = useToast(); - useEffect(() => { - toastQueue.forEach((t) => { - toast(t); - }); - toastQueue.length > 0 && dispatch(clearToastQueue()); - }, [dispatch, toast, toastQueue]); - - return null; -}; - -/** - * Returns a function that can be used to make a toast. - * @example - * const toaster = useAppToaster(); - * toaster('Hello world!'); - * toaster({ title: 'Hello world!', status: 'success' }); - * @returns A function that can be used to make a toast. - * @see makeToast - * @see MakeToastArg - * @see UseToastOptions - */ -export const useAppToaster = () => { - const dispatch = useAppDispatch(); - const toaster = useCallback((arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), [dispatch]); - - return toaster; -}; - -export default memo(Toaster); diff --git a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts index aaa3b8f6f2..d3baf5f452 100644 --- a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts +++ b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts @@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks'; import type { MapStore } from 'nanostores'; import { atom, map } from 'nanostores'; import { useEffect, useMemo } from 'react'; +import { setEventListeners } from 'services/events/setEventListeners'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import { setEventListeners } from 'services/events/util/setEventListeners'; import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; import { io } from 'socket.io-client'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 0c0c8ed2bc..0fd2f1b79c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -35,28 +35,22 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; +import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; -import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; -import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; -import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError'; -import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed'; -import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed'; import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved'; import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested'; import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import type { AppDispatch, RootState } from 'app/store/store'; -import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings'; - export const listenerMiddleware = createListenerMiddleware(); export type AppStartListening = TypedStartListening; @@ -104,18 +98,13 @@ addCommitStagingAreaImageListener(startAppListening); // Socket.IO addGeneratorProgressEventListener(startAppListening); -addGraphExecutionStateCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening); addInvocationErrorEventListener(startAppListening); addInvocationStartedEventListener(startAppListening); addSocketConnectedEventListener(startAppListening); addSocketDisconnectedEventListener(startAppListening); -addSocketSubscribedEventListener(startAppListening); -addSocketUnsubscribedEventListener(startAppListening); addModelLoadEventListener(startAppListening); addModelInstallEventListener(startAppListening); -addSessionRetrievalErrorEventListener(startAppListening); -addInvocationRetrievalErrorEventListener(startAppListening); addSocketQueueItemStatusChangedEventListener(startAppListening); addBulkDownloadListeners(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index ae26531722..9095a08431 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -8,7 +8,7 @@ import { resetCanvas, setInitialCanvasImage, } from 'features/canvas/store/canvasSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; @@ -30,22 +30,20 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis req.reset(); if (canceled > 0) { log.debug(`Canceled ${canceled} canvas batches`); - dispatch( - addToast({ - title: t('queue.cancelBatchSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'CANCEL_BATCH_SUCCEEDED', + title: t('queue.cancelBatchSucceeded'), + status: 'success', + }); } dispatch(canvasBatchIdsReset()); } catch { log.error('Failed to cancel canvas batches'); - dispatch( - addToast({ - title: t('queue.cancelBatchFailed'), - status: 'error', - }) - ); + toast({ + id: 'CANCEL_BATCH_FAILED', + title: t('queue.cancelBatchFailed'), + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts index 68eda997b7..3f74bf9b61 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -1,8 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { parseify } from 'common/util/serialize'; -import { toast } from 'common/util/toast'; import { zPydanticValidationError } from 'features/system/store/zodSchemas'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { truncate, upperFirst } from 'lodash-es'; import { queueApi } from 'services/api/endpoints/queue'; @@ -16,18 +16,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = const arg = action.meta.arg.originalArgs; logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued'); - if (!toast.isActive('batch-queued')) { - toast({ - id: 'batch-queued', - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - count: response.enqueued, - direction: arg.prepend ? t('queue.front') : t('queue.back'), - }), - duration: 1000, - status: 'success', - }); - } + toast({ + id: 'QUEUE_BATCH_SUCCEEDED', + title: t('queue.batchQueued'), + status: 'success', + description: t('queue.batchQueuedDesc', { + count: response.enqueued, + direction: arg.prepend ? t('queue.front') : t('queue.back'), + }), + }); }, }); @@ -40,9 +37,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = if (!response) { toast({ + id: 'QUEUE_BATCH_FAILED', title: t('queue.batchFailedToQueue'), status: 'error', - description: 'Unknown Error', + description: t('common.unknownError'), }); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); return; @@ -52,7 +50,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = if (result.success) { result.data.data.detail.map((e) => { toast({ - id: 'batch-failed-to-queue', + id: 'QUEUE_BATCH_FAILED', title: truncate(upperFirst(e.msg), { length: 128 }), status: 'error', description: truncate( @@ -64,9 +62,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = }); } else if (response.status !== 403) { toast({ + id: 'QUEUE_BATCH_FAILED', title: t('queue.batchFailedToQueue'), - description: t('common.unknownError'), status: 'error', + description: t('common.unknownError'), }); } logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index 38a0fd7911..489f218370 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -1,13 +1,12 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { toast } from 'common/util/toast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { - socketBulkDownloadCompleted, - socketBulkDownloadFailed, + socketBulkDownloadComplete, + socketBulkDownloadError, socketBulkDownloadStarted, } from 'services/events/actions'; @@ -28,7 +27,6 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // Show the response message if it exists, otherwise show the default message description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'), duration: null, - isClosable: true, }); }, }); @@ -40,9 +38,9 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // There isn't any toast to update if we get this event. toast({ + id: 'BULK_DOWNLOAD_REQUEST_FAILED', title: t('gallery.bulkDownloadRequestFailed'), - status: 'success', - isClosable: true, + status: 'error', }); }, }); @@ -56,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = }); startAppListening({ - actionCreator: socketBulkDownloadCompleted, + actionCreator: socketBulkDownloadComplete, effect: async (action) => { log.debug(action.payload.data, 'Bulk download preparation completed'); @@ -65,7 +63,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first const url = `/api/v1/images/download/${bulk_download_item_name}`; - const toastOptions: UseToastOptions = { + toast({ id: bulk_download_item_name, title: t('gallery.bulkDownloadReady', 'Download ready'), status: 'success', @@ -77,38 +75,24 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = /> ), duration: null, - isClosable: true, - }; - - if (toast.isActive(bulk_download_item_name)) { - toast.update(bulk_download_item_name, toastOptions); - } else { - toast(toastOptions); - } + }); }, }); startAppListening({ - actionCreator: socketBulkDownloadFailed, + actionCreator: socketBulkDownloadError, effect: async (action) => { log.debug(action.payload.data, 'Bulk download preparation failed'); const { bulk_download_item_name } = action.payload.data; - const toastOptions: UseToastOptions = { + toast({ id: bulk_download_item_name, title: t('gallery.bulkDownloadFailed'), status: 'error', description: action.payload.data.error, duration: null, - isClosable: true, - }; - - if (toast.isActive(bulk_download_item_name)) { - toast.update(bulk_download_item_name, toastOptions); - } else { - toast(toastOptions); - } + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts index e1f4804d56..311dda3e2e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts @@ -2,14 +2,14 @@ import { $logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: canvasCopiedToClipboard, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { getState }) => { const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' }); const state = getState(); @@ -19,22 +19,20 @@ export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartLi copyBlobToClipboard(blob); } catch (err) { moduleLog.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemCopyingCanvas'), - description: t('toast.problemCopyingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_COPY_FAILED', + title: t('toast.problemCopyingCanvas'), + description: t('toast.problemCopyingCanvasDesc'), + status: 'error', + }); return; } - dispatch( - addToast({ - title: t('toast.canvasCopiedClipboard'), - status: 'success', - }) - ); + toast({ + id: 'CANVAS_COPY_SUCCEEDED', + title: t('toast.canvasCopiedClipboard'), + status: 'success', + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts index 5b8150bd20..71e616b9ea 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts @@ -3,13 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; import { downloadBlob } from 'features/canvas/util/downloadBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: canvasDownloadedAsImage, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { getState }) => { const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' }); const state = getState(); @@ -18,18 +18,17 @@ export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartLi blob = await getBaseLayerBlob(state); } catch (err) { moduleLog.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemDownloadingCanvas'), - description: t('toast.problemDownloadingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_DOWNLOAD_FAILED', + title: t('toast.problemDownloadingCanvas'), + description: t('toast.problemDownloadingCanvasDesc'), + status: 'error', + }); return; } downloadBlob(blob, 'canvas.png'); - dispatch(addToast({ title: t('toast.canvasDownloaded'), status: 'success' })); + toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts index 55392ebff4..2aa1f52d6c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasImageToControlAdapter } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -20,13 +20,12 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi blob = await getBaseLayerBlob(state, true); } catch (err) { log.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemSavingCanvas'), - description: t('toast.problemSavingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_SAVING_CANVAS', + title: t('toast.problemSavingCanvas'), + description: t('toast.problemSavingCanvasDesc'), + status: 'error', + }); return; } @@ -43,7 +42,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi crop_visible: false, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasSentControlnetAssets') }, + title: t('toast.canvasSentControlnetAssets'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts index af0c3878fc..454342b997 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts @@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -29,13 +29,12 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL if (!maskBlob) { log.error('Problem getting mask layer blob'); - dispatch( - addToast({ - title: t('toast.problemSavingMask'), - description: t('toast.problemSavingMaskDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_SAVING_MASK', + title: t('toast.problemSavingMask'), + description: t('toast.problemSavingMaskDesc'), + status: 'error', + }); return; } @@ -52,7 +51,7 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL crop_visible: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.maskSavedAssets') }, + title: t('toast.maskSavedAssets'), }, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts index 569b4badc7..2e6ca61d8a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasMaskToControlAdapter } from 'features/canvas/store/actions'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -30,13 +30,12 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis if (!maskBlob) { log.error('Problem getting mask layer blob'); - dispatch( - addToast({ - title: t('toast.problemImportingMask'), - description: t('toast.problemImportingMaskDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_IMPORTING_MASK', + title: t('toast.problemImportingMask'), + description: t('toast.problemImportingMaskDesc'), + status: 'error', + }); return; } @@ -53,7 +52,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis crop_visible: false, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.maskSentControlnetAssets') }, + title: t('toast.maskSentControlnetAssets'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index 71b0e62b44..9ae6de2e76 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -4,7 +4,7 @@ import { canvasMerged } from 'features/canvas/store/actions'; import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore'; import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -17,13 +17,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => if (!blob) { moduleLog.error('Problem getting base layer blob'); - dispatch( - addToast({ - title: t('toast.problemMergingCanvas'), - description: t('toast.problemMergingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_MERGING_CANVAS', + title: t('toast.problemMergingCanvas'), + description: t('toast.problemMergingCanvasDesc'), + status: 'error', + }); return; } @@ -31,13 +30,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => if (!canvasBaseLayer) { moduleLog.error('Problem getting canvas base layer'); - dispatch( - addToast({ - title: t('toast.problemMergingCanvas'), - description: t('toast.problemMergingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_MERGING_CANVAS', + title: t('toast.problemMergingCanvas'), + description: t('toast.problemMergingCanvasDesc'), + status: 'error', + }); return; } @@ -54,7 +52,7 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => is_intermediate: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasMerged') }, + title: t('toast.canvasMerged'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index e3ba988886..71586b5f6e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -1,8 +1,9 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { parseify } from 'common/util/serialize'; import { canvasSavedToGallery } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -18,13 +19,12 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe blob = await getBaseLayerBlob(state); } catch (err) { log.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemSavingCanvas'), - description: t('toast.problemSavingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_SAVE_FAILED', + title: t('toast.problemSavingCanvas'), + description: t('toast.problemSavingCanvasDesc'), + status: 'error', + }); return; } @@ -41,7 +41,10 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe crop_visible: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasSavedGallery') }, + title: t('toast.canvasSavedGallery'), + }, + metadata: { + _canvas_objects: parseify(state.canvas.layerState.objects), }, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index 2a59cc0317..581146c25c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -14,8 +14,9 @@ import { } from 'features/controlLayers/store/controlLayersSlice'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { isImageOutput } from 'features/nodes/types/common'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { isEqual } from 'lodash-es'; import { getImageDTO } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig } from 'services/api/types'; @@ -47,8 +48,10 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { startAppListening({ matcher, - effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => { + effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => { const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId; + const state = getState(); + const originalState = getOriginalState(); // Cancel any in-progress instances of this listener cancelActiveListeners(); @@ -57,21 +60,33 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni // Delay before starting actual work await delay(DEBOUNCE_MS); - // Double-check that we are still eligible for processing - const state = getState(); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); - // If we have no image or there is no processor config, bail if (!layer) { return; } + // We should only process if the processor settings or image have changed + const originalLayer = originalState.controlLayers.present.layers + .filter(isControlAdapterLayer) + .find((l) => l.id === layerId); + const originalImage = originalLayer?.controlAdapter.image; + const originalConfig = originalLayer?.controlAdapter.processorConfig; + const image = layer.controlAdapter.image; const config = layer.controlAdapter.processorConfig; + if (isEqual(config, originalConfig) && isEqual(image, originalImage)) { + // Neither config nor image have changed, we can bail + return; + } + if (!image || !config) { - // The user has reset the image or config, so we should clear the processed image + // - If we have no image, we have nothing to process + // - If we have no processor config, we have nothing to process + // Clear the processed image and bail dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); + return; } // At this point, the user has stopped fiddling with the processor settings and there is a processor selected. @@ -81,8 +96,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId); } - // @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... - const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); + // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now + const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never); const enqueueBatchArg: BatchConfig = { prepend: true, batch: { @@ -118,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni const [invocationCompleteAction] = await take( (action): action is ReturnType => socketInvocationComplete.match(action) && - action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && - action.payload.data.source_node_id === processorNode.id + action.payload.data.batch_id === enqueueResult.batch.batch_id && + action.payload.data.invocation_source_id === processorNode.id ); // We still have to check the output type @@ -159,12 +174,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni } } - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } } finally { req.reset(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 0055866aa7..1e485b31d5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -10,7 +10,7 @@ import { } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; @@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL const [invocationCompleteAction] = await take( (action): action is ReturnType => socketInvocationComplete.match(action) && - action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && - action.payload.data.source_node_id === nodeId + action.payload.data.batch_id === enqueueResult.batch.batch_id && + action.payload.data.invocation_source_id === nodeId ); // We still have to check the output type @@ -108,12 +108,11 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL } } - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index d5d74bf668..cd5304c32b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,4 +1,3 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; @@ -14,7 +13,7 @@ import { } from 'features/controlLayers/store/controlLayersSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { omit } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; @@ -42,16 +41,17 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis return; } - const DEFAULT_UPLOADED_TOAST: UseToastOptions = { + const DEFAULT_UPLOADED_TOAST = { + id: 'IMAGE_UPLOADED', title: t('toast.imageUploaded'), status: 'success', - }; + } as const; // default action - just upload and alert user if (postUploadAction?.type === 'TOAST') { - const { toastOptions } = postUploadAction; if (!autoAddBoardId || autoAddBoardId === 'none') { - dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); + const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title; + toast({ ...DEFAULT_UPLOADED_TOAST, title }); } else { // Add this image to the board dispatch( @@ -70,24 +70,20 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis ? `${t('toast.addedToBoard')} ${board.board_name}` : `${t('toast.addedToBoard')} ${autoAddBoardId}`; - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description, - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description, + }); } return; } if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') { dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state))); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setAsCanvasInitialImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setAsCanvasInitialImage'), + }); return; } @@ -105,68 +101,56 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis controlImage: imageDTO.image_name, }) ); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); return; } if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(caLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(ipaLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') { const { layerId, ipAdapterId } = postUploadAction; dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(iiLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_NODES_IMAGE') { const { nodeId, fieldName } = postUploadAction; dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: `${t('toast.setNodeField')} ${fieldName}`, - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: `${t('toast.setNodeField')} ${fieldName}`, + }); return; } }, @@ -174,7 +158,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis startAppListening({ matcher: imagesApi.endpoints.uploadImage.matchRejected, - effect: (action, { dispatch }) => { + effect: (action) => { const log = logger('images'); const sanitizedData = { arg: { @@ -183,13 +167,11 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis }, }; log.error({ ...sanitizedData }, 'Image upload failed'); - dispatch( - addToast({ - title: t('toast.imageUploadFailed'), - description: action.error.message, - status: 'error', - }) - ); + toast({ + title: t('toast.imageUploadFailed'), + description: action.error.message, + status: 'error', + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index bc049cf498..239a5b863d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -8,8 +8,7 @@ import { loraRemoved } from 'features/lora/store/loraSlice'; import { modelSelected } from 'features/parameters/store/actions'; import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice'; import { zParameterModel } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; @@ -60,16 +59,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = }); if (modelsCleared > 0) { - dispatch( - addToast( - makeToast({ - title: t('toast.baseModelChangedCleared', { - count: modelsCleared, - }), - status: 'warning', - }) - ) - ); + toast({ + id: 'BASE_MODEL_CHANGED', + title: t('toast.baseModelChanged'), + description: t('toast.baseModelChangedCleared', { + count: modelsCleared, + }), + status: 'warning', + }); } } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts index 61a978d576..415c359d70 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts @@ -19,8 +19,7 @@ import { isParameterWidth, zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import { isNonRefinerMainModelConfig } from 'services/api/types'; @@ -109,7 +108,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni } } - dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); + toast({ id: 'PARAMETER_SET', title: t('toast.parameterSet', { parameter: 'Default settings' }) }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index 2dd598396a..08ad830ba4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -1,7 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; -import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; +import { parseify } from 'common/util/serialize'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; @@ -11,13 +12,14 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis startAppListening({ actionCreator: socketGeneratorProgress, effect: (action) => { - log.trace(action.payload, `Generator progress`); - const { source_node_id, step, total_steps, progress_image } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + log.trace(parseify(action.payload), `Generator progress`); + const { invocation_source_id, step, total_steps, progress_image } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.IN_PROGRESS; nes.progress = (step + 1) / total_steps; nes.progressImage = progress_image ?? null; + upsertExecutionState(nes.nodeId, nes); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts deleted file mode 100644 index 5221679232..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketGraphExecutionStateComplete } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketGraphExecutionStateComplete, - effect: (action) => { - log.debug(action.payload, 'Session complete'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 06dc08d846..1a04f9493a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi actionCreator: socketInvocationComplete, effect: async (action, { dispatch, getState }) => { const { data } = action.payload; - log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); + log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`); - const { result, node, queue_batch_id, source_node_id } = data; + const { result, invocation_source_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { - const { image_name } = result.image; + if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) { + const { image_name } = data.result.image; const { canvas, gallery } = getState(); // This populates the `getImageDTO` cache @@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { + if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } @@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi } } - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.COMPLETED; if (nes.progress !== null) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index ce26c4dd7d..b34f34a079 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; +import { parseify } from 'common/util/serialize'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationError } from 'services/events/actions'; @@ -11,14 +12,18 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe startAppListening({ actionCreator: socketInvocationError, effect: (action) => { - log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); - const { source_node_id } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data; + log.error(parseify(action.payload), `Invocation error (${invocation.type})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.FAILED; - nes.error = action.payload.data.error; nes.progress = null; nes.progressImage = null; + nes.error = { + error_type, + error_message, + error_traceback, + }; upsertExecutionState(nes.nodeId, nes); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts deleted file mode 100644 index 44da4c0ddb..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketInvocationRetrievalError } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketInvocationRetrievalError, - effect: (action) => { - log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts index 9d6e0ac14d..7dae869ce2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; +import { parseify } from 'common/util/serialize'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationStarted } from 'services/events/actions'; @@ -11,9 +12,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis startAppListening({ actionCreator: socketInvocationStarted, effect: (action) => { - log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`); - const { source_node_id } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`); + const { invocation_source_id } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.IN_PROGRESS; upsertExecutionState(nes.nodeId, nes); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index f474c2736b..7fafb8302c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api'; import { modelsApi } from 'services/api/endpoints/models'; import { socketModelInstallCancelled, - socketModelInstallCompleted, - socketModelInstallDownloading, + socketModelInstallComplete, + socketModelInstallDownloadProgress, socketModelInstallError, } from 'services/events/actions'; export const addModelInstallEventListener = (startAppListening: AppStartListening) => { startAppListening({ - actionCreator: socketModelInstallDownloading, + actionCreator: socketModelInstallDownloadProgress, effect: async (action, { dispatch }) => { const { bytes, total_bytes, id } = action.payload.data; @@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin }); startAppListening({ - actionCreator: socketModelInstallCompleted, + actionCreator: socketModelInstallComplete, effect: (action, { dispatch }) => { const { id } = action.payload.data; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts index 4f4ec7635e..0240fe219a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions'; +import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions'; const log = logger('socketio'); @@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening) startAppListening({ actionCreator: socketModelLoadStarted, effect: (action) => { - const { model_config, submodel_type } = action.payload.data; - const { name, base, type } = model_config; + const { config, submodel_type } = action.payload.data; + const { name, base, type } = config; const extras: string[] = [base, type]; + if (submodel_type) { extras.push(submodel_type); } @@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening) }); startAppListening({ - actionCreator: socketModelLoadCompleted, + actionCreator: socketModelLoadComplete, effect: (action) => { - const { model_config, submodel_type } = action.payload.data; - const { name, base, type } = model_config; + const { config, submodel_type } = action.payload.data; + const { name, base, type } = config; const extras: string[] = [base, type]; if (submodel_type) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx similarity index 55% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx index 2adc529766..8a83609b3c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx @@ -3,6 +3,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { deepClone } from 'common/util/deepClone'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; +import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; +import { toast } from 'features/toast/toast'; import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; @@ -12,18 +14,38 @@ const log = logger('socketio'); export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: async (action, { dispatch }) => { + effect: async (action, { dispatch, getState }) => { // we've got new status for the queue item, batch and queue - const { queue_item, batch_status, queue_status } = action.payload.data; + const { + item_id, + session_id, + status, + started_at, + updated_at, + completed_at, + batch_status, + queue_status, + error_type, + error_message, + error_traceback, + } = action.payload.data; - log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`); + log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`); // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) dispatch( queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueItemsAdapter.updateOne(draft, { - id: String(queue_item.item_id), - changes: queue_item, + id: String(item_id), + changes: { + status, + started_at, + updated_at: updated_at ?? undefined, + completed_at: completed_at ?? undefined, + error_type, + error_message, + error_traceback, + }, }); }) ); @@ -43,23 +65,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status) ); - // Update the queue item status (this is the full queue item, including the session) - dispatch( - queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => { - if (!draft) { - return; - } - Object.assign(draft, queue_item); - }) - ); - // Invalidate caches for things we cannot update // TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again dispatch( - queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus']) + queueApi.util.invalidateTags([ + 'CurrentSessionQueueItem', + 'NextSessionQueueItem', + 'InvocationCacheStatus', + { type: 'SessionQueueItem', id: item_id }, + ]) ); - if (['in_progress'].includes(action.payload.data.queue_item.status)) { + if (status === 'in_progress') { forEach($nodeExecutionStates.get(), (nes) => { if (!nes) { return; @@ -72,6 +89,25 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: clone.outputs = []; $nodeExecutionStates.setKey(clone.nodeId, clone); }); + } else if (status === 'failed' && error_type) { + const isLocal = getState().config.isLocal ?? true; + const sessionId = session_id; + + toast({ + id: `INVOCATION_ERROR_${error_type}`, + title: getTitleFromErrorType(error_type), + status: 'error', + duration: null, + updateDescription: isLocal, + description: ( + + ), + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts deleted file mode 100644 index a1a497dc08..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketSessionRetrievalError } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketSessionRetrievalError, - effect: (action) => { - log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts deleted file mode 100644 index 48324cb652..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketSubscribedSession } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketSubscribedSession, - effect: (action) => { - log.debug(action.payload, 'Subscribed'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts deleted file mode 100644 index 7a76a809d6..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketUnsubscribedSession } from 'services/events/actions'; -const log = logger('socketio'); - -export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketUnsubscribedSession, - effect: (action) => { - log.debug(action.payload, 'Unsubscribed'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts index 6816e25bc1..6c4c2a9df1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -1,6 +1,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { stagingAreaImageSaved } from 'features/canvas/store/actions'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -29,15 +29,14 @@ export const addStagingAreaImageSavedListener = (startAppListening: AppStartList }) ); } - dispatch(addToast({ title: t('toast.imageSaved'), status: 'success' })); + toast({ id: 'IMAGE_SAVED', title: t('toast.imageSaved'), status: 'success' }); } catch (error) { - dispatch( - addToast({ - title: t('toast.imageSavingFailed'), - description: (error as Error)?.message, - status: 'error', - }) - ); + toast({ + id: 'IMAGE_SAVE_FAILED', + title: t('toast.imageSavingFailed'), + description: (error as Error)?.message, + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 63d960b406..07df2a4f42 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,12 +1,11 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice'; +import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice'; import { NodeUpdateError } from 'features/nodes/types/error'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => { @@ -31,7 +30,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi } try { const updatedNode = updateNode(node, template); - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + dispatch( + nodesChanged([ + { type: 'remove', id: updatedNode.id }, + { type: 'add', item: updatedNode }, + ]) + ); } catch (e) { if (e instanceof NodeUpdateError) { unableToUpdateCount++; @@ -45,24 +49,18 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi count: unableToUpdateCount, }) ); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToUpdateNodes', { - count: unableToUpdateCount, - }), - }) - ) - ); + toast({ + id: 'UNABLE_TO_UPDATE_NODES', + title: t('nodes.unableToUpdateNodes', { + count: unableToUpdateCount, + }), + }); } else { - dispatch( - addToast( - makeToast({ - title: t('nodes.allNodesUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'ALL_NODES_UPDATED', + title: t('nodes.allNodesUpdated'), + status: 'success', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index ff5d5f24be..ce480a3573 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -4,7 +4,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { parseify } from 'common/util/serialize'; import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph'; import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig, ImageDTO } from 'services/api/types'; @@ -29,12 +29,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening { imageDTO }, t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce ); - dispatch( - addToast({ - title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce - status: 'error', - }) - ); + toast({ + id: 'NOT_ALLOWED_TO_UPSCALE', + title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce + status: 'error', + }); return; } @@ -65,12 +64,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening if (error instanceof Object && 'status' in error && error.status === 403) { return; } else { - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index a680bbca97..e6fc5a526a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -8,23 +8,23 @@ import type { Templates } from 'features/nodes/store/types'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow'; import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks'; import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types'; import { z } from 'zod'; import { fromZodError } from 'zod-validation-error'; -const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { +const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => { if (data.workflow) { // Prefer to load the workflow if it's available - it has more information const parsed = JSON.parse(data.workflow); - return validateWorkflow(parsed, templates); + return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else if (data.graph) { // Else we fall back on the graph, using the graphToWorkflow function to convert and do layout const parsed = JSON.parse(data.graph); const workflow = graphToWorkflow(parsed as NonNullableGraph, true); - return validateWorkflow(workflow, templates); + return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else { throw new Error('No workflow or graph provided'); } @@ -33,13 +33,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: workflowLoadRequested, - effect: (action, { dispatch }) => { + effect: async (action, { dispatch }) => { const log = logger('nodes'); const { data, asCopy } = action.payload; const nodeTemplates = $templates.get(); try { - const { workflow, warnings } = getWorkflow(data, nodeTemplates); + const { workflow, warnings } = await getWorkflow(data, nodeTemplates); if (asCopy) { // If we're loading a copy, we need to remove the ID so that the backend will create a new workflow @@ -48,23 +48,18 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList dispatch(workflowLoaded(workflow)); if (!warnings.length) { - dispatch( - addToast( - makeToast({ - title: t('toast.workflowLoaded'), - status: 'success', - }) - ) - ); + toast({ + id: 'WORKFLOW_LOADED', + title: t('toast.workflowLoaded'), + status: 'success', + }); } else { - dispatch( - addToast( - makeToast({ - title: t('toast.loadedWithWarnings'), - status: 'warning', - }) - ) - ); + toast({ + id: 'WORKFLOW_LOADED', + title: t('toast.loadedWithWarnings'), + status: 'warning', + }); + warnings.forEach(({ message, ...rest }) => { log.warn(rest, message); }); @@ -77,54 +72,42 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList if (e instanceof WorkflowVersionError) { // The workflow version was not recognized in the valid list of versions log.error({ error: parseify(e) }, e.message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: e.message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }); } else if (e instanceof WorkflowMigrationError) { // There was a problem migrating the workflow to the latest version log.error({ error: parseify(e) }, e.message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: e.message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }); } else if (e instanceof z.ZodError) { // There was a problem validating the workflow itself const { message } = fromZodError(e, { prefix: t('nodes.workflowValidation'), }); log.error({ error: parseify(e) }, message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: message, + }); } else { // Some other error occurred log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow')); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: t('nodes.unknownErrorValidatingWorkflow'), - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: t('nodes.unknownErrorValidatingWorkflow'), + }); } } }, diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 4982dbb83f..21636ada49 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -74,6 +74,7 @@ export type AppConfig = { maxUpscalePixels?: number; metadataFetchDebounce?: number; workflowFetchDebounce?: number; + isLocal?: boolean; sd: { defaultModel?: string; disabledControlNetModels: string[]; diff --git a/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts b/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts index ef9db44a9d..233b841034 100644 --- a/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts +++ b/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts @@ -1,11 +1,10 @@ -import { useAppToaster } from 'app/components/Toaster'; import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; export const useCopyImageToClipboard = () => { - const toaster = useAppToaster(); const { t } = useTranslation(); const imageUrlToBlob = useImageUrlToBlob(); @@ -16,12 +15,11 @@ export const useCopyImageToClipboard = () => { const copyImageToClipboard = useCallback( async (image_url: string) => { if (!isClipboardAPIAvailable) { - toaster({ + toast({ + id: 'PROBLEM_COPYING_IMAGE', title: t('toast.problemCopyingImage'), description: "Your browser doesn't support the Clipboard API.", status: 'error', - duration: 2500, - isClosable: true, }); } try { @@ -33,23 +31,21 @@ export const useCopyImageToClipboard = () => { copyBlobToClipboard(blob); - toaster({ + toast({ + id: 'IMAGE_COPIED', title: t('toast.imageCopied'), status: 'success', - duration: 2500, - isClosable: true, }); } catch (err) { - toaster({ + toast({ + id: 'PROBLEM_COPYING_IMAGE', title: t('toast.problemCopyingImage'), description: String(err), status: 'error', - duration: 2500, - isClosable: true, }); } }, - [imageUrlToBlob, isClipboardAPIAvailable, t, toaster] + [imageUrlToBlob, isClipboardAPIAvailable, t] ); return { isClipboardAPIAvailable, copyImageToClipboard }; diff --git a/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts b/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts index 26a17e1d0c..ede247b9fb 100644 --- a/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts +++ b/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts @@ -1,13 +1,12 @@ import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { $authToken } from 'app/store/nanostores/authToken'; import { useAppDispatch } from 'app/store/storeHooks'; import { imageDownloaded } from 'features/gallery/store/actions'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; export const useDownloadImage = () => { - const toaster = useAppToaster(); const { t } = useTranslation(); const dispatch = useAppDispatch(); const authToken = useStore($authToken); @@ -37,16 +36,15 @@ export const useDownloadImage = () => { window.URL.revokeObjectURL(url); dispatch(imageDownloaded()); } catch (err) { - toaster({ + toast({ + id: 'PROBLEM_DOWNLOADING_IMAGE', title: t('toast.problemDownloadingImage'), description: String(err), status: 'error', - duration: 2500, - isClosable: true, }); } }, - [t, toaster, dispatch, authToken] + [t, dispatch, authToken] ); return { downloadImage }; diff --git a/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts b/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts index 0334294e98..5b1bf1f5b3 100644 --- a/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts +++ b/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts @@ -1,6 +1,6 @@ -import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { useCallback, useEffect, useState } from 'react'; import type { Accept, FileRejection } from 'react-dropzone'; @@ -26,7 +26,6 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac export const useFullscreenDropzone = () => { const { t } = useTranslation(); - const toaster = useAppToaster(); const postUploadAction = useAppSelector(selectPostUploadAction); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const [isHandlingUpload, setIsHandlingUpload] = useState(false); @@ -37,13 +36,14 @@ export const useFullscreenDropzone = () => { (rejection: FileRejection) => { setIsHandlingUpload(true); - toaster({ + toast({ + id: 'UPLOAD_FAILED', title: t('toast.uploadFailed'), description: rejection.errors.map((error) => error.message).join('\n'), status: 'error', }); }, - [t, toaster] + [t] ); const fileAcceptedCallback = useCallback( @@ -62,7 +62,8 @@ export const useFullscreenDropzone = () => { const onDrop = useCallback( (acceptedFiles: Array, fileRejections: Array) => { if (fileRejections.length > 1) { - toaster({ + toast({ + id: 'UPLOAD_FAILED', title: t('toast.uploadFailed'), description: t('toast.uploadFailedInvalidUploadDesc'), status: 'error', @@ -78,7 +79,7 @@ export const useFullscreenDropzone = () => { fileAcceptedCallback(file); }); }, - [t, toaster, fileAcceptedCallback, fileRejectionCallback] + [t, fileAcceptedCallback, fileRejectionCallback] ); const onDragOver = useCallback(() => { diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 41d6f4607e..dbf3c41480 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -137,7 +137,7 @@ const createSelector = (templates: Templates) => if (l.controlAdapter.type === 't2i_adapter') { const multiple = model?.base === 'sdxl' ? 32 : 64; if (size.width % multiple !== 0 || size.height % multiple !== 0) { - problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); + problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple })); } } } diff --git a/invokeai/frontend/web/src/common/util/toast.ts b/invokeai/frontend/web/src/common/util/toast.ts deleted file mode 100644 index ac61a4a12d..0000000000 --- a/invokeai/frontend/web/src/common/util/toast.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; - -export const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index a22f23d9d3..fbb6378166 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -613,7 +613,7 @@ export const canvasSlice = createSlice({ state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id); } - const queueItemStatus = action.payload.data.queue_item.status; + const queueItemStatus = action.payload.data.status; if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') { resetStagingAreaIfEmpty(state); } diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx index 36509ec1d3..9e71ad943c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx @@ -4,7 +4,7 @@ import { CALayerControlAdapterWrapper } from 'features/controlLayers/components/ import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; @@ -26,7 +26,7 @@ export const CALayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx index c6efd041ca..c53c4c7631 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx @@ -5,7 +5,7 @@ import { InitialImagePreview } from 'features/controlLayers/components/IILayer/I import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { iiLayerDenoisingStrengthChanged, @@ -66,7 +66,7 @@ export const IILayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx index 2077700104..e8f60c8d07 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { IPALayerIPAdapterWrapper } from 'features/controlLayers/components/IPALayer/IPALayerIPAdapterWrapper'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { layerSelected, selectIPALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; @@ -22,7 +22,7 @@ export const IPALayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx index d2dab39e36..227d74c35a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx @@ -1,8 +1,8 @@ import { IconButton } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { stopPropagation } from 'common/util/stopPropagation'; -import { useLayerIsVisible } from 'features/controlLayers/hooks/layerStateHooks'; -import { layerVisibilityToggled } from 'features/controlLayers/store/controlLayersSlice'; +import { useLayerIsEnabled } from 'features/controlLayers/hooks/layerStateHooks'; +import { layerIsEnabledToggled } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiCheckBold } from 'react-icons/pi'; @@ -11,21 +11,21 @@ type Props = { layerId: string; }; -export const LayerVisibilityToggle = memo(({ layerId }: Props) => { +export const LayerIsEnabledToggle = memo(({ layerId }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const isVisible = useLayerIsVisible(layerId); + const isEnabled = useLayerIsEnabled(layerId); const onClick = useCallback(() => { - dispatch(layerVisibilityToggled(layerId)); + dispatch(layerIsEnabledToggled(layerId)); }, [dispatch, layerId]); return ( : undefined} + icon={isEnabled ? : undefined} onClick={onClick} colorScheme="base" onDoubleClick={stopPropagation} // double click expands the layer @@ -33,4 +33,4 @@ export const LayerVisibilityToggle = memo(({ layerId }: Props) => { ); }); -LayerVisibilityToggle.displayName = 'LayerVisibilityToggle'; +LayerIsEnabledToggle.displayName = 'LayerVisibilityToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx index a6bce5316e..cc331017d3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx @@ -6,7 +6,7 @@ import { AddPromptButtons } from 'features/controlLayers/components/AddPromptBut import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { isRegionalGuidanceLayer, @@ -55,7 +55,7 @@ export const RGLayer = memo(({ layerId }: Props) => { return ( - + {autoNegative === 'invert' && ( diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx index ce02811ebf..ba02aa9242 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx @@ -45,7 +45,6 @@ export const RGLayerNegativePrompt = memo(({ layerId }: Props) => { variant="darkFilled" paddingRight={30} fontSize="sm" - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx index 56d3953e25..6f85ea077c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx @@ -45,7 +45,6 @@ export const RGLayerPositivePrompt = memo(({ layerId }: Props) => { variant="darkFilled" paddingRight={30} minH={28} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts index f2054779d4..21e49ba15e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts @@ -39,7 +39,7 @@ export const useLayerNegativePrompt = (layerId: string) => { return prompt; }; -export const useLayerIsVisible = (layerId: string) => { +export const useLayerIsEnabled = (layerId: string) => { const selectLayer = useMemo( () => createSelector(selectControlLayersSlice, (controlLayers) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 32e29918ae..5fa8cc3dfb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -139,7 +139,7 @@ export const controlLayersSlice = createSlice({ layerSelected: (state, action: PayloadAction) => { exclusivelySelectLayer(state, action.payload); }, - layerVisibilityToggled: (state, action: PayloadAction) => { + layerIsEnabledToggled: (state, action: PayloadAction) => { const layer = state.layers.find((l) => l.id === action.payload); if (layer) { layer.isEnabled = !layer.isEnabled; @@ -616,12 +616,24 @@ export const controlLayersSlice = createSlice({ iiLayerAdded: { reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; + + // Retain opacity and denoising strength of existing initial image layer if exists + let opacity = 1; + let denoisingStrength = 0.75; + const iiLayer = state.layers.find((l) => l.id === layerId); + if (iiLayer) { + assert(isInitialImageLayer(iiLayer)); + opacity = iiLayer.opacity; + denoisingStrength = iiLayer.denoisingStrength; + } + // Highlander! There can be only one! state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true)); + const layer: InitialImageLayer = { id: layerId, type: 'initial_image_layer', - opacity: 1, + opacity, x: 0, y: 0, bbox: null, @@ -629,7 +641,7 @@ export const controlLayersSlice = createSlice({ isEnabled: true, image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, isSelected: true, - denoisingStrength: 0.75, + denoisingStrength, }; state.layers.push(layer); exclusivelySelectLayer(state, layer.id); @@ -779,7 +791,7 @@ class LayerColors { export const { // Any Layer Type layerSelected, - layerVisibilityToggled, + layerIsEnabledToggled, layerTranslated, layerBboxChanged, layerReset, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index a25f6d8c0e..b3119aa8fa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -1,6 +1,5 @@ import { Flex, MenuDivider, MenuItem, Spinner } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { $customStarUI } from 'app/store/nanostores/customStarUI'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard'; @@ -11,10 +10,13 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { toast } from 'features/toast/toast'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; @@ -44,10 +46,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const optimalDimension = useAppSelector(selectOptimalDimension); const dispatch = useAppDispatch(); const { t } = useTranslation(); - const toaster = useAppToaster(); const isCanvasEnabled = useFeatureStatus('canvas'); const customStarUi = useStore($customStarUI); const { downloadImage } = useDownloadImage(); + const templates = useStore($templates); const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } = useImageActions(imageDTO?.image_name); @@ -83,13 +85,12 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { }); dispatch(setInitialCanvasImage(imageDTO, optimalDimension)); - toaster({ + toast({ + id: 'SENT_TO_CANVAS', title: t('toast.sentToUnifiedCanvas'), status: 'success', - duration: 2500, - isClosable: true, }); - }, [dispatch, imageDTO, t, toaster, optimalDimension]); + }, [dispatch, imageDTO, t, optimalDimension]); const handleChangeBoard = useCallback(() => { dispatch(imagesToChangeSelected([imageDTO])); @@ -133,7 +134,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { : } onClickCapture={handleLoadWorkflow} - isDisabled={!imageDTO.has_workflow} + isDisabled={!imageDTO.has_workflow || !size(templates)} > {t('nodes.loadWorkflow')} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index ada9c35d28..d500d692fe 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,4 +1,5 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; @@ -12,12 +13,14 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; +import { $templates } from 'features/nodes/store/nodesSlice'; import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { selectSystemSlice } from 'features/system/store/systemSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -48,7 +51,7 @@ const CurrentImageButtons = () => { const lastSelectedImage = useAppSelector(selectLastSelectedImage); const selection = useAppSelector((s) => s.gallery.selection); const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); - + const templates = useStore($templates); const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); const { t } = useTranslation(); @@ -143,7 +146,7 @@ const CurrentImageButtons = () => { icon={} tooltip={`${t('nodes.loadWorkflow')} (W)`} aria-label={`${t('nodes.loadWorkflow')} (W)`} - isDisabled={!imageDTO?.has_workflow} + isDisabled={!imageDTO?.has_workflow || !size(templates)} onClick={handleLoadWorkflow} isLoading={getAndLoadEmbeddedWorkflowResult.isLoading} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx index 4bf55116db..df3fbe2765 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx @@ -1,6 +1,5 @@ import { IconButton } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; @@ -14,7 +13,6 @@ export const ToggleMetadataViewerButton = memo(() => { const dispatch = useAppDispatch(); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const lastSelectedImage = useAppSelector(selectLastSelectedImage); - const toaster = useAppToaster(); const { t } = useTranslation(); const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken); @@ -24,7 +22,7 @@ export const ToggleMetadataViewerButton = memo(() => { [dispatch, shouldShowImageDetails] ); - useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails, toaster]); + useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails]); return ( { const recallSeed = useCallback(() => { handlers.seed.parse(metadata).then((seed) => { - handlers.seed.recall && handlers.seed.recall(seed); + handlers.seed.recall && handlers.seed.recall(seed, true); }); }, [metadata]); diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index b0d0e22688..2829507dcd 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -1,5 +1,4 @@ import { objectKeys } from 'common/util/objectKeys'; -import { toast } from 'common/util/toast'; import type { Layer } from 'features/controlLayers/store/types'; import type { LoRA } from 'features/lora/store/loraSlice'; import type { @@ -15,6 +14,7 @@ import type { import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { validators } from 'features/metadata/util/validators'; import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { assert } from 'tsafe'; @@ -89,23 +89,23 @@ const renderLayersValue: MetadataRenderValueFunc = async (layers) => { return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`; }; -const parameterSetToast = (parameter: string, description?: string) => { +const parameterSetToast = (parameter: string) => { toast({ - title: t('toast.parameterSet', { parameter }), - description, + id: 'PARAMETER_SET', + title: t('toast.parameterSet'), + description: t('toast.parameterSetDesc', { parameter }), status: 'info', - duration: 2500, - isClosable: true, }); }; -const parameterNotSetToast = (parameter: string, description?: string) => { +const parameterNotSetToast = (parameter: string, message?: string) => { toast({ - title: t('toast.parameterNotSet', { parameter }), - description, + id: 'PARAMETER_NOT_SET', + title: t('toast.parameterNotSet'), + description: message + ? t('toast.parameterNotSetDescWithMessage', { parameter, message }) + : t('toast.parameterNotSetDesc', { parameter }), status: 'warning', - duration: 2500, - isClosable: true, }); }; @@ -458,7 +458,18 @@ export const parseAndRecallAllMetadata = async ( }); }) ); + if (results.some((result) => result.status === 'fulfilled')) { - parameterSetToast(t('toast.parameters')); + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersSet'), + status: 'info', + }); + } else { + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersNotSet'), + status: 'warning', + }); } }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts new file mode 100644 index 0000000000..7636b9f314 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts @@ -0,0 +1,48 @@ +import { toast } from 'features/toast/toast'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; + +type InstallModelArg = { + source: string; + inplace?: boolean; + onSuccess?: () => void; + onError?: (error: unknown) => void; +}; + +export const useInstallModel = () => { + const { t } = useTranslation(); + const [_installModel, request] = useInstallModelMutation(); + + const installModel = useCallback( + ({ source, inplace, onSuccess, onError }: InstallModelArg) => { + _installModel({ source, inplace }) + .unwrap() + .then((_) => { + if (onSuccess) { + onSuccess(); + } + toast({ + id: 'MODEL_INSTALL_QUEUED', + title: t('toast.modelAddedSimple'), + status: 'success', + }); + }) + .catch((error) => { + if (onError) { + onError(error); + } + if (error) { + toast({ + id: 'MODEL_INSTALL_QUEUE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); + } + }); + }, + [_installModel, t] + ); + + return [installModel, request] as const; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx index 6106264b78..6da320aa0b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx @@ -17,7 +17,11 @@ export const useStarterModelsToast = () => { useEffect(() => { if (toast.isActive(TOAST_ID)) { - return; + if (mainModels.length === 0) { + return; + } else { + toast.close(TOAST_ID); + } } if (data && mainModels.length === 0 && !didToast && isEnabled) { toast({ diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx index 184429478e..ee5960f7d2 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx @@ -1,11 +1,9 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEventHandler } from 'react'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; +import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; import { HuggingFaceResults } from './HuggingFaceResults'; @@ -14,50 +12,19 @@ export const HuggingFaceForm = () => { const [displayResults, setDisplayResults] = useState(false); const [errorMessage, setErrorMessage] = useState(''); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); - const [installModel] = useInstallModelMutation(); - - const handleInstallModel = useCallback( - (source: string) => { - installModel({ source }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); - }, - [installModel, dispatch, t] - ); + const [installModel] = useInstallModel(); const getModels = useCallback(async () => { _getHuggingFaceModels(huggingFaceRepo) .unwrap() .then((response) => { if (response.is_diffusers) { - handleInstallModel(huggingFaceRepo); + installModel({ source: huggingFaceRepo }); setDisplayResults(false); } else if (response.urls?.length === 1 && response.urls[0]) { - handleInstallModel(response.urls[0]); + installModel({ source: response.urls[0] }); setDisplayResults(false); } else { setDisplayResults(true); @@ -66,7 +33,7 @@ export const HuggingFaceForm = () => { .catch((error) => { setErrorMessage(error.data.detail || ''); }); - }, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]); + }, [_getHuggingFaceModels, installModel, huggingFaceRepo]); const handleSetHuggingFaceRepo: ChangeEventHandler = useCallback((e) => { setHuggingFaceRepo(e.target.value); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx index 1595a61147..32970a3666 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx @@ -1,47 +1,20 @@ import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPlusBold } from 'react-icons/pi'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type Props = { result: string; }; export const HuggingFaceResultItem = ({ result }: Props) => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); - const handleInstall = useCallback(() => { - installModel({ source: result }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); - }, [installModel, result, dispatch, t]); + const onClick = useCallback(() => { + installModel({ source: result }); + }, [installModel, result]); return ( @@ -51,7 +24,7 @@ export const HuggingFaceResultItem = ({ result }: Props) => { {result} - } onClick={handleInstall} size="sm" /> + } onClick={onClick} size="sm" /> ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx index 8144accf3f..826fd177ea 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx @@ -8,15 +8,12 @@ import { InputGroup, InputRightElement, } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEventHandler } from 'react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiXBold } from 'react-icons/pi'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; import { HuggingFaceResultItem } from './HuggingFaceResultItem'; @@ -27,9 +24,8 @@ type HuggingFaceResultsProps = { export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { const { t } = useTranslation(); const [searchTerm, setSearchTerm] = useState(''); - const dispatch = useAppDispatch(); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); const filteredResults = useMemo(() => { return results.filter((result) => { @@ -46,34 +42,11 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { setSearchTerm(''); }, []); - const handleAddAll = useCallback(() => { + const onClickAddAll = useCallback(() => { for (const result of filteredResults) { - installModel({ source: result }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source: result }); } - }, [filteredResults, installModel, dispatch, t]); + }, [filteredResults, installModel]); return ( <> @@ -82,7 +55,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { {t('modelManager.availableModels')} - + {t('modelManager.installAll')} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx index 282e07ee27..cc052878bf 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx @@ -1,12 +1,9 @@ import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import { t } from 'i18next'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type SimpleImportModelConfig = { location: string; @@ -14,9 +11,7 @@ type SimpleImportModelConfig = { }; export const InstallModelForm = () => { - const dispatch = useAppDispatch(); - - const [installModel, { isLoading }] = useInstallModelMutation(); + const [installModel, { isLoading }] = useInstallModel(); const { register, handleSubmit, formState, reset } = useForm({ defaultValues: { @@ -26,40 +21,22 @@ export const InstallModelForm = () => { mode: 'onChange', }); + const resetForm = useCallback(() => reset(undefined, { keepValues: true }), [reset]); + const onSubmit = useCallback>( (values) => { if (!values?.location) { return; } - installModel({ source: values.location, inplace: values.inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - reset(undefined, { keepValues: true }); - }) - .catch((error) => { - reset(undefined, { keepValues: true }); - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ + source: values.location, + inplace: values.inplace, + onSuccess: resetForm, + onError: resetForm, + }); }, - [dispatch, reset, installModel] + [installModel, resetForm] ); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx index 5db2743669..b3544af5b3 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx @@ -1,8 +1,6 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { useCallback, useMemo } from 'react'; import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models'; @@ -10,8 +8,6 @@ import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } fro import { ModelInstallQueueItem } from './ModelInstallQueueItem'; export const ModelInstallQueue = () => { - const dispatch = useAppDispatch(); - const { data } = useListModelInstallsQuery(); const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation(); @@ -20,28 +16,22 @@ export const ModelInstallQueue = () => { _pruneCompletedModelInstalls() .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.prunedQueue'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_QUEUE_PRUNED', + title: t('toast.prunedQueue'), + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_QUEUE_PRUNE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); - }, [_pruneCompletedModelInstalls, dispatch]); + }, [_pruneCompletedModelInstalls]); const pruneAvailable = useMemo(() => { return data?.some( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx index d1fc600510..82a28b2d75 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx @@ -1,7 +1,5 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; @@ -29,7 +27,6 @@ const formatBytes = (bytes: number) => { export const ModelInstallQueueItem = (props: ModelListItemProps) => { const { installJob } = props; - const dispatch = useAppDispatch(); const [deleteImportModel] = useCancelModelInstallMutation(); @@ -37,28 +34,22 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => { deleteImportModel(installJob.id) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelImportCanceled'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_CANCELED', + title: t('toast.modelImportCanceled'), + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_CANCEL_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); - }, [deleteImportModel, installJob, dispatch]); + }, [deleteImportModel, installJob]); const sourceLocation = useMemo(() => { switch (installJob.source.type) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx index 360d6c1403..749ef4c8e0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx @@ -11,15 +11,13 @@ import { InputGroup, InputRightElement, } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEvent, ChangeEventHandler } from 'react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiXBold } from 'react-icons/pi'; -import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; +import type { ScanFolderResponse } from 'services/api/endpoints/models'; import { ScanModelResultItem } from './ScanFolderResultItem'; @@ -30,9 +28,8 @@ type ScanModelResultsProps = { export const ScanModelsResults = ({ results }: ScanModelResultsProps) => { const { t } = useTranslation(); const [searchTerm, setSearchTerm] = useState(''); - const dispatch = useAppDispatch(); const [inplace, setInplace] = useState(true); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); const filteredResults = useMemo(() => { return results.filter((result) => { @@ -58,61 +55,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => { if (result.is_installed) { continue; } - installModel({ source: result.path, inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source: result.path, inplace }); } - }, [filteredResults, installModel, inplace, dispatch, t]); + }, [filteredResults, installModel, inplace]); const handleInstallOne = useCallback( (source: string) => { - installModel({ source, inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source, inplace }); }, - [installModel, inplace, dispatch, t] + [installModel, inplace] ); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx index a32fad810f..98e1e39640 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx @@ -1,20 +1,16 @@ import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPlusBold } from 'react-icons/pi'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type Props = { result: GetStarterModelsResponse[number]; }; export const StarterModelsResultItem = ({ result }: Props) => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); const allSources = useMemo(() => { const _allSources = [result.source]; if (result.dependencies) { @@ -22,36 +18,13 @@ export const StarterModelsResultItem = ({ result }: Props) => { } return _allSources; }, [result]); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); - const handleQuickAdd = useCallback(() => { + const onClick = useCallback(() => { for (const source of allSources) { - installModel({ source }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source }); } - }, [allSources, installModel, dispatch, t]); + }, [allSources, installModel]); return ( @@ -67,7 +40,7 @@ export const StarterModelsResultItem = ({ result }: Props) => { {result.is_installed ? ( {t('common.installed')} ) : ( - } onClick={handleQuickAdd} size="sm" /> + } onClick={onClick} size="sm" /> )} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx index c9d0c03ed8..a4a6d5c833 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice'; import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import type { MouseEvent } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -53,25 +52,19 @@ const ModelListItem = (props: ModelListItemProps) => { deleteModel({ key: model.key }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelDeleted')}: ${model.name}`, - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_DELETED', + title: `${t('modelManager.modelDeleted')}: ${model.name}`, + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_DELETE_FAILED', + title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, + status: 'error', + }); } }); dispatch(setSelectedModelKey(null)); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx index dcdc4e2a36..9a84fbc726 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx @@ -1,10 +1,9 @@ import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings'; import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor'; import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -19,7 +18,6 @@ export type ControlNetOrT2IAdapterDefaultSettingsFormData = { export const ControlNetOrT2IAdapterDefaultSettings = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } = useControlNetOrT2IAdapterDefaultSettings(selectedModelKey); @@ -46,30 +44,24 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => { }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.defaultSettingsSaved'), - status: 'success', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVED', + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }); reset(data); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); }, - [selectedModelKey, dispatch, reset, updateModel, t] + [selectedModelKey, reset, updateModel, t] ); if (isLoadingDefaultSettings) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx index 0d7920ef77..292835a7b7 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx @@ -1,8 +1,6 @@ import { Box, Button, Flex, Icon, IconButton, Image, Tooltip } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import { typedMemo } from 'common/util/typedMemo'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback, useState } from 'react'; import { useDropzone } from 'react-dropzone'; import { useTranslation } from 'react-i18next'; @@ -15,7 +13,6 @@ type Props = { }; const ModelImageUpload = ({ model_key, model_image }: Props) => { - const dispatch = useAppDispatch(); const [image, setImage] = useState(model_image || null); const { t } = useTranslation(); @@ -34,27 +31,21 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => { .unwrap() .then(() => { setImage(URL.createObjectURL(file)); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_IMAGE_UPDATED', + title: t('modelManager.modelImageUpdated'), + status: 'success', + }); }) - .catch((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageUpdateFailed'), - status: 'error', - }) - ) - ); + .catch(() => { + toast({ + id: 'MODEL_IMAGE_UPDATE_FAILED', + title: t('modelManager.modelImageUpdateFailed'), + status: 'error', + }); }); }, - [dispatch, model_key, t, updateModelImage] + [model_key, t, updateModelImage] ); const handleResetImage = useCallback(() => { @@ -65,26 +56,20 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => { deleteModelImage(model_key) .unwrap() .then(() => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageDeleted'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_IMAGE_DELETED', + title: t('modelManager.modelImageDeleted'), + status: 'success', + }); }) - .catch((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageDeleteFailed'), - status: 'error', - }) - ) - ); + .catch(() => { + toast({ + id: 'MODEL_IMAGE_DELETE_FAILED', + title: t('modelManager.modelImageDeleteFailed'), + status: 'error', + }); }); - }, [dispatch, model_key, t, deleteModelImage]); + }, [model_key, t, deleteModelImage]); const { getInputProps, getRootProps } = useDropzone({ accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] }, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx index e096b11209..233fc7bc6b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx @@ -1,11 +1,10 @@ import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings'; import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight'; import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth'; import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -39,7 +38,6 @@ export type MainModelDefaultSettingsFormData = { export const MainModelDefaultSettings = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { defaultSettingsDefaults, @@ -76,30 +74,24 @@ export const MainModelDefaultSettings = () => { }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.defaultSettingsSaved'), - status: 'success', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVED', + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }); reset(data); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); }, - [selectedModelKey, dispatch, reset, updateModel, t] + [selectedModelKey, reset, updateModel, t] ); if (isLoadingDefaultSettings) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx index d95eed8d24..fa7ca4c394 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx @@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton'; import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -47,25 +46,19 @@ export const Model = () => { .then((payload) => { form.reset(payload, { keepDefaultValues: true }); dispatch(setSelectedModelMode('view')); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_UPDATED', + title: t('modelManager.modelUpdated'), + status: 'success', + }); }) .catch((_) => { form.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_UPDATE_FAILED', + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }); }); }, [dispatch, data?.key, form, t, updateModel] diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx index bcff0451d6..40ffca76b4 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx @@ -9,9 +9,7 @@ import { useDisclosure, } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models'; @@ -22,7 +20,6 @@ interface ModelConvertProps { export const ModelConvertButton = (props: ModelConvertProps) => { const { modelKey } = props; - const dispatch = useAppDispatch(); const { t } = useTranslation(); const { data } = useGetModelConfigQuery(modelKey ?? skipToken); const [convertModel, { isLoading }] = useConvertModelMutation(); @@ -33,38 +30,26 @@ export const ModelConvertButton = (props: ModelConvertProps) => { return; } - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`, - status: 'info', - }) - ) - ); + const toastId = `CONVERTING_MODEL_${data.key}`; + toast({ + id: toastId, + title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`, + status: 'info', + }); convertModel(data?.key) .unwrap() .then(() => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConverted')}: ${data?.name}`, - status: 'success', - }) - ) - ); + toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' }); }) .catch(() => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`, - status: 'error', - }) - ) - ); + toast({ + id: toastId, + title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`, + status: 'error', + }); }); - }, [data, isLoading, dispatch, t, convertModel]); + }, [data, isLoading, t, convertModel]); if (data?.format !== 'checkpoint') { return; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx index 1f8e50b9da..8bc775c872 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -72,10 +72,12 @@ export const ModelEdit = ({ form }: Props) => { {t('modelManager.baseModel')} - - {t('modelManager.variant')} - - + {data.type === 'main' && ( + + {t('modelManager.variant')} + + + )} {data.type === 'main' && data.format === 'checkpoint' && ( <> diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 95104c683c..6da87f4e98 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -3,33 +3,35 @@ import 'reactflow/dist/style.css'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { $cursorPos, + $edgePendingUpdate, $isAddNodePopoverOpen, $pendingConnection, $templates, closeAddNodePopover, - connectionMade, - nodeAdded, + edgesChanged, + nodesChanged, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; -import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; +import { toast } from 'features/toast/toast'; import { filter, map, memoize, some } from 'lodash-es'; -import type { KeyboardEventHandler } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { flushSync } from 'react-dom'; import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; -import { assert } from 'tsafe'; +import type { EdgeChange, NodeChange } from 'reactflow'; const createRegex = memoize( (inputValue: string) => @@ -58,7 +60,6 @@ const filterOption = memoize((option: FilterOptionOption, inputV const AddNodePopover = () => { const dispatch = useAppDispatch(); const buildInvocation = useBuildNode(); - const toaster = useAppToaster(); const { t } = useTranslation(); const selectRef = useRef | null>(null); const inputRef = useRef(null); @@ -69,17 +70,19 @@ const AddNodePopover = () => { const filteredTemplates = useMemo(() => { // If we have a connection in progress, we need to filter the node choices + const templatesArray = map(templates); if (!pendingConnection) { - return map(templates); + return templatesArray; } return filter(templates, (template) => { - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; - const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; - return some(fields, (field) => { - const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; - const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; - return validateSourceAndTargetTypes(sourceType, targetType); + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + return some(candidateFields, (field) => { + const sourceType = + pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type; + return validateConnectionTypes(sourceType, targetType); }); }); }, [templates, pendingConnection]); @@ -123,17 +126,43 @@ const AddNodePopover = () => { const errorMessage = t('nodes.unknownNode', { nodeType: nodeType, }); - toaster({ + toast({ status: 'error', title: errorMessage, }); return null; } + + // Find a cozy spot for the node const cursorPos = $cursorPos.get(); - dispatch(nodeAdded({ node, cursorPos })); + const { nodes, edges } = store.getState().nodes.present; + node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y); + node.selected = true; + + // Deselect all other nodes and edges + const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; + const edgeChanges: EdgeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (selected) { + nodeChanges.push({ type: 'select', id, selected: false }); + } + }); + edges.forEach(({ id, selected }) => { + if (selected) { + edgeChanges.push({ type: 'select', id, selected: false }); + } + }); + + // Onwards! + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } return node; }, - [dispatch, buildInvocation, toaster, t] + [buildInvocation, store, dispatch, t] ); const onChange = useCallback( @@ -145,12 +174,28 @@ const AddNodePopover = () => { // Auto-connect an edge if we just added a node and have a pending connection if (pendingConnection && isInvocationNode(node)) { - const template = templates[node.data.type]; - assert(template, 'Template not found'); + const edgePendingUpdate = $edgePendingUpdate.get(); + const { handleType } = pendingConnection; + + const source = handleType === 'source' ? pendingConnection.nodeId : node.id; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : node.id; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const { nodes, edges } = store.getState().nodes.present; - const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template); + const connection = getFirstValidConnection( + source, + sourceHandle, + target, + targetHandle, + nodes, + edges, + templates, + edgePendingUpdate + ); if (connection) { - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); } } @@ -160,25 +205,24 @@ const AddNodePopover = () => { ); const handleHotkeyOpen: HotkeyCallback = useCallback((e) => { - e.preventDefault(); - openAddNodePopover(); - flushSync(() => { - selectRef.current?.inputRef?.focus(); - }); + if (!$isAddNodePopoverOpen.get()) { + e.preventDefault(); + openAddNodePopover(); + flushSync(() => { + selectRef.current?.inputRef?.focus(); + }); + } }, []); const handleHotkeyClose: HotkeyCallback = useCallback(() => { - closeAddNodePopover(); - }, []); - - useHotkeys(['shift+a', 'space'], handleHotkeyOpen); - useHotkeys(['escape'], handleHotkeyClose); - const onKeyDown: KeyboardEventHandler = useCallback((e) => { - if (e.key === 'Escape') { + if ($isAddNodePopoverOpen.get()) { closeAddNodePopover(); } }, []); + useHotkeys(['shift+a', 'space'], handleHotkeyOpen); + useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] }); + const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]); return ( @@ -215,7 +259,6 @@ const AddNodePopover = () => { filterOption={filterOption} onChange={onChange} onMenuClose={closeAddNodePopover} - onKeyDown={onKeyDown} inputRef={inputRef} closeMenuOnSelect={false} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 656de737c7..1748989394 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,6 +1,6 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useConnection } from 'features/nodes/hooks/useConnection'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState'; @@ -8,38 +8,35 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection' import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, + $didUpdateEdge, + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, + $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, - connectionMade, - edgeAdded, - edgeDeleted, edgesChanged, - edgesDeleted, nodesChanged, - nodesDeleted, redo, - selectedAll, undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import type { + EdgeChange, + NodeChange, OnEdgesChange, - OnEdgesDelete, OnEdgeUpdateFunc, OnInit, OnMoveEnd, OnNodesChange, - OnNodesDelete, ProOptions, ReactFlowProps, ReactFlowState, } from 'reactflow'; -import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow'; +import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; @@ -48,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; -const DELETE_KEYS = ['Delete', 'Backspace']; - const edgeTypes = { collapsed: InvocationCollapsedEdge, default: InvocationDefaultEdge, @@ -81,6 +76,8 @@ export const Flow = memo(() => { const flowWrapper = useRef(null); const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); + const updateNodeInternals = useUpdateNodeInternals(); + const store = useAppStore(); useWorkflowWatcher(); useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); @@ -93,29 +90,17 @@ export const Flow = memo(() => { ); const onNodesChange: OnNodesChange = useCallback( - (changes) => { - dispatch(nodesChanged(changes)); + (nodeChanges) => { + dispatch(nodesChanged(nodeChanges)); }, [dispatch] ); const onEdgesChange: OnEdgesChange = useCallback( (changes) => { - dispatch(edgesChanged(changes)); - }, - [dispatch] - ); - - const onEdgesDelete: OnEdgesDelete = useCallback( - (edges) => { - dispatch(edgesDeleted(edges)); - }, - [dispatch] - ); - - const onNodesDelete: OnNodesDelete = useCallback( - (nodes) => { - dispatch(nodesDeleted(nodes)); + if (changes.length > 0) { + dispatch(edgesChanged(changes)); + } }, [dispatch] ); @@ -157,45 +142,50 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - // We have a ref for cursor position, but it is the *projected* cursor position. - // Easiest to just keep track of the last mouse event for this particular feature - const edgeUpdateMouseEvent = useRef(); - - const onEdgeUpdateStart: NonNullable = useCallback( - (e, edge, _handleType) => { - $isUpdatingEdge.set(true); - // update mouse event - edgeUpdateMouseEvent.current = e; - // always delete the edge when starting an updated - dispatch(edgeDeleted(edge.id)); - }, - [dispatch] - ); + const onEdgeUpdateStart: NonNullable = useCallback((e, edge, _handleType) => { + $edgePendingUpdate.set(edge); + $didUpdateEdge.set(false); + $lastEdgeUpdateMouseEvent.set(e); + }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (_oldEdge, newConnection) => { - // Because we deleted the edge when the update started, we must create a new edge from the connection - dispatch(connectionMade(newConnection)); + (oldEdge, newConnection) => { + // This event is fired when an edge update is successful + $didUpdateEdge.set(true); + // When an edge update is successful, we need to delete the old edge and create a new one + const newEdge = connectionToEdge(newConnection); + dispatch( + edgesChanged([ + { type: 'remove', id: oldEdge.id }, + { type: 'add', item: newEdge }, + ]) + ); + // Because we shift the position of handles depending on whether a field is connected or not, we must use + // updateNodeInternals to tell reactflow to recalculate the positions of the handles + updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]); }, - [dispatch] + [dispatch, updateNodeInternals] ); const onEdgeUpdateEnd: NonNullable = useCallback( (e, edge, _handleType) => { - $isUpdatingEdge.set(false); - $pendingConnection.set(null); - // Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting - // the edge update - we need to add it back - if ( - // ignore touch events - !('touches' in e) && - edgeUpdateMouseEvent.current?.clientX === e.clientX && - edgeUpdateMouseEvent.current?.clientY === e.clientY - ) { - dispatch(edgeAdded(edge)); + const didUpdateEdge = $didUpdateEdge.get(); + // Fall back to a reasonable default event + const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 }; + // We have to narrow this event down to MouseEvents - could be TouchEvent + const didMouseMove = + !('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5; + + // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, + // the user probably intended to delete the edge + if (!didUpdateEdge && didMouseMove) { + dispatch(edgesChanged([{ type: 'remove', id: edge.id }])); } - // reset mouse event - edgeUpdateMouseEvent.current = undefined; + + $edgePendingUpdate.set(null); + $didUpdateEdge.set(false); + $pendingConnection.set(null); + $lastEdgeUpdateMouseEvent.set(null); }, [dispatch] ); @@ -216,9 +206,27 @@ export const Flow = memo(() => { const onSelectAllHotkey = useCallback( (e: KeyboardEvent) => { e.preventDefault(); - dispatch(selectedAll()); + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (!selected) { + nodeChanges.push({ type: 'select', id, selected: true }); + } + }); + edges.forEach(({ id, selected }) => { + if (!selected) { + edgeChanges.push({ type: 'select', id, selected: true }); + } + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }, - [dispatch] + [dispatch, store] ); useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey); @@ -255,12 +263,37 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - $pendingConnection.set(null); - $isAddNodePopoverOpen.set(false); - cancelConnection(); + if (!$edgePendingUpdate.get()) { + $pendingConnection.set(null); + $isAddNodePopoverOpen.set(false); + cancelConnection(); + } }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); + const onDeleteHotkey = useCallback(() => { + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes + .filter((n) => n.selected) + .forEach(({ id }) => { + nodeChanges.push({ type: 'remove', id }); + }); + edges + .filter((e) => e.selected) + .forEach(({ id }) => { + edgeChanges.push({ type: 'remove', id }); + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } + }, [dispatch, store]); + useHotkeys(['delete', 'backspace'], onDeleteHotkey); + return ( { onMouseMove={onMouseMove} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} - onEdgesDelete={onEdgesDelete} onEdgeUpdate={onEdgeUpdate} onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateEnd={onEdgeUpdateEnd} - onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} onConnectEnd={onConnectEnd} @@ -292,9 +323,10 @@ export const Flow = memo(() => { proOptions={proOptions} style={flowStyles} onPaneClick={handlePaneClick} - deleteKeyCode={DELETE_KEYS} + deleteKeyCode={null} selectionMode={selectionMode} elevateEdgesOnSelect + nodeDragThreshold={1} > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx index 2e2fb31154..0d7e7b7d5e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx @@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; +import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector'; import { $templates } from 'features/nodes/store/nodesSlice'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; -import { makeEdgeSelector } from './util/makeEdgeSelector'; - const InvocationCollapsedEdge = ({ sourceX, sourceY, @@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({ targetPosition, markerEnd, data, - selected, + selected = false, source, - target, sourceHandleId, + target, targetHandleId, }: EdgeProps<{ count: number }>) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, selected, source, sourceHandleId, target, targetHandleId] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector); const [edgePath, labelX, labelY] = getBezierPath({ sourceX, @@ -44,14 +44,8 @@ const InvocationCollapsedEdge = ({ const { base500 } = useChakraThemeTokens(); const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke: base500, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [base500, isSelected, shouldAnimate] + () => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, base500, selected, shouldAnimateEdges] ); return ( @@ -60,11 +54,15 @@ const InvocationCollapsedEdge = ({ {data?.count && data.count > 1 && ( - + {data.count} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx index 2e4340975b..5a27e974e5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx @@ -1,8 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { $templates } from 'features/nodes/store/nodesSlice'; -import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; @@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({ sourcePosition, targetPosition, markerEnd, - selected, + selected = false, source, target, sourceHandleId, @@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({ }: EdgeProps) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, source, sourceHandleId, target, targetHandleId, selected] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector); const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels); const [edgePath, labelX, labelY] = getBezierPath({ @@ -41,15 +41,9 @@ const InvocationDefaultEdge = ({ targetPosition, }); - const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [isSelected, shouldAnimate, stroke] + const edgeStyles = useMemo( + () => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, stroke, selected, shouldAnimateEdges] ); return ( @@ -65,13 +59,13 @@ const InvocationDefaultEdge = ({ bg="base.800" borderRadius="base" borderWidth={1} - borderColor={isSelected ? 'undefined' : 'transparent'} - opacity={isSelected ? 1 : 0.5} + borderColor={selected ? 'undefined' : 'transparent'} + opacity={selected ? 1 : 0.5} py={1} px={3} shadow="md" > - + {label} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts index e7fa43015b..b5801c45ed 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -1,6 +1,7 @@ import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { FIELD_COLORS } from 'features/nodes/types/constants'; import type { FieldType } from 'features/nodes/types/field'; +import type { CSSProperties } from 'react'; export const getFieldColor = (fieldType: FieldType | null): string => { if (!fieldType) { @@ -10,3 +11,16 @@ export const getFieldColor = (fieldType: FieldType | null): string => { return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); }; + +export const getEdgeStyles = ( + stroke: string, + selected: boolean, + shouldAnimateEdges: boolean, + areConnectedNodesSelected: boolean +): CSSProperties => ({ + strokeWidth: 3, + stroke, + opacity: selected ? 1 : 0.5, + animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined, + strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none', +}); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 87ef8eb629..9c67728722 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,5 +1,6 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { deepClone } from 'common/util/deepClone'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { Templates } from 'features/nodes/store/types'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; @@ -8,8 +9,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; const defaultReturnValue = { - isSelected: false, - shouldAnimate: false, + areConnectedNodesSelected: false, + shouldAnimateEdges: false, stroke: colorTokenToCssVar('base.500'), label: '', }; @@ -19,21 +20,27 @@ export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, target: string, - targetHandleId: string | null | undefined, - selected?: boolean + targetHandleId: string | null | undefined ) => createMemoizedSelector( selectNodesSlice, selectWorkflowSettingsSlice, - (nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => { + ( + nodes, + workflowSettings + ): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => { + const { shouldAnimateEdges, shouldColorEdges } = workflowSettings; const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); + const returnValue = deepClone(defaultReturnValue); + returnValue.shouldAnimateEdges = shouldAnimateEdges; + const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected); if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { - return defaultReturnValue; + return returnValue; } const sourceNodeTemplate = templates[sourceNode.data.type]; @@ -42,16 +49,10 @@ export const makeEdgeSelector = ( const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; - const stroke = - sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); + returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); - const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; - return { - isSelected, - shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected, - stroke, - label, - }; + return returnValue; } ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx index 0147bcaed2..baa7fc262a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx @@ -1,7 +1,7 @@ import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper'; -import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames'; -import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames'; +import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck'; +import { useFieldNames } from 'features/nodes/hooks/useFieldNames'; import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames'; import { useWithFooter } from 'features/nodes/hooks/useWithFooter'; import { memo } from 'react'; @@ -20,8 +20,7 @@ type Props = { }; const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { - const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId); - const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId); + const fieldNames = useFieldNames(nodeId); const withFooter = useWithFooter(nodeId); const outputFieldNames = useOutputFieldNames(nodeId); @@ -41,9 +40,11 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { > - {inputConnectionFieldNames.map((fieldName, i) => ( + {fieldNames.connectionFields.map((fieldName, i) => ( - + + + ))} {outputFieldNames.map((fieldName, i) => ( @@ -52,8 +53,23 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { ))} - {inputAnyOrDirectFieldNames.map((fieldName) => ( - + {fieldNames.anyOrDirectFields.map((fieldName) => ( + + + + ))} + {fieldNames.missingFields.map((fieldName) => ( + + + ))} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 959b13c2d0..143dee983f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; +import type { ValidationResult } from 'features/nodes/store/util/validateConnection'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; -import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import type { HandleType } from 'reactflow'; import { Handle, Position } from 'reactflow'; @@ -14,11 +16,12 @@ type FieldHandleProps = { handleType: HandleType; isConnectionInProgress: boolean; isConnectionStartField: boolean; - connectionError?: string; + validationResult: ValidationResult; }; const FieldHandle = (props: FieldHandleProps) => { - const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props; + const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props; + const { t } = useTranslation(); const { name } = fieldTemplate; const type = fieldTemplate.type; const fieldTypeName = useFieldTypeName(type); @@ -26,11 +29,11 @@ const FieldHandle = (props: FieldHandleProps) => { const isModelType = MODEL_TYPES.some((t) => t === type.name); const color = getFieldColor(type); const s: CSSProperties = { - backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color, + backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0, + borderWidth: !isSingle(type) ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', @@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => { s.insetInlineEnd = '-1rem'; } - if (isConnectionInProgress && !isConnectionStartField && connectionError) { + if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) { s.filter = 'opacity(0.4) grayscale(0.7)'; } - if (isConnectionInProgress && connectionError) { + if (isConnectionInProgress && !validationResult.isValid) { if (isConnectionStartField) { s.cursor = 'grab'; } else { @@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => { } return s; - }, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]); + }, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]); const tooltip = useMemo(() => { - if (isConnectionInProgress && connectionError) { - return connectionError; + if (isConnectionInProgress && validationResult.messageTKey) { + return t(validationResult.messageTKey); } return fieldTypeName; - }, [connectionError, fieldTypeName, isConnectionInProgress]); + }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]); return ( { - const { t } = useTranslation(); const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); - const fieldInstance = useFieldInputInstance(nodeId, fieldName); const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const [isHovered, setIsHovered] = useState(false); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { @@ -55,20 +51,6 @@ const InputField = ({ nodeId, fieldName }: Props) => { setIsHovered(false); }, []); - if (!fieldTemplate || !fieldInstance) { - return ( - - - - {t('nodes.unknownInput', { - name: fieldInstance?.label ?? fieldTemplate?.title ?? fieldName, - })} - - - - ); - } - if (fieldTemplate.input === 'connection' || isConnected) { return ( @@ -88,7 +70,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); @@ -126,7 +108,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> )} @@ -134,27 +116,3 @@ const InputField = ({ nodeId, fieldName }: Props) => { }; export default memo(InputField); - -type InputFieldWrapperProps = PropsWithChildren<{ - shouldDim: boolean; -}>; - -const InputFieldWrapper = memo(({ shouldDim, children }: InputFieldWrapperProps) => { - return ( - - {children} - - ); -}); - -InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index b6e331c114..99937ceec4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,3 +1,4 @@ +import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { @@ -23,6 +24,8 @@ import { isLoRAModelFieldInputTemplate, isMainModelFieldInputInstance, isMainModelFieldInputTemplate, + isModelIdentifierFieldInputInstance, + isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, isSDXLMainModelFieldInputInstance, @@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx new file mode 100644 index 0000000000..8723538f85 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx @@ -0,0 +1,27 @@ +import { Flex } from '@invoke-ai/ui-library'; +import type { PropsWithChildren } from 'react'; +import { memo } from 'react'; + +type InputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +export const InputFieldWrapper = memo(({ shouldDim, children }: InputFieldWrapperProps) => { + return ( + + {children} + + ); +}); + +InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx new file mode 100644 index 0000000000..f4b6be0cd6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx @@ -0,0 +1,59 @@ +import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectInvocationNode } from 'features/nodes/store/selectors'; +import type { PropsWithChildren } from 'react'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +type Props = PropsWithChildren<{ + nodeId: string; + fieldName: string; +}>; + +export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }: Props) => { + const { t } = useTranslation(); + const templates = useStore($templates); + const selector = useMemo( + () => + createSelector(selectNodesSlice, (nodesSlice) => { + const node = selectInvocationNode(nodesSlice, nodeId); + const instance = node.data.inputs[fieldName]; + const template = templates[node.data.type]; + const fieldTemplate = template?.inputs[fieldName]; + return { + name: instance?.label || fieldTemplate?.title || fieldName, + hasInstance: Boolean(instance), + hasTemplate: Boolean(fieldTemplate), + }; + }), + [fieldName, nodeId, templates] + ); + const { hasInstance, hasTemplate, name } = useAppSelector(selector); + + if (!hasTemplate || !hasInstance) { + return ( + + + + {t('nodes.unknownInput', { name })} + + + + ); + } + + return children; +}); + +InvocationInputFieldCheck.displayName = 'InvocationInputFieldCheck'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index 0cd199f7a4..ef466b2882 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities'; import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; +import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice'; @@ -20,7 +21,7 @@ type Props = { fieldName: string; }; -const LinearViewField = ({ nodeId, fieldName }: Props) => { +const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => { const dispatch = useAppDispatch(); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); @@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { ); }; +const LinearViewField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index f2d776a2da..94e8b62744 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'outputs' }); if (!fieldTemplate) { @@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { handleType="source" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx new file mode 100644 index 0000000000..4019689978 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx @@ -0,0 +1,66 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice'; +import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback, useMemo } from 'react'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const ModelIdentifierFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const { data, isLoading } = useGetModelConfigsQuery(); + const _onChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldModelIdentifierValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const modelConfigs = useMemo(() => { + if (!data) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(data); + }, [data]); + + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + groupByType: true, + }); + + return ( + + + + + + ); +}; + +export default memo(ModelIdentifierFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx index 966809cb0e..76666af396 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx @@ -48,7 +48,7 @@ const NotesNode = (props: NodeProps) => { gap={1} > - + > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index 57426982ef..983aee1d48 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -1,14 +1,15 @@ import type { ChakraProps } from '@invoke-ai/ui-library'; import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; import { useExecutionState } from 'features/nodes/hooks/useExecutionState'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; -import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice'; +import { nodesChanged } from 'features/nodes/store/nodesSlice'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants'; import { zNodeStatus } from 'features/nodes/types/invocation'; import type { MouseEvent, PropsWithChildren } from 'react'; import { memo, useCallback } from 'react'; +import type { NodeChange } from 'reactflow'; type NodeWrapperProps = PropsWithChildren & { nodeId: string; @@ -18,6 +19,7 @@ type NodeWrapperProps = PropsWithChildren & { const NodeWrapper = (props: NodeWrapperProps) => { const { nodeId, width, children, selected } = props; + const store = useAppStore(); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); const executionState = useExecutionState(nodeId); @@ -37,11 +39,20 @@ const NodeWrapper = (props: NodeWrapperProps) => { const handleClick = useCallback( (e: MouseEvent) => { if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) { - dispatch(nodeExclusivelySelected(nodeId)); + const { nodes } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (selected !== (id === nodeId)) { + nodeChanges.push({ type: 'select', id, selected: id === nodeId }); + } + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } } onCloseGlobal(); }, - [dispatch, onCloseGlobal, nodeId] + [onCloseGlobal, store, dispatch, nodeId] ); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/ClearFlowButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/ClearFlowButton.tsx index 9a675c7214..7ceb991bd8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/ClearFlowButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopPanel/ClearFlowButton.tsx @@ -1,8 +1,7 @@ import { ConfirmationAlertDialog, Flex, IconButton, Text, useDisclosure } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleFill } from 'react-icons/pi'; @@ -16,14 +15,11 @@ const ClearFlowButton = () => { const handleNewWorkflow = useCallback(() => { dispatch(nodeEditorReset()); - dispatch( - addToast( - makeToast({ - title: t('workflows.workflowCleared'), - status: 'success', - }) - ) - ); + toast({ + id: 'WORKFLOW_CLEARED', + title: t('workflows.workflowCleared'), + status: 'success', + }); onClose(); }, [dispatch, onClose, t]); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx index 14852945ab..b983e12e11 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx @@ -7,8 +7,10 @@ import WorkflowInfoTooltipContent from './viewMode/WorkflowInfoTooltipContent'; import { WorkflowWarning } from './viewMode/WorkflowWarning'; export const WorkflowName = () => { - const { name, isTouched, mode } = useAppSelector((s) => s.workflow); const { t } = useTranslation(); + const name = useAppSelector((s) => s.workflow.name); + const isTouched = useAppSelector((s) => s.workflow.isTouched); + const mode = useAppSelector((s) => s.workflow.mode); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index e707dd4f54..482de6693e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -1,6 +1,7 @@ import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent'; import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer'; +import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck'; import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; @@ -14,7 +15,7 @@ type Props = { fieldName: string; }; -const WorkflowField = ({ nodeId, fieldName }: Props) => { +const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); @@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { ); }; +const WorkflowField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(WorkflowField); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx index fa1767138e..9b0e5bb9d6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx @@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DndSortable from 'features/dnd/components/DndSortable'; import type { DragEndEvent } from 'features/dnd/types'; -import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; +import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice'; import type { FieldIdentifier } from 'features/nodes/types/field'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; @@ -40,16 +40,18 @@ const WorkflowLinearTab = () => { [dispatch, fields] ); + const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]); + return ( - `${field.nodeId}.${field.fieldName}`)}> + {isLoading ? ( ) : fields.length ? ( fields.map(({ nodeId, fieldName }) => ( - + )) ) : ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts deleted file mode 100644 index 3b7a1b74c1..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { EMPTY_ARRAY } from 'app/store/constants'; -import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; -import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; -import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; -import { keys, map } from 'lodash-es'; -import { useMemo } from 'react'; - -export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { - const template = useNodeTemplate(nodeId); - - const fieldNames = useMemo(() => { - const fields = map(template.inputs).filter((field) => { - return ( - (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && - keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) - ); - }); - const _fieldNames = getSortedFilteredFieldNames(fields); - if (_fieldNames.length === 0) { - return EMPTY_ARRAY; - } - return _fieldNames; - }, [template.inputs]); - - return fieldNames; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index df628ba5af..0bca73731e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -2,58 +2,69 @@ import { useStore } from '@nanostores/react'; import { useAppStore } from 'app/store/storeHooks'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { + $didUpdateEdge, + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, $pendingConnection, $templates, - connectionMade, + edgesChanged, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { useCallback, useMemo } from 'react'; -import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import { useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; export const useConnection = () => { const store = useAppStore(); const templates = useStore($templates); + const updateNodeInternals = useUpdateNodeInternals(); const onConnectStart = useCallback( - (event, params) => { + (event, { nodeId, handleId, handleType }) => { + assert(nodeId && handleId && handleType, 'Invalid connection start event'); const nodes = store.getState().nodes.present.nodes; - const { nodeId, handleId, handleType } = params; - assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`); + const node = nodes.find((n) => n.id === nodeId); - assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`); + if (!node) { + return; + } + const template = templates[node.data.type]; - assert(template, `Template not found for node type: ${node.data.type}`); - const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId]; - assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`); - $pendingConnection.set({ - node, - template, - fieldTemplate, - }); + if (!template) { + return; + } + + const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs']; + const fieldTemplate = fieldTemplates[handleId]; + if (!fieldTemplate) { + return; + } + + $pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate }); }, [store, templates] ); const onConnect = useCallback( (connection) => { const { dispatch } = store; - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); + updateNodeInternals([newEdge.source, newEdge.target]); $pendingConnection.set(null); }, - [store] + [store, updateNodeInternals] ); const onConnectEnd = useCallback(() => { const { dispatch } = store; const pendingConnection = $pendingConnection.get(); - const isUpdatingEdge = $isUpdatingEdge.get(); + const edgePendingUpdate = $edgePendingUpdate.get(); const mouseOverNodeId = $mouseOverNode.get(); // If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge // update logic can finish up - if (isUpdatingEdge && !mouseOverNodeId) { + if (edgePendingUpdate && !mouseOverNodeId) { $pendingConnection.set(null); return; } @@ -63,30 +74,41 @@ export const useConnection = () => { } const { nodes, edges } = store.getState().nodes.present; if (mouseOverNodeId) { - const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId); - if (!candidateNode) { - // The mouse is over a non-invocation node - bail - return; - } - const candidateTemplate = templates[candidateNode.data.type]; - assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`); + const { handleType } = pendingConnection; + const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const connection = getFirstValidConnection( - templates, + source, + sourceHandle, + target, + targetHandle, nodes, edges, - pendingConnection, - candidateNode, - candidateTemplate + templates, + edgePendingUpdate ); if (connection) { - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + const edgeChanges: EdgeChange[] = [{ type: 'add', item: newEdge }]; + + const nodesToUpdate = [newEdge.source, newEdge.target]; + if (edgePendingUpdate) { + $didUpdateEdge.set(true); + edgeChanges.push({ type: 'remove', id: edgePendingUpdate.id }); + nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target); + } + dispatch(edgesChanged(edgeChanges)); + updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); } else { // The mouse is not over a node - we should open the add node popover $isAddNodePopoverOpen.set(true); } - }, [store, templates]); + }, [store, templates, updateNodeInternals]); const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]); return api; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts deleted file mode 100644 index d071ac76d2..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { EMPTY_ARRAY } from 'app/store/constants'; -import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; -import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; -import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; -import { keys, map } from 'lodash-es'; -import { useMemo } from 'react'; - -export const useConnectionInputFieldNames = (nodeId: string): string[] => { - const template = useNodeTemplate(nodeId); - const fieldNames = useMemo(() => { - // get the visible fields - const fields = map(template.inputs).filter( - (field) => - (field.input === 'connection' && !field.type.isCollectionOrScalar) || - !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) - ); - - const _fieldNames = getSortedFilteredFieldNames(fields); - - if (_fieldNames.length === 0) { - return EMPTY_ARRAY; - } - - return _fieldNames; - }, [template.inputs]); - - return fieldNames; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 728b492453..64bb72c54e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -1,12 +1,10 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector'; +import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; -import { useFieldType } from './useFieldType.ts'; - type UseConnectionStateProps = { nodeId: string; fieldName: string; @@ -16,7 +14,7 @@ type UseConnectionStateProps = { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); - const fieldType = useFieldType(nodeId, fieldName, kind); + const edgePendingUpdate = useStore($edgePendingUpdate); const selectIsConnected = useMemo( () => @@ -33,17 +31,9 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta [fieldName, kind, nodeId] ); - const selectConnectionError = useMemo( - () => - makeConnectionErrorSelector( - templates, - pendingConnection, - nodeId, - fieldName, - kind === 'inputs' ? 'target' : 'source', - fieldType - ), - [templates, pendingConnection, nodeId, fieldName, kind, fieldType] + const selectValidationResult = useMemo( + () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'), + [templates, nodeId, fieldName, kind] ); const isConnected = useAppSelector(selectIsConnected); @@ -53,23 +43,23 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta return false; } return ( - pendingConnection.node.id === nodeId && - pendingConnection.fieldTemplate.name === fieldName && + pendingConnection.nodeId === nodeId && + pendingConnection.handleId === fieldName && pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector(selectConnectionError); + const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate)); const shouldDim = useMemo( - () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), - [connectionError, isConnectionInProgress, isConnectionStartField] + () => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField), + [validationResult, isConnectionInProgress, isConnectionStartField] ); return { isConnected, isConnectionInProgress, isConnectionStartField, - connectionError, + validationResult, shouldDim, }; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts index 08def1514c..32db806cde 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts @@ -5,11 +5,13 @@ import { $copiedNodes, $cursorPos, $edgesToCopiedNodes, - selectionPasted, + edgesChanged, + nodesChanged, selectNodesSlice, } from 'features/nodes/store/nodesSlice'; import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { isEqual, uniqWith } from 'lodash-es'; +import type { EdgeChange, NodeChange } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; const copySelection = () => { @@ -26,7 +28,7 @@ const copySelection = () => { const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const { getState, dispatch } = getStore(); - const currentNodes = selectNodesSlice(getState()).nodes; + const { nodes, edges } = selectNodesSlice(getState()); const cursorPos = $cursorPos.get(); const copiedNodes = deepClone($copiedNodes.get()); @@ -46,7 +48,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const offsetY = cursorPos ? cursorPos.y - minY : 50; copiedNodes.forEach((node) => { - const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY); + const { x, y } = findUnoccupiedPosition(nodes, node.position.x + offsetX, node.position.y + offsetY); node.position.x = x; node.position.y = y; // Pasted nodes are selected @@ -68,7 +70,48 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { node.data.id = id; }); - dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges })); + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + // Deselect existing nodes + nodes.forEach(({ id, selected }) => { + if (selected) { + nodeChanges.push({ + type: 'select', + id, + selected: false, + }); + } + }); + // Add new nodes + copiedNodes.forEach((n) => { + nodeChanges.push({ + type: 'add', + item: n, + }); + }); + // Deselect existing edges + edges.forEach(({ id, selected }) => { + if (selected) { + edgeChanges.push({ + type: 'select', + id, + selected: false, + }); + } + }); + // Add new edges + copiedEdges.forEach((e) => { + edgeChanges.push({ + type: 'add', + item: e, + }); + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }; const api = { copySelection, pasteSelection }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 4b70847ad1..729319e0dd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,9 +1,14 @@ import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; +import { assert } from 'tsafe'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate => { const template = useNodeTemplate(nodeId); - const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]); + const fieldTemplate = useMemo(() => { + const _fieldTemplate = template.inputs[fieldName]; + assert(_fieldTemplate, `Field template for field ${fieldName} not found`); + return _fieldTemplate; + }, [fieldName, template.inputs]); return fieldTemplate; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldNames.ts new file mode 100644 index 0000000000..19849fb296 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldNames.ts @@ -0,0 +1,39 @@ +import { useNodeData } from 'features/nodes/hooks/useNodeData'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; +import { isSingleOrCollection } from 'features/nodes/types/field'; +import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; +import { difference, filter, keys } from 'lodash-es'; +import { useMemo } from 'react'; + +const isConnectionInputField = (field: FieldInputTemplate) => { + return ( + (field.input === 'connection' && !isSingleOrCollection(field.type)) || + !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) + ); +}; + +const isAnyOrDirectInputField = (field: FieldInputTemplate) => { + return ( + (['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && + keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) + ); +}; + +export const useFieldNames = (nodeId: string) => { + const template = useNodeTemplate(nodeId); + const node = useNodeData(nodeId); + const fieldNames = useMemo(() => { + const instanceFields = keys(node.inputs); + const allTemplateFields = keys(template.inputs); + const missingFields = difference(instanceFields, allTemplateFields); + const connectionFields = filter(template.inputs, isConnectionInputField).map((f) => f.name); + const anyOrDirectFields = filter(template.inputs, isAnyOrDirectInputField).map((f) => f.name); + return { + missingFields, + connectionFields, + anyOrDirectFields, + }; + }, [node.inputs, template.inputs]); + return fieldNames; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts deleted file mode 100644 index 90c08a94aa..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; -import type { FieldType } from 'features/nodes/types/field'; -import { useMemo } from 'react'; - -export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => { - const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); - const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]); - return fieldType; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 00b4b40176..9a978b09a8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,14 +1,10 @@ // TODO: enable this at some point import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; -import { $templates } from 'features/nodes/store/nodesSlice'; -import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; -import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; -import type { InvocationNodeData } from 'features/nodes/types/invocation'; -import { isEqual } from 'lodash-es'; +import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { useCallback } from 'react'; -import type { Connection, Node } from 'reactflow'; +import type { Connection } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -25,75 +21,21 @@ export const useIsValidConnection = () => { if (!(source && sourceHandle && target && targetHandle)) { return false; } + const edgePendingUpdate = $edgePendingUpdate.get(); + const { nodes, edges } = store.getState().nodes.present; - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled - return false; - } + const validationResult = validateConnection( + { source, sourceHandle, target, targetHandle }, + nodes, + edges, + templates, + edgePendingUpdate, + shouldValidateGraph + ); - const state = store.getState(); - const { nodes, edges } = state.nodes.present; - - // Find the source and target nodes - const sourceNode = nodes.find((node) => node.id === source) as Node; - const targetNode = nodes.find((node) => node.id === target) as Node; - const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; - const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - - // Conditional guards against undefined nodes/handles - if (!(sourceFieldTemplate && targetFieldTemplate)) { - return false; - } - - if (targetFieldTemplate.input === 'direct') { - return false; - } - - if (!shouldValidateGraph) { - // manual override! - return true; - } - - if ( - edges.find((edge) => { - edge.target === target && - edge.targetHandle === targetHandle && - edge.source === source && - edge.sourceHandle === sourceHandle; - }) - ) { - // We already have a connection from this source to this target - return false; - } - - if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - return isEqual(sourceFieldTemplate.type, collectItemType); - } - } - - // Connection is invalid if target already has a connection - if ( - edges.find((edge) => { - return edge.target === target && edge.targetHandle === targetHandle; - }) && - // except CollectionItem inputs can have multiples - targetFieldTemplate.type.name !== 'CollectionItemField' - ) { - return false; - } - - // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return false; - } - - // Graphs much be acyclic (no loops!) - return getIsGraphAcyclic(source, target, nodes, edges); + return validationResult.isValid; }, - [shouldValidateGraph, templates, store] + [templates, shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index 31dcb9c466..56e77a39e8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,14 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, (nodes) => { - return selectNodeData(nodes, nodeId)?.label ?? null; + createSelector(selectNodesSlice, (nodesSlice) => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + return node?.data.label; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index a63e0433aa..39ae617460 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,8 +1,24 @@ -import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { useStore } from '@nanostores/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; export const useNodeTemplateTitle = (nodeId: string): string | null => { - const template = useNodeTemplate(nodeId); - const title = useMemo(() => template.title, [template.title]); + const templates = useStore($templates); + const selector = useMemo( + () => + createSelector(selectNodesSlice, (nodesSlice) => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + const template = templates[node.data.type]; + return template?.title ?? null; + }), + [nodeId, templates] + ); + const title = useAppSelector(selector); return title; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts index df4b742842..7f531c3dba 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts @@ -1,4 +1,4 @@ -import type { FieldType } from 'features/nodes/types/field'; +import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field'; import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -10,13 +10,13 @@ export const useFieldTypeName = (fieldType?: FieldType): string => { return ''; } const { name } = fieldType; - if (fieldType.isCollection) { + if (isCollection(fieldType)) { return t('nodes.collectionFieldType', { name }); } - if (fieldType.isCollectionOrScalar) { + if (isSingleOrCollection(fieldType)) { return t('nodes.collectionOrScalarFieldType', { name }); } - return name; + return t('nodes.singleFieldType', { name }); }, [fieldType, t]); return name; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 1f61c77e83..5ebc5de147 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -16,6 +16,7 @@ import type { IPAdapterModelFieldValue, LoRAModelFieldValue, MainModelFieldValue, + ModelIdentifierFieldValue, SchedulerFieldValue, SDXLRefinerModelFieldValue, StatefulFieldValue, @@ -35,6 +36,7 @@ import { zIPAdapterModelFieldValue, zLoRAModelFieldValue, zMainModelFieldValue, + zModelIdentifierFieldValue, zSchedulerFieldValue, zSDXLRefinerModelFieldValue, zStatefulFieldValue, @@ -45,13 +47,13 @@ import { import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; -import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; -import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; +import type { MouseEvent } from 'react'; +import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow'; +import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; import type { NodesState, PendingConnection, Templates } from './types'; -import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; const initialNodesState: NodesState = { _version: 1, @@ -90,44 +92,47 @@ export const nodesSlice = createSlice({ reducers: { nodesChanged: (state, action: PayloadAction) => { state.nodes = applyNodeChanges(action.payload, state.nodes); - }, - nodeReplaced: (state, action: PayloadAction<{ nodeId: string; node: Node }>) => { - const nodeIndex = state.nodes.findIndex((n) => n.id === action.payload.nodeId); - if (nodeIndex < 0) { - return; - } - state.nodes[nodeIndex] = action.payload.node; - }, - nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => { - const { node, cursorPos } = action.payload; - const position = findUnoccupiedPosition( - state.nodes, - cursorPos?.x ?? node.position.x, - cursorPos?.y ?? node.position.y - ); - node.position = position; - node.selected = true; - - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })), - state.nodes - ); - - state.edges = applyEdgeChanges( - state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })), - state.edges - ); - - state.nodes.push(node); + // Remove edges that are no longer valid, due to a removed or otherwise changed node + const edgeChanges: EdgeChange[] = []; + state.edges.forEach((e) => { + const sourceExists = state.nodes.some((n) => n.id === e.source); + const targetExists = state.nodes.some((n) => n.id === e.target); + if (!(sourceExists && targetExists)) { + edgeChanges.push({ type: 'remove', id: e.id }); + } + }); + state.edges = applyEdgeChanges(edgeChanges, state.edges); }, edgesChanged: (state, action: PayloadAction) => { - state.edges = applyEdgeChanges(action.payload, state.edges); - }, - edgeAdded: (state, action: PayloadAction) => { - state.edges = addEdge(action.payload, state.edges); - }, - connectionMade: (state, action: PayloadAction) => { - state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); + const changes: EdgeChange[] = []; + // We may need to massage the edge changes or otherwise handle them + action.payload.forEach((change) => { + if (change.type === 'remove' || change.type === 'select') { + const edge = state.edges.find((e) => e.id === change.id); + // If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them + if (edge && edge.type === 'collapsed') { + const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target); + if (change.type === 'remove') { + hiddenEdges.forEach(({ id }) => { + changes.push({ type: 'remove', id }); + }); + } + if (change.type === 'select') { + hiddenEdges.forEach(({ id }) => { + changes.push({ type: 'select', id, selected: change.selected }); + }); + } + } + } + if (change.type === 'add') { + if (!change.item.type) { + // We must add the edge type! + change.item.type = 'default'; + } + } + changes.push(change); + }); + state.edges = applyEdgeChanges(changes, state.edges); }, fieldLabelChanged: ( state, @@ -232,6 +237,7 @@ export const nodesSlice = createSlice({ type: 'collapsed', data: { count: 1 }, updatable: false, + selected: edge.selected, }); } } @@ -252,6 +258,7 @@ export const nodesSlice = createSlice({ type: 'collapsed', data: { count: 1 }, updatable: false, + selected: edge.selected, }); } } @@ -264,41 +271,13 @@ export const nodesSlice = createSlice({ } } }, - edgeDeleted: (state, action: PayloadAction) => { - state.edges = state.edges.filter((e) => e.id !== action.payload); - }, - edgesDeleted: (state, action: PayloadAction) => { - const edges = action.payload; - const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); - - // if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes - if (collapsedEdges.length) { - const edgeChanges: EdgeRemoveChange[] = []; - collapsedEdges.forEach((collapsedEdge) => { - state.edges.forEach((edge) => { - if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) { - edgeChanges.push({ id: edge.id, type: 'remove' }); - } - }); - }); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - } - }, - nodesDeleted: (state, action: PayloadAction) => { - action.payload.forEach((node) => { - if (!isInvocationNode(node)) { - return; - } - }); - }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); const node = state.nodes?.[nodeIndex]; - if (!isInvocationNode(node)) { - return; + if (isInvocationNode(node) || isNotesNode(node)) { + node.data.label = label; } - node.data.label = label; }, nodeNotesChanged: (state, action: PayloadAction<{ nodeId: string; notes: string }>) => { const { nodeId, notes } = action.payload; @@ -309,17 +288,6 @@ export const nodesSlice = createSlice({ } node.data.notes = notes; }, - nodeExclusivelySelected: (state, action: PayloadAction) => { - const nodeId = action.payload; - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ - id: n.id, - type: 'select', - selected: n.id === nodeId ? true : false, - })), - state.nodes - ); - }, fieldValueReset: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zStatefulFieldValue); }, @@ -344,6 +312,9 @@ export const nodesSlice = createSlice({ fieldMainModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zMainModelFieldValue); }, + fieldModelIdentifierValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zModelIdentifierFieldValue); + }, fieldRefinerModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); }, @@ -381,57 +352,6 @@ export const nodesSlice = createSlice({ state.nodes = []; state.edges = []; }, - selectedAll: (state) => { - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })), - state.nodes - ); - state.edges = applyEdgeChanges( - state.edges.map((e) => ({ id: e.id, type: 'select', selected: true })), - state.edges - ); - }, - selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => { - const { nodes, edges } = action.payload; - - const nodeChanges: NodeChange[] = []; - - // Deselect existing nodes - state.nodes.forEach((n) => { - nodeChanges.push({ - id: n.data.id, - type: 'select', - selected: false, - }); - }); - // Add new nodes - nodes.forEach((n) => { - nodeChanges.push({ - item: n, - type: 'add', - }); - }); - - const edgeChanges: EdgeChange[] = []; - // Deselect existing edges - state.edges.forEach((e) => { - edgeChanges.push({ - id: e.id, - type: 'select', - selected: false, - }); - }); - // Add new edges - edges.forEach((e) => { - edgeChanges.push({ - item: e, - type: 'add', - }); - }); - - state.nodes = applyNodeChanges(nodeChanges, state.nodes); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - }, undo: (state) => state, redo: (state) => state, }, @@ -440,13 +360,13 @@ export const nodesSlice = createSlice({ const { nodes, edges } = action.payload; state.nodes = applyNodeChanges( nodes.map((node) => ({ - item: { ...node, ...SHARED_NODE_PROPERTIES }, type: 'add', + item: { ...node, ...SHARED_NODE_PROPERTIES }, })), [] ); state.edges = applyEdgeChanges( - edges.map((edge) => ({ item: edge, type: 'add' })), + edges.map((edge) => ({ type: 'add', item: edge })), [] ); }); @@ -454,10 +374,7 @@ export const nodesSlice = createSlice({ }); export const { - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldValueReset, fieldBoardValueChanged, fieldBooleanValueChanged, @@ -469,27 +386,21 @@ export const { fieldT2IAdapterModelValueChanged, fieldLabelChanged, fieldLoRAModelValueChanged, + fieldModelIdentifierValueChanged, fieldMainModelValueChanged, fieldNumberValueChanged, fieldRefinerModelValueChanged, fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, - nodeReplaced, nodeEditorReset, - nodeExclusivelySelected, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, nodesChanged, - nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, - selectedAll, - selectionPasted, - edgeAdded, undo, redo, } = nodesSlice.actions; @@ -500,7 +411,10 @@ export const $copiedNodes = atom([]); export const $copiedEdges = atom([]); export const $edgesToCopiedNodes = atom([]); export const $pendingConnection = atom(null); -export const $isUpdatingEdge = atom(false); +export const $edgePendingUpdate = atom(null); +export const $didUpdateEdge = atom(false); +export const $lastEdgeUpdateMouseEvent = atom(null); + export const $viewport = atom({ x: 0, y: 0, zoom: 1 }); export const $isAddNodePopoverOpen = atom(false); export const closeAddNodePopover = () => { @@ -528,17 +442,17 @@ export const nodesPersistConfig: PersistConfig = { persistDenylist: [], }; -const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected); - const isSelectionAction = (action: UnknownAction) => { - if (selectionMatcher(action)) { - return true; - } if (nodesChanged.match(action)) { if (action.payload.every((change) => change.type === 'select')) { return true; } } + if (edgesChanged.match(action)) { + if (action.payload.every((change) => change.type === 'select')) { + return true; + } + } return false; }; @@ -574,10 +488,7 @@ export const nodesUndoableConfig: UndoableOptions = { // This is used for tracking `state.workflow.isTouched` export const isAnyNodeOrEdgeMutation = isAnyOf( - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldBoardValueChanged, fieldBooleanValueChanged, fieldColorValueChanged, @@ -594,15 +505,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, - nodeReplaced, + nodesChanged, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, - nodesDeleted, nodeUseCacheChanged, - notesNodeValueChanged, - selectionPasted, - edgeAdded + notesNodeValueChanged ); diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts index 4739a77e1c..be8cfafa8b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/selectors.ts +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -4,7 +4,7 @@ import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/in import { isInvocationNode } from 'features/nodes/types/invocation'; import { assert } from 'tsafe'; -const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => { +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => { const node = nodesSlice.nodes.find((node) => node.id === nodeId); assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`); return node; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 2f514bdb5b..6dcf70cfad 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -6,19 +6,20 @@ import type { } from 'features/nodes/types/field'; import type { AnyNode, - InvocationNode, InvocationNodeEdge, InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import type { HandleType } from 'reactflow'; export type Templates = Record; export type NodeExecutionStates = Record; export type PendingConnection = { - node: InvocationNode; - template: InvocationTemplate; + nodeId: string; + handleId: string; + handleType: HandleType; fieldTemplate: FieldInputTemplate | FieldOutputTemplate; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts new file mode 100644 index 0000000000..ae9d4f6742 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts @@ -0,0 +1,86 @@ +import type { FieldType } from 'features/nodes/types/field'; +import { describe, expect, it } from 'vitest'; + +import { areTypesEqual } from './areTypesEqual'; + +describe(areTypesEqual.name, () => { + it('should handle equal source and target type', () => { + const sourceType: FieldType = { + name: 'IntegerField', + cardinality: 'SINGLE', + originalType: { + name: 'Foo', + cardinality: 'SINGLE', + }, + }; + const targetType: FieldType = { + name: 'IntegerField', + cardinality: 'SINGLE', + originalType: { + name: 'Bar', + cardinality: 'SINGLE', + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal source type and original target type', () => { + const sourceType: FieldType = { + name: 'IntegerField', + cardinality: 'SINGLE', + originalType: { + name: 'Foo', + cardinality: 'SINGLE', + }, + }; + const targetType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', + originalType: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and target type', () => { + const sourceType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', + originalType: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + }; + const targetType: FieldType = { + name: 'IntegerField', + cardinality: 'SINGLE', + originalType: { + name: 'Bar', + cardinality: 'SINGLE', + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and original target type', () => { + const sourceType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', + originalType: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + }; + const targetType: FieldType = { + name: 'LoRAModelField', + cardinality: 'SINGLE', + originalType: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts new file mode 100644 index 0000000000..8502cb563c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts @@ -0,0 +1,29 @@ +import type { FieldType } from 'features/nodes/types/field'; +import { isEqual, omit } from 'lodash-es'; + +/** + * Checks if two types are equal. If the field types have original types, those are also compared. Any match is + * considered equal. For example, if the first type and original second type match, the types are considered equal. + * @param firstType The first type to compare. + * @param secondType The second type to compare. + * @returns True if the types are equal, false otherwise. + */ +export const areTypesEqual = (firstType: FieldType, secondType: FieldType) => { + const _firstType = 'originalType' in firstType ? omit(firstType, 'originalType') : firstType; + const _secondType = 'originalType' in secondType ? omit(secondType, 'originalType') : secondType; + const _originalFirstType = 'originalType' in firstType ? firstType.originalType : null; + const _originalSecondType = 'originalType' in secondType ? secondType.originalType : null; + if (isEqual(_firstType, _secondType)) { + return true; + } + if (_originalSecondType && isEqual(_firstType, _originalSecondType)) { + return true; + } + if (_originalFirstType && isEqual(_originalFirstType, _secondType)) { + return true; + } + if (_originalFirstType && _originalSecondType && isEqual(_originalFirstType, _originalSecondType)) { + return true; + } + return false; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts deleted file mode 100644 index 1f33c52371..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ /dev/null @@ -1,105 +0,0 @@ -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { differenceWith, isEqual, map } from 'lodash-es'; -import type { Connection } from 'reactflow'; -import { assert } from 'tsafe'; - -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; - -export const getFirstValidConnection = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate -): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self - return null; - } - - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct'); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName - ); - const candidateField = candidateUnconnectedFields.find((field) => - validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { - return null; - } - - if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) { - return null; - } - - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } - } - - return null; -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts new file mode 100644 index 0000000000..be0b553d8b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -0,0 +1,44 @@ +import { deepClone } from 'common/util/deepClone'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils'; +import type { FieldType } from 'features/nodes/types/field'; +import { unset } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +describe(getCollectItemType.name, () => { + it('should return the type of the items the collect node collects', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const result = getCollectItemType(templates, [n1, n2], [e1], n2.id); + expect(result).toEqual({ name: 'IntegerField', cardinality: 'SINGLE' }); + }); + it('should return null if the collect node does not have any connections', () => { + const n1 = buildNode(collect); + const result = getCollectItemType(templates, [n1], [], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to collect's node doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType(templates, [n1], [e1], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to collect's node template doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType({ collect }, [n1, n2], [e1], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to the collect's field template doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const addWithoutOutputValue = deepClone(add); + unset(addWithoutOutputValue, 'outputs.value'); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id); + expect(result).toBeNull(); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts new file mode 100644 index 0000000000..e6c117d91e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts @@ -0,0 +1,38 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; + +/** + * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and + * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no + * input field. + * @param templates The current invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param nodeId The collect node's id + * @returns The type of the items the collect node collects, or null if there is no input field + */ +export const getCollectItemType = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + nodeId: string +): FieldType | null => { + const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); + if (!firstEdgeToCollect?.sourceHandle) { + return null; + } + const node = nodes.find((n) => n.id === firstEdgeToCollect.source); + if (!node) { + return null; + } + const template = templates[node.data.type]; + if (!template) { + return null; + } + const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle]; + if (!fieldTemplate) { + return null; + } + return fieldTemplate.type; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts new file mode 100644 index 0000000000..7d04ea8a58 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -0,0 +1,203 @@ +import { deepClone } from 'common/util/deepClone'; +import { + getFirstValidConnection, + getSourceCandidateFields, + getTargetCandidateFields, +} from 'features/nodes/store/util/getFirstValidConnection'; +import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils'; +import { unset } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +describe('getFirstValidConnection', () => { + it('should return null if the pending and candidate nodes are the same node', () => { + const n = buildNode(add); + expect(getFirstValidConnection(n.id, 'value', n.id, null, [n], [], templates, null)).toBe(null); + }); + + it('should return null if the sourceHandle and targetHandle are null', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + expect(getFirstValidConnection(n1.id, null, n2.id, null, [n1, n2], [], templates, null)).toBe(null); + }); + + it('should return itself if both sourceHandle and targetHandle are provided', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + expect(getFirstValidConnection(n1.id, 'value', n2.id, 'a', [n1, n2], [], templates, null)).toEqual({ + source: n1.id, + sourceHandle: 'value', + target: n2.id, + targetHandle: 'a', + }); + }); + + describe('connecting from a source to a target', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + + it('should return the first valid connection if there are no connected fields', () => { + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is a connected field', () => { + const e = buildEdge(n1.id, 'height', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'height', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is an edgePendingUpdate', () => { + const e = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, e); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return null if the target has no valid fields', () => { + const e1 = buildEdge(n1.id, 'width', n2.id, 'width'); + const e2 = buildEdge(n1.id, 'height', n2.id, 'height'); + const n3 = buildNode(add); + const r = getFirstValidConnection(n3.id, 'value', n2.id, null, [n1, n2, n3], [e1, e2], templates, null); + expect(r).toEqual(null); + }); + }); + + describe('connecting from a target to a source', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + + it('should return the first valid connection if there are no connected fields', () => { + const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is a connected field', () => { + const e = buildEdge(n1.id, 'height', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, null, n2.id, 'height', [n1, n2], [e], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'height', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is an edgePendingUpdate', () => { + const e = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [e], templates, e); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return null if the target has no valid fields', () => { + const e1 = buildEdge(n1.id, 'width', n2.id, 'width'); + const e2 = buildEdge(n1.id, 'height', n2.id, 'height'); + const n3 = buildNode(add); + const r = getFirstValidConnection(n3.id, null, n2.id, 'a', [n1, n2, n3], [e1, e2], templates, null); + expect(r).toEqual(null); + }); + }); +}); + +describe('getTargetCandidateFields', () => { + it('should return an empty array if the nodes canot be found', () => { + const r = getTargetCandidateFields('missing', 'value', 'missing', [], [], templates, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the templates cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], {}, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the source field template cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + + const addWithoutOutputValue = deepClone(add); + unset(addWithoutOutputValue, 'outputs.value'); + + const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], { add: addWithoutOutputValue }, null); + expect(r).toEqual([]); + }); + it('should return all valid target fields if there are no connected fields', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, null); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); + it('should ignore the edgePendingUpdate if provided', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); +}); + +describe('getSourceCandidateFields', () => { + it('should return an empty array if the nodes canot be found', () => { + const r = getSourceCandidateFields('missing', 'value', 'missing', [], [], templates, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the templates cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + const r = getSourceCandidateFields(n2.id, 'a', n1.id, nodes, [], {}, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the source field template cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + + const addWithoutInputA = deepClone(add); + unset(addWithoutInputA, 'inputs.a'); + + const r = getSourceCandidateFields(n1.id, 'a', n2.id, nodes, [], { add: addWithoutInputA }, null); + expect(r).toEqual([]); + }); + it('should return all valid source fields if there are no connected fields', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, null); + expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); + }); + it('should ignore the edgePendingUpdate if provided', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate); + expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts new file mode 100644 index 0000000000..adc51341d7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -0,0 +1,149 @@ +import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; +import { map } from 'lodash-es'; +import type { Connection, Edge } from 'reactflow'; + +/** + * + * @param source The source (node id) + * @param sourceHandle The source handle (field name), if any + * @param target The target (node id) + * @param targetHandle The target handle (field name), if any + * @param nodes The current nodes + * @param edges The current edges + * @param templates The current templates + * @param edgePendingUpdate The edge pending update, if any + * @returns + */ +export const getFirstValidConnection = ( + source: string, + sourceHandle: string | null, + target: string, + targetHandle: string | null, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Templates, + edgePendingUpdate: Edge | null +): Connection | null => { + if (source === target) { + return null; + } + + if (sourceHandle && targetHandle) { + return { source, sourceHandle, target, targetHandle }; + } + + if (sourceHandle && !targetHandle) { + const candidates = getTargetCandidateFields( + source, + sourceHandle, + target, + nodes, + edges, + templates, + edgePendingUpdate + ); + + const firstCandidate = candidates[0]; + if (!firstCandidate) { + return null; + } + + return { source, sourceHandle, target, targetHandle: firstCandidate.name }; + } + + if (!sourceHandle && targetHandle) { + const candidates = getSourceCandidateFields( + target, + targetHandle, + source, + nodes, + edges, + templates, + edgePendingUpdate + ); + + const firstCandidate = candidates[0]; + if (!firstCandidate) { + return null; + } + + return { source, sourceHandle: firstCandidate.name, target, targetHandle }; + } + + return null; +}; + +export const getTargetCandidateFields = ( + source: string, + sourceHandle: string, + target: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldInputTemplate[] => { + const sourceNode = nodes.find((n) => n.id === source); + const targetNode = nodes.find((n) => n.id === target); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } + + const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { + const c = { source, sourceHandle, target, targetHandle: field.name }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return targetCandidateFields; +}; + +export const getSourceCandidateFields = ( + target: string, + targetHandle: string, + source: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldOutputTemplate[] => { + const targetNode = nodes.find((n) => n.id === target); + const sourceNode = nodes.find((n) => n.id === source); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + + const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => { + const c = { source, sourceHandle: field.name, target, targetHandle }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return sourceCandidateFields; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts new file mode 100644 index 0000000000..5b3a31de09 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts @@ -0,0 +1,22 @@ +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils'; +import { describe, expect, it } from 'vitest'; + +describe(getHasCycles.name, () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); + const nodes = [n1, n2, n3]; + + it('should return true if the graph WOULD have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')]; + const result = getHasCycles(n3.id, n1.id, nodes, edges); + expect(result).toBe(true); + }); + + it('should return false if the graph WOULD NOT have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a')]; + const result = getHasCycles(n2.id, n3.id, nodes, edges); + expect(result).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts new file mode 100644 index 0000000000..c1a4e51f0c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts @@ -0,0 +1,30 @@ +import graphlib from '@dagrejs/graphlib'; +import type { Edge, Node } from 'reactflow'; + +/** + * Check if adding an edge between the source and target nodes would create a cycle in the graph. + * @param source The source node id + * @param target The target node id + * @param nodes The graph's current nodes + * @param edges The graph's current edges + * @returns True if the graph would be acyclic after adding the edge, false otherwise + */ + +export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { + // construct graphlib graph from editor state + const g = new graphlib.Graph(); + + nodes.forEach((n) => { + g.setNode(n.id); + }); + + edges.forEach((e) => { + g.setEdge(e.source, e.target); + }); + + // add the candidate edge + g.setEdge(source, target); + + // check if the graph is acyclic + return !graphlib.alg.isAcyclic(g); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts b/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts deleted file mode 100644 index 2ef1c64c0e..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts +++ /dev/null @@ -1,21 +0,0 @@ -import graphlib from '@dagrejs/graphlib'; -import type { Edge, Node } from 'reactflow'; - -export const getIsGraphAcyclic = (source: string, target: string, nodes: Node[], edges: Edge[]) => { - // construct graphlib graph from editor state - const g = new graphlib.Graph(); - - nodes.forEach((n) => { - g.setNode(n.id); - }); - - edges.forEach((e) => { - g.setEdge(e.source, e.target); - }); - - // add the candidate edge - g.setEdge(source, target); - - // check if the graph is acyclic - return graphlib.alg.isAcyclic(g); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts new file mode 100644 index 0000000000..ec607c60c5 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -0,0 +1,67 @@ +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import type { RootState } from 'app/store/store'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; +import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection'; +import type { Edge, HandleType } from 'reactflow'; + +/** + * Creates a selector that validates a pending connection. + * + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` + * TODO: Figure out how to do this without duplicating all the logic + * + * @param templates The invocation templates + * @param nodeId The id of the node for which the selector is being created + * @param fieldName The name of the field for which the selector is being created + * @param handleType The type of the handle for which the selector is being created + * @returns + */ +export const makeConnectionErrorSelector = ( + templates: Templates, + nodeId: string, + fieldName: string, + handleType: HandleType +) => { + return createMemoizedSelector( + selectNodesSlice, + (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, + (state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => + edgePendingUpdate, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => { + const { nodes, edges } = nodesSlice; + + if (!pendingConnection) { + return buildRejectResult('nodes.noConnectionInProgress'); + } + + if (handleType === pendingConnection.handleType) { + if (handleType === 'source') { + return buildRejectResult('nodes.cannotConnectOutputToOutput'); + } + return buildRejectResult('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const source = handleType === 'source' ? nodeId : pendingConnection.nodeId; + const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId; + const target = handleType === 'target' ? nodeId : pendingConnection.nodeId; + const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId; + + const validationResult = validateConnection( + { + source, + sourceHandle, + target, + targetHandle, + }, + nodes, + edges, + templates, + edgePendingUpdate + ); + + return validationResult; + } + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts deleted file mode 100644 index 90e75e0d87..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ /dev/null @@ -1,147 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import type { FieldType } from 'features/nodes/types/field'; -import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; -import i18n from 'i18next'; -import { isEqual } from 'lodash-es'; -import type { HandleType } from 'reactflow'; -import { assert } from 'tsafe'; - -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; - -export const getCollectItemType = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - nodeId: string -): FieldType | null => { - const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); - if (!firstEdgeToCollect?.sourceHandle) { - return null; - } - const node = nodes.find((n) => n.id === firstEdgeToCollect.source); - if (!node) { - return null; - } - const template = templates[node.data.type]; - if (!template) { - return null; - } - const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; - return fieldType; -}; - -/** - * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` - * TODO: Figure out how to do this without duplicating all the logic - */ - -export const makeConnectionErrorSelector = ( - templates: Templates, - pendingConnection: PendingConnection | null, - nodeId: string, - fieldName: string, - handleType: HandleType, - fieldType?: FieldType | null -) => { - return createSelector(selectNodesSlice, (nodesSlice) => { - const { nodes, edges } = nodesSlice; - - if (!fieldType) { - return i18n.t('nodes.noFieldType'); - } - - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); - } - return i18n.t('nodes.cannotConnectInputToInput'); - } - - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; - - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } - - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); - - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!isEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); - } - } - } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateSourceAndTargetTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const isGraphAcyclic = getIsGraphAcyclic( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (!isGraphAcyclic) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - }); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts new file mode 100644 index 0000000000..89be7951a2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -0,0 +1,32 @@ +import type { Connection, Edge } from 'reactflow'; +import { assert } from 'tsafe'; + +/** + * Gets the edge id for a connection + * Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45 + * Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290 + * @param connection The connection to get the id for + * @returns The edge id + */ +const getEdgeId = (connection: Connection): string => { + const { source, sourceHandle, target, targetHandle } = connection; + return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`; +}; + +/** + * Converts a connection to an edge + * @param connection The connection to convert to an edge + * @returns The edge + * @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle) + */ +export const connectionToEdge = (connection: Connection): Edge => { + const { source, sourceHandle, target, targetHandle } = connection; + assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); + return { + source, + sourceHandle, + target, + targetHandle, + id: getEdgeId({ source, sourceHandle, target, targetHandle }), + }; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts new file mode 100644 index 0000000000..83988d55ea --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -0,0 +1,1562 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { Edge } from 'reactflow'; + +export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ + source, + sourceHandle, + target, + targetHandle, + type: 'default', + id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, +}); + +export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template); + +export const add: InvocationTemplate = { + title: 'Add Integers', + type: 'add', + version: '1.0.1', + tags: ['math', 'add'], + description: 'Adds two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const sub: InvocationTemplate = { + title: 'Subtract Integers', + type: 'sub', + version: '1.0.1', + tags: ['math', 'subtract'], + description: 'Subtracts two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const collect: InvocationTemplate = { + title: 'Collect', + type: 'collect', + version: '1.0.0', + tags: [], + description: 'Collects values into a collection', + outputType: 'collect_output', + inputs: { + item: { + name: 'item', + title: 'Collection Item', + required: false, + description: 'The item to collect (all inputs must be of the same type)', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionItemField', + type: { + name: 'CollectionItemField', + cardinality: 'SINGLE', + }, + }, + }, + outputs: { + collection: { + fieldKind: 'output', + name: 'collection', + title: 'Collection', + description: 'The collection of input items', + type: { + name: 'CollectionField', + cardinality: 'COLLECTION', + }, + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + useCache: true, + classification: 'stable', +}; + +const scheduler: InvocationTemplate = { + title: 'Scheduler', + type: 'scheduler', + version: '1.0.0', + tags: ['scheduler'], + description: 'Selects a scheduler.', + outputType: 'scheduler_output', + inputs: { + scheduler: { + name: 'scheduler', + title: 'Scheduler', + required: false, + description: 'Scheduler to use during inference', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + ui_type: 'SchedulerField', + type: { + name: 'SchedulerField', + cardinality: 'SINGLE', + + originalType: { + name: 'EnumField', + cardinality: 'SINGLE', + }, + }, + default: 'euler', + }, + }, + outputs: { + scheduler: { + fieldKind: 'output', + name: 'scheduler', + title: 'Scheduler', + description: 'Scheduler to use during inference', + type: { + name: 'SchedulerField', + cardinality: 'SINGLE', + + originalType: { + name: 'EnumField', + cardinality: 'SINGLE', + }, + }, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const main_model_loader: InvocationTemplate = { + title: 'Main Model', + type: 'main_model_loader', + version: '1.0.2', + tags: ['model'], + description: 'Loads a main model, outputting its submodels.', + outputType: 'model_loader_output', + inputs: { + model: { + name: 'model', + title: 'Model', + required: true, + description: 'Main model (UNet, VAE, CLIP) to load', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + ui_type: 'MainModelField', + type: { + name: 'MainModelField', + cardinality: 'SINGLE', + + originalType: { + name: 'ModelIdentifierField', + cardinality: 'SINGLE', + }, + }, + }, + }, + outputs: { + vae: { + fieldKind: 'output', + name: 'vae', + title: 'VAE', + description: 'VAE', + type: { + name: 'VAEField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + clip: { + fieldKind: 'output', + name: 'clip', + title: 'CLIP', + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + type: { + name: 'CLIPField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + unet: { + fieldKind: 'output', + name: 'unet', + title: 'UNet', + description: 'UNet (scheduler, LoRAs)', + type: { + name: 'UNetField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const img_resize: InvocationTemplate = { + title: 'Resize Image', + type: 'img_resize', + version: '1.2.2', + tags: ['image', 'resize'], + description: 'Resizes an image to specific dimensions', + outputType: 'image_output', + inputs: { + board: { + name: 'board', + title: 'Board', + required: false, + description: 'The board to save the image to', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + type: { + name: 'BoardField', + cardinality: 'SINGLE', + }, + }, + metadata: { + name: 'metadata', + title: 'Metadata', + required: false, + description: 'Optional metadata to be saved with the image', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + type: { + name: 'MetadataField', + cardinality: 'SINGLE', + }, + }, + image: { + name: 'image', + title: 'Image', + required: true, + description: 'The image to resize', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'ImageField', + cardinality: 'SINGLE', + }, + }, + width: { + name: 'width', + title: 'Width', + required: false, + description: 'The width to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 512, + exclusiveMinimum: 0, + }, + height: { + name: 'height', + title: 'Height', + required: false, + description: 'The height to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + default: 512, + exclusiveMinimum: 0, + }, + resample_mode: { + name: 'resample_mode', + title: 'Resample Mode', + required: false, + description: 'The resampling mode', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'EnumField', + cardinality: 'SINGLE', + }, + options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + default: 'bicubic', + }, + }, + outputs: { + image: { + fieldKind: 'output', + name: 'image', + title: 'Image', + description: 'The output image', + type: { + name: 'ImageField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + width: { + fieldKind: 'output', + name: 'width', + title: 'Width', + description: 'The width of the image in pixels', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + height: { + fieldKind: 'output', + name: 'height', + title: 'Height', + description: 'The height of the image in pixels', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +const iterate: InvocationTemplate = { + title: 'Iterate', + type: 'iterate', + version: '1.1.0', + tags: [], + description: 'Iterates over a list of items', + outputType: 'iterate_output', + inputs: { + collection: { + name: 'collection', + title: 'Collection', + required: false, + description: 'The list of items to iterate over', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionField', + type: { + name: 'CollectionField', + cardinality: 'COLLECTION', + }, + }, + }, + outputs: { + item: { + fieldKind: 'output', + name: 'item', + title: 'Collection Item', + description: 'The item being iterated over', + type: { + name: 'CollectionItemField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + index: { + fieldKind: 'output', + name: 'index', + title: 'Index', + description: 'The index of the item', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + total: { + fieldKind: 'output', + name: 'total', + title: 'Total', + description: 'The total number of items', + type: { + name: 'IntegerField', + cardinality: 'SINGLE', + }, + ui_hidden: false, + }, + }, + useCache: true, + classification: 'stable', +}; + +export const templates: Templates = { + add, + sub, + collect, + iterate, + scheduler, + main_model_loader, + img_resize, +}; + +export const schema = { + openapi: '3.1.0', + info: { + title: 'Invoke - Community Edition', + description: 'An API for invoking AI image operations', + version: '1.0.0', + }, + components: { + schemas: { + AddInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['add'], + const: 'add', + title: 'type', + default: 'add', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Add Integers', + description: 'Adds two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'add'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + IntegerOutput: { + description: 'Base class for nodes that output a single integer', + properties: { + value: { + description: 'The output integer', + field_kind: 'output', + title: 'Value', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'integer_output', + default: 'integer_output', + enum: ['integer_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['value', 'type', 'type'], + title: 'IntegerOutput', + type: 'object', + class: 'output', + }, + SchedulerInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + scheduler: { + type: 'string', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + title: 'Scheduler', + description: 'Scheduler to use during inference', + default: 'euler', + field_kind: 'input', + input: 'any', + orig_default: 'euler', + orig_required: false, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + type: 'string', + enum: ['scheduler'], + const: 'scheduler', + title: 'type', + default: 'scheduler', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Scheduler', + description: 'Selects a scheduler.', + category: 'latents', + classification: 'stable', + node_pack: 'invokeai', + tags: ['scheduler'], + version: '1.0.0', + output: { + $ref: '#/components/schemas/SchedulerOutput', + }, + class: 'invocation', + }, + SchedulerOutput: { + properties: { + scheduler: { + description: 'Scheduler to use during inference', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + field_kind: 'output', + title: 'Scheduler', + type: 'string', + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + const: 'scheduler_output', + default: 'scheduler_output', + enum: ['scheduler_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['scheduler', 'type', 'type'], + title: 'SchedulerOutput', + type: 'object', + class: 'output', + }, + MainModelLoaderInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + model: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Main model (UNet, VAE, CLIP) to load', + field_kind: 'input', + input: 'direct', + orig_required: true, + ui_hidden: false, + ui_type: 'MainModelField', + }, + type: { + type: 'string', + enum: ['main_model_loader'], + const: 'main_model_loader', + title: 'type', + default: 'main_model_loader', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['model', 'type', 'id'], + title: 'Main Model', + description: 'Loads a main model, outputting its submodels.', + category: 'model', + classification: 'stable', + node_pack: 'invokeai', + tags: ['model'], + version: '1.0.2', + output: { + $ref: '#/components/schemas/ModelLoaderOutput', + }, + class: 'invocation', + }, + ModelIdentifierField: { + properties: { + key: { + description: "The model's unique key", + title: 'Key', + type: 'string', + }, + hash: { + description: "The model's BLAKE3 hash", + title: 'Hash', + type: 'string', + }, + name: { + description: "The model's name", + title: 'Name', + type: 'string', + }, + base: { + allOf: [ + { + $ref: '#/components/schemas/BaseModelType', + }, + ], + description: "The model's base model type", + }, + type: { + allOf: [ + { + $ref: '#/components/schemas/ModelType', + }, + ], + description: "The model's type", + }, + submodel_type: { + anyOf: [ + { + $ref: '#/components/schemas/SubModelType', + }, + { + type: 'null', + }, + ], + default: null, + description: 'The submodel to load, if this is a main model', + }, + }, + required: ['key', 'hash', 'name', 'base', 'type'], + title: 'ModelIdentifierField', + type: 'object', + }, + BaseModelType: { + description: 'Base model type.', + enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + title: 'BaseModelType', + type: 'string', + }, + ModelType: { + description: 'Model type.', + enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], + title: 'ModelType', + type: 'string', + }, + SubModelType: { + description: 'Submodel type.', + enum: [ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', + ], + title: 'SubModelType', + type: 'string', + }, + ModelLoaderOutput: { + description: 'Model loader output', + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/VAEField', + }, + ], + description: 'VAE', + field_kind: 'output', + title: 'VAE', + ui_hidden: false, + }, + type: { + const: 'model_loader_output', + default: 'model_loader_output', + enum: ['model_loader_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + clip: { + allOf: [ + { + $ref: '#/components/schemas/CLIPField', + }, + ], + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + field_kind: 'output', + title: 'CLIP', + ui_hidden: false, + }, + unet: { + allOf: [ + { + $ref: '#/components/schemas/UNetField', + }, + ], + description: 'UNet (scheduler, LoRAs)', + field_kind: 'output', + title: 'UNet', + ui_hidden: false, + }, + }, + required: ['vae', 'type', 'clip', 'unet', 'type'], + title: 'ModelLoaderOutput', + type: 'object', + class: 'output', + }, + UNetField: { + properties: { + unet: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load unet submodel', + }, + scheduler: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load scheduler submodel', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + freeu_config: { + anyOf: [ + { + $ref: '#/components/schemas/FreeUConfig', + }, + { + type: 'null', + }, + ], + default: null, + description: 'FreeU configuration', + }, + }, + required: ['unet', 'scheduler', 'loras'], + title: 'UNetField', + type: 'object', + class: 'output', + }, + LoRAField: { + properties: { + lora: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load lora model', + }, + weight: { + description: 'Weight to apply to lora model', + title: 'Weight', + type: 'number', + }, + }, + required: ['lora', 'weight'], + title: 'LoRAField', + type: 'object', + class: 'output', + }, + FreeUConfig: { + description: + 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', + properties: { + s1: { + description: + 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S1', + type: 'number', + }, + s2: { + description: + 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S2', + type: 'number', + }, + b1: { + description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B1', + type: 'number', + }, + b2: { + description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B2', + type: 'number', + }, + }, + required: ['s1', 's2', 'b1', 'b2'], + title: 'FreeUConfig', + type: 'object', + class: 'output', + }, + VAEField: { + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load vae submodel', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + }, + required: ['vae'], + title: 'VAEField', + type: 'object', + class: 'output', + }, + CLIPField: { + properties: { + tokenizer: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load tokenizer submodel', + }, + text_encoder: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load text_encoder submodel', + }, + skipped_layers: { + description: 'Number of skipped layers in text_encoder', + title: 'Skipped Layers', + type: 'integer', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + }, + required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], + title: 'CLIPField', + type: 'object', + class: 'output', + }, + CollectInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + item: { + anyOf: [ + {}, + { + type: 'null', + }, + ], + title: 'Collection Item', + description: 'The item to collect (all inputs must be of the same type)', + field_kind: 'input', + input: 'connection', + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The collection, will be provided on execution', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['collect'], + const: 'collect', + title: 'type', + default: 'collect', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'CollectInvocation', + description: 'Collects values into a collection', + classification: 'stable', + version: '1.0.0', + output: { + $ref: '#/components/schemas/CollectInvocationOutput', + }, + class: 'invocation', + }, + CollectInvocationOutput: { + properties: { + collection: { + description: 'The collection of input items', + field_kind: 'output', + items: {}, + title: 'Collection', + type: 'array', + ui_hidden: false, + ui_type: 'CollectionField', + }, + type: { + const: 'collect_output', + default: 'collect_output', + enum: ['collect_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['collection', 'type', 'type'], + title: 'CollectInvocationOutput', + type: 'object', + class: 'output', + }, + SubtractInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['sub'], + const: 'sub', + title: 'type', + default: 'sub', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Subtract Integers', + description: 'Subtracts two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'subtract'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + ImageResizeInvocation: { + properties: { + board: { + anyOf: [ + { + $ref: '#/components/schemas/BoardField', + }, + { + type: 'null', + }, + ], + description: 'The board to save the image to', + field_kind: 'internal', + input: 'direct', + orig_required: false, + ui_hidden: false, + }, + metadata: { + anyOf: [ + { + $ref: '#/components/schemas/MetadataField', + }, + { + type: 'null', + }, + ], + description: 'Optional metadata to be saved with the image', + field_kind: 'internal', + input: 'connection', + orig_required: false, + ui_hidden: false, + }, + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The image to resize', + field_kind: 'input', + input: 'any', + orig_required: true, + ui_hidden: false, + }, + width: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Width', + description: 'The width to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + height: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Height', + description: 'The height to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + resample_mode: { + type: 'string', + enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + title: 'Resample Mode', + description: 'The resampling mode', + default: 'bicubic', + field_kind: 'input', + input: 'any', + orig_default: 'bicubic', + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['img_resize'], + const: 'img_resize', + title: 'type', + default: 'img_resize', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Resize Image', + description: 'Resizes an image to specific dimensions', + category: 'image', + classification: 'stable', + node_pack: 'invokeai', + tags: ['image', 'resize'], + version: '1.2.2', + output: { + $ref: '#/components/schemas/ImageOutput', + }, + class: 'invocation', + }, + ImageField: { + description: 'An image primitive field', + properties: { + image_name: { + description: 'The name of the image', + title: 'Image Name', + type: 'string', + }, + }, + required: ['image_name'], + title: 'ImageField', + type: 'object', + class: 'output', + }, + ImageOutput: { + description: 'Base class for nodes that output a single image', + properties: { + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The output image', + field_kind: 'output', + ui_hidden: false, + }, + width: { + description: 'The width of the image in pixels', + field_kind: 'output', + title: 'Width', + type: 'integer', + ui_hidden: false, + }, + height: { + description: 'The height of the image in pixels', + field_kind: 'output', + title: 'Height', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'image_output', + default: 'image_output', + enum: ['image_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['image', 'width', 'height', 'type', 'type'], + title: 'ImageOutput', + type: 'object', + class: 'output', + }, + MetadataField: { + description: + 'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.', + title: 'MetadataField', + type: 'object', + class: 'output', + }, + BoardField: { + properties: { + board_id: { + type: 'string', + title: 'Board Id', + description: 'The id of the board', + }, + }, + type: 'object', + required: ['board_id'], + title: 'BoardField', + description: 'A board primitive field', + }, + IterateInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The list of items to iterate over', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionField', + }, + index: { + type: 'integer', + title: 'Index', + description: 'The index, will be provided on executed iterators', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['iterate'], + const: 'iterate', + title: 'type', + default: 'iterate', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'IterateInvocation', + description: 'Iterates over a list of items', + classification: 'stable', + version: '1.1.0', + output: { + $ref: '#/components/schemas/IterateInvocationOutput', + }, + class: 'invocation', + }, + IterateInvocationOutput: { + description: 'Used to connect iteration outputs. Will be expanded to a specific output.', + properties: { + item: { + description: 'The item being iterated over', + field_kind: 'output', + title: 'Collection Item', + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + index: { + description: 'The index of the item', + field_kind: 'output', + title: 'Index', + type: 'integer', + ui_hidden: false, + }, + total: { + description: 'The total number of items', + field_kind: 'output', + title: 'Total', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'iterate_output', + default: 'iterate_output', + enum: ['iterate_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['item', 'index', 'total', 'type', 'type'], + title: 'IterateInvocationOutput', + type: 'object', + class: 'output', + }, + }, + }, +} as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts new file mode 100644 index 0000000000..19035afd54 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -0,0 +1,194 @@ +import { deepClone } from 'common/util/deepClone'; +import { set } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; +import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; + +describe(validateConnection.name, () => { + it('should reject invalid connection to self', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + + describe('missing nodes', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + + it('should reject missing source node', () => { + const r = validateConnection(c, [n2], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + + it('should reject missing target node', () => { + const r = validateConnection(c, [n1], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + }); + + describe('missing invocation templates', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const nodes = [n1, n2]; + + it('should reject missing source template', () => { + const r = validateConnection(c, nodes, [], { sub }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + + it('should reject missing target template', () => { + const r = validateConnection(c, nodes, [], { add }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + }); + + describe('missing field templates', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const nodes = [n1, n2]; + + it('should reject missing source field template', () => { + const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + + it('should reject missing target field template', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + }); + + describe('duplicate connections', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + it('should accept non-duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + it('should reject duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection')); + }); + it('should accept duplicate connections if the duplicate is an ignored edge', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, e); + expect(r).toEqual(buildAcceptResult()); + }); + }); + + it('should reject connection to direct input', () => { + // Create cloned add template w/ a direct input + const addWithDirectAField = deepClone(add); + set(addWithDirectAField, 'inputs.a.input', 'direct'); + set(addWithDirectAField, 'type', 'addWithDirectAField'); + + const n1 = buildNode(add); + const n2 = buildNode(addWithDirectAField); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput')); + }); + + it('should reject connection to a collect node with mismatched item types', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(main_model_loader); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes')); + }); + + it('should accept connection to a collect node with matching item types', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(sub); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + + it('should reject connections to target field that is already connected', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection')); + }); + + it('should accept connections to target field that is already connected (ignored edge)', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, e1); + expect(r).toEqual(buildAcceptResult()); + }); + + it('should reject connections between invalid types', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch')); + }); + + it('should reject connections that would create cycles', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); + + describe('non-strict mode', () => { + it('should reject connections from self to self in non-strict mode', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + it('should reject connections that create cycles in non-strict mode', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); + it('should otherwise allow invalid connections in non-strict mode', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null, false); + expect(r).toEqual(buildAcceptResult()); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts new file mode 100644 index 0000000000..8ece852b07 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -0,0 +1,130 @@ +import type { Templates } from 'features/nodes/store/types'; +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { AnyNode } from 'features/nodes/types/invocation'; +import type { Connection as NullableConnection, Edge } from 'reactflow'; +import type { O } from 'ts-toolbelt'; + +type Connection = O.NonNullable; + +export type ValidationResult = + | { + isValid: true; + messageTKey?: string; + } + | { + isValid: false; + messageTKey: string; + }; + +type ValidateConnectionFunc = ( + connection: Connection, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + ignoreEdge: Edge | null, + strict?: boolean +) => ValidationResult; + +const getEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return ( + e.target === c.target && + e.targetHandle === c.targetHandle && + e.source === c.source && + e.sourceHandle === c.sourceHandle + ); + }; + +const getTargetEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return e.target === c.target && e.targetHandle === c.targetHandle; + }; + +export const buildAcceptResult = (): ValidationResult => ({ isValid: true }); +export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey }); + +export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { + if (c.source === c.target) { + return buildRejectResult('nodes.cannotConnectToSelf'); + } + + if (strict) { + /** + * We may need to ignore an edge when validating a connection. + * + * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else + * the validation will fail unexpectedly. + */ + const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); + + if (filteredEdges.some(getEqualityPredicate(c))) { + // We already have a connection from this source to this target + return buildRejectResult('nodes.cannotDuplicateConnection'); + } + + const sourceNode = nodes.find((n) => n.id === c.source); + if (!sourceNode) { + return buildRejectResult('nodes.missingNode'); + } + + const targetNode = nodes.find((n) => n.id === c.target); + if (!targetNode) { + return buildRejectResult('nodes.missingNode'); + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; + if (!sourceFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + if (!targetFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + if (targetFieldTemplate.input === 'direct') { + return buildRejectResult('nodes.cannotConnectToDirectInput'); + } + + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { + // Collect nodes shouldn't mix and match field types. + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + // CollectionItemField inputs can have multiple input connections + if (targetFieldTemplate.type.name !== 'CollectionItemField') { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } + } + + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return buildRejectResult('nodes.fieldTypesMustMatch'); + } + } + + if (getHasCycles(c.source, c.target, nodes, edges)) { + return buildRejectResult('nodes.connectionWouldCreateCycle'); + } + + return buildAcceptResult(); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts new file mode 100644 index 0000000000..56d4cfe70a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -0,0 +1,222 @@ +import { describe, expect, it } from 'vitest'; + +import { validateConnectionTypes } from './validateConnectionTypes'; + +describe(validateConnectionTypes.name, () => { + describe('generic cases', () => { + it('should accept SINGLE to SINGLE of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'FooField', cardinality: 'SINGLE' } + ); + expect(r).toBe(true); + }); + it('should accept COLLECTION to COLLECTION of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should reject COLLECTION to SINGLE of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE' } + ); + expect(r).toBe(false); + }); + it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE' } + ); + expect(r).toBe(false); + }); + it('should reject mismatched types', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'BarField', cardinality: 'SINGLE' } + ); + expect(r).toBe(false); + }); + }); + + describe('special cases', () => { + it('should reject a COLLECTION input to a COLLECTION input', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', cardinality: 'COLLECTION' }, + { name: 'CollectionField', cardinality: 'COLLECTION' } + ); + expect(r).toBe(false); + }); + + it('should accept equal types', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE' } + ); + expect(r).toBe(true); + }); + + describe('CollectionItemField', () => { + it('should accept CollectionItemField to any SINGLE target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE' } + ); + expect(r).toBe(true); + }); + it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept any SINGLE to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } + ); + expect(r).toBe(true); + }); + it('should reject any COLLECTION to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'COLLECTION' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } + ); + expect(r).toBe(false); + }); + it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } + ); + expect(r).toBe(false); + }); + }); + + describe('SINGLE_OR_COLLECTION', () => { + it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'COLLECTION' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + }); + + describe('CollectionField', () => { + it('should accept any CollectionField to any COLLECTION type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + }); + + describe('subtype handling', () => { + type TypePair = { t1: string; t2: string }; + const typePairs = [ + { t1: 'IntegerField', t2: 'FloatField' }, + { t1: 'IntegerField', t2: 'StringField' }, + { t1: 'FloatField', t2: 'StringField' }, + ]; + it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' }); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, cardinality: 'SINGLE' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, cardinality: 'COLLECTION' }, + { name: t2, cardinality: 'COLLECTION' } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, cardinality: 'COLLECTION' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + it.each(typePairs)( + 'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', + ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, cardinality: 'SINGLE_OR_COLLECTION' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + } + ); + }); + + describe('AnyField', () => { + it('should accept any SINGLE type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'SINGLE' } + ); + expect(r).toBe(true); + }); + it('should accept any COLLECTION type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'COLLECTION' } + ); + expect(r).toBe(true); + }); + it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + }); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts new file mode 100644 index 0000000000..d5dee6dbaf --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -0,0 +1,74 @@ +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'features/nodes/types/field'; + +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ +export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => { + // TODO: There's a bug with Collect -> Iterate nodes: + // https://github.com/invoke-ai/InvokeAI/issues/3956 + // Once this is resolved, we can remove this check. + if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { + return false; + } + + if (areTypesEqual(sourceType, targetType)) { + return true; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION) + * - SINGLE can connect to CollectionItem + * - Anything (SINGLE, COLLECTION, SINGLE_OR_COLLECTION) can connect to SINGLE_OR_COLLECTION of the same base type + * - Generic CollectionField can connect to any other COLLECTION or SINGLE_OR_COLLECTION + * - Any COLLECTION can connect to a Generic Collection + */ + const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType); + + const isNonCollectionToCollectionItem = isSingle(sourceType) && targetType.name === 'CollectionItemField'; + + const isAnythingToSingleOrCollectionOfSameBaseType = + isSingleOrCollection(targetType) && sourceType.name === targetType.name; + + const isGenericCollectionToAnyCollectionOrSingleOrCollection = + sourceType.name === 'CollectionField' && !isSingle(targetType); + + const isCollectionToGenericCollection = targetType.name === 'CollectionField' && isCollection(sourceType); + + const isSourceSingle = isSingle(sourceType); + const isTargetSingle = isSingle(targetType); + const isSingleToSingle = isSourceSingle && isTargetSingle; + const isSingleToSingleOrCollection = isSourceSingle && isSingleOrCollection(targetType); + const isCollectionToCollection = isCollection(sourceType) && isCollection(targetType); + const isCollectionToSingleOrCollection = isCollection(sourceType) && isSingleOrCollection(targetType); + const isSingleOrCollectionToSingleOrCollection = isSingleOrCollection(sourceType) && isSingleOrCollection(targetType); + const doesCardinalityMatch = + isSingleToSingle || + isCollectionToCollection || + isCollectionToSingleOrCollection || + isSingleOrCollectionToSingleOrCollection || + isSingleToSingleOrCollection; + + const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField'; + const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField'; + + const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString); + + const isTargetAnyType = targetType.name === 'AnyField'; + + // One of these must be true for the connection to be valid + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToSingleOrCollectionOfSameBaseType || + isGenericCollectionToAnyCollectionOrSingleOrCollection || + isCollectionToGenericCollection || + isSubTypeMatch || + isTargetAnyType + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts deleted file mode 100644 index 3cbfb5b89c..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ /dev/null @@ -1,70 +0,0 @@ -import type { FieldType } from 'features/nodes/types/field'; -import { isEqual } from 'lodash-es'; - -/** - * Validates that the source and target types are compatible for a connection. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the connection is valid, false otherwise. - */ -export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => { - // TODO: There's a bug with Collect -> Iterate nodes: - // https://github.com/invoke-ai/InvokeAI/issues/3956 - // Once this is resolved, we can remove this check. - if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { - return false; - } - - if (isEqual(sourceType, targetType)) { - return true; - } - - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar - * - Any Collection can connect to a Generic Collection - */ - - const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; - - const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; - - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; - - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); - - const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; - - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; - - const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; - - const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; - - const isTargetAnyType = targetType.name === 'AnyField'; - - // One of these must be true for the connection to be valid - return ( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || - isCollectionToGenericCollection || - isIntToFloat || - isIntOrFloatToString || - isTargetAnyType - ); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 6293d3cce5..0d358f56e4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; -import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice'; +import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice'; import type { FieldIdentifierWithValue, WorkflowMode, @@ -139,15 +139,31 @@ export const workflowSlice = createSlice({ }; }); - builder.addCase(nodesDeleted, (state, action) => { - action.payload.forEach((node) => { - state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id); - }); - }); - builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState)); builder.addCase(nodesChanged, (state, action) => { + // If a node was removed, we should remove any exposed fields that were associated with it. However, node changes + // may remove and then add the same node back. For example, when updating a workflow, we replace old nodes with + // updated nodes. In this case, we should not remove the exposed fields. To handle this, we find the last remove + // and add changes for each exposed field. If the remove change comes after the add change, we remove the exposed + // field. + const exposedFieldsToRemove: FieldIdentifier[] = []; + state.exposedFields.forEach((field) => { + const removeIndex = action.payload.findLastIndex( + (change) => change.type === 'remove' && change.id === field.nodeId + ); + const addIndex = action.payload.findLastIndex( + (change) => change.type === 'add' && change.item.id === field.nodeId + ); + if (removeIndex > addIndex) { + exposedFieldsToRemove.push({ nodeId: field.nodeId, fieldName: field.fieldName }); + } + }); + + state.exposedFields = state.exposedFields.filter( + (field) => !exposedFieldsToRemove.some((f) => isEqual(f, field)) + ); + // Not all changes to nodes should result in the workflow being marked touched const filteredChanges = action.payload.filter((change) => { // We always want to mark the workflow as touched if a node is added, removed, or reset @@ -165,7 +181,7 @@ export const workflowSlice = createSlice({ return false; }); - if (filteredChanges.length > 0) { + if (filteredChanges.length > 0 || exposedFieldsToRemove.length > 0) { state.isTouched = true; } }); diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 87b0839bc3..e2a84e3390 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -54,9 +54,10 @@ const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ fieldKind: z.literal('output'), }); +const zCardinality = z.enum(['SINGLE', 'COLLECTION', 'SINGLE_OR_COLLECTION']); + const zFieldTypeBase = z.object({ - isCollection: z.boolean(), - isCollectionOrScalar: z.boolean(), + cardinality: zCardinality, }); export const zFieldIdentifier = z.object({ @@ -66,16 +67,124 @@ export const zFieldIdentifier = z.object({ export type FieldIdentifier = z.infer; // #endregion -// #region IntegerField +// #region Field Types +const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); const zIntegerFieldType = zFieldTypeBase.extend({ name: z.literal('IntegerField'), + originalType: zStatelessFieldType.optional(), }); +const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), + originalType: zStatelessFieldType.optional(), +}); +const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), + originalType: zStatelessFieldType.optional(), +}); +const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), + originalType: zStatelessFieldType.optional(), +}); +const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), + originalType: zStatelessFieldType.optional(), +}); +const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), + originalType: zStatelessFieldType.optional(), +}); +const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), + originalType: zStatelessFieldType.optional(), +}); +const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), + originalType: zStatelessFieldType.optional(), +}); +const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zModelIdentifierFieldType = zFieldTypeBase.extend({ + name: z.literal('ModelIdentifierField'), + originalType: zStatelessFieldType.optional(), +}); +const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), + originalType: zStatelessFieldType.optional(), +}); +const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zModelIdentifierFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); +export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType => + (statefulFieldTypeNames as string[]).includes(fieldType.name); +const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; + +export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE; +export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION; +export const isSingleOrCollection = (fieldType: FieldType): boolean => + fieldType.cardinality === zCardinality.enum.SINGLE_OR_COLLECTION; +// #endregion + +// #region IntegerField + export const zIntegerFieldValue = z.number().int(); const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ value: zIntegerFieldValue, }); const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, + originalType: zFieldType.optional(), default: zIntegerFieldValue, multipleOf: z.number().int().optional(), maximum: z.number().int().optional(), @@ -96,15 +205,14 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn // #endregion // #region FloatField -const zFloatFieldType = zFieldTypeBase.extend({ - name: z.literal('FloatField'), -}); + export const zFloatFieldValue = z.number(); const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ value: zFloatFieldValue, }); const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, + originalType: zFieldType.optional(), default: zFloatFieldValue, multipleOf: z.number().optional(), maximum: z.number().optional(), @@ -125,15 +233,14 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT // #endregion // #region StringField -const zStringFieldType = zFieldTypeBase.extend({ - name: z.literal('StringField'), -}); + export const zStringFieldValue = z.string(); const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ value: zStringFieldValue, }); const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, + originalType: zFieldType.optional(), default: zStringFieldValue, maxLength: z.number().int().optional(), minLength: z.number().int().optional(), @@ -152,15 +259,14 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu // #endregion // #region BooleanField -const zBooleanFieldType = zFieldTypeBase.extend({ - name: z.literal('BooleanField'), -}); + export const zBooleanFieldValue = z.boolean(); const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ value: zBooleanFieldValue, }); const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, + originalType: zFieldType.optional(), default: zBooleanFieldValue, }); const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -176,15 +282,14 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn // #endregion // #region EnumField -const zEnumFieldType = zFieldTypeBase.extend({ - name: z.literal('EnumField'), -}); + export const zEnumFieldValue = z.string(); const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ value: zEnumFieldValue, }); const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, + originalType: zFieldType.optional(), default: zEnumFieldValue, options: z.array(z.string()), labels: z.record(z.string()).optional(), @@ -202,15 +307,14 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem // #endregion // #region ImageField -const zImageFieldType = zFieldTypeBase.extend({ - name: z.literal('ImageField'), -}); + export const zImageFieldValue = zImageField.optional(); const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ value: zImageFieldValue, }); const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, + originalType: zFieldType.optional(), default: zImageFieldValue, }); const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -226,15 +330,14 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT // #endregion // #region BoardField -const zBoardFieldType = zFieldTypeBase.extend({ - name: z.literal('BoardField'), -}); + export const zBoardFieldValue = zBoardField.optional(); const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ value: zBoardFieldValue, }); const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, + originalType: zFieldType.optional(), default: zBoardFieldValue, }); const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -250,15 +353,14 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT // #endregion // #region ColorField -const zColorFieldType = zFieldTypeBase.extend({ - name: z.literal('ColorField'), -}); + export const zColorFieldValue = zColorField.optional(); const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ value: zColorFieldValue, }); const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, + originalType: zFieldType.optional(), default: zColorFieldValue, }); const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -274,15 +376,14 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT // #endregion // #region MainModelField -const zMainModelFieldType = zFieldTypeBase.extend({ - name: z.literal('MainModelField'), -}); + export const zMainModelFieldValue = zModelIdentifierField.optional(); const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zMainModelFieldValue, }); const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, + originalType: zFieldType.optional(), default: zMainModelFieldValue, }); const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -297,16 +398,37 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie zMainModelFieldInputTemplate.safeParse(val).success; // #endregion -// #region SDXLMainModelField -const zSDXLMainModelFieldType = zFieldTypeBase.extend({ - name: z.literal('SDXLMainModelField'), +// #region ModelIdentifierField +export const zModelIdentifierFieldValue = zModelIdentifierField.optional(); +const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zModelIdentifierFieldValue, }); +const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zModelIdentifierFieldType, + originalType: zFieldType.optional(), + default: zModelIdentifierFieldValue, +}); +const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zModelIdentifierFieldType, +}); +export type ModelIdentifierFieldValue = z.infer; +export type ModelIdentifierFieldInputInstance = z.infer; +export type ModelIdentifierFieldInputTemplate = z.infer; +export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance => + zModelIdentifierFieldInputInstance.safeParse(val).success; +export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate => + zModelIdentifierFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField + const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zSDXLMainModelFieldValue, }); const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, + originalType: zFieldType.optional(), default: zSDXLMainModelFieldValue, }); const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -321,9 +443,7 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain // #endregion // #region SDXLRefinerModelField -const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ - name: z.literal('SDXLRefinerModelField'), -}); + /** @alias */ // tells knip to ignore this duplicate export export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ @@ -331,6 +451,7 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ }); const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, + originalType: zFieldType.optional(), default: zSDXLRefinerModelFieldValue, }); const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -346,15 +467,14 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR // #endregion // #region VAEModelField -const zVAEModelFieldType = zFieldTypeBase.extend({ - name: z.literal('VAEModelField'), -}); + export const zVAEModelFieldValue = zModelIdentifierField.optional(); const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zVAEModelFieldValue, }); const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, + originalType: zFieldType.optional(), default: zVAEModelFieldValue, }); const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -370,15 +490,14 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField // #endregion // #region LoRAModelField -const zLoRAModelFieldType = zFieldTypeBase.extend({ - name: z.literal('LoRAModelField'), -}); + export const zLoRAModelFieldValue = zModelIdentifierField.optional(); const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zLoRAModelFieldValue, }); const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, + originalType: zFieldType.optional(), default: zLoRAModelFieldValue, }); const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -394,15 +513,14 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie // #endregion // #region ControlNetModelField -const zControlNetModelFieldType = zFieldTypeBase.extend({ - name: z.literal('ControlNetModelField'), -}); + export const zControlNetModelFieldValue = zModelIdentifierField.optional(); const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zControlNetModelFieldValue, }); const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, + originalType: zFieldType.optional(), default: zControlNetModelFieldValue, }); const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -418,15 +536,14 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro // #endregion // #region IPAdapterModelField -const zIPAdapterModelFieldType = zFieldTypeBase.extend({ - name: z.literal('IPAdapterModelField'), -}); + export const zIPAdapterModelFieldValue = zModelIdentifierField.optional(); const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zIPAdapterModelFieldValue, }); const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, + originalType: zFieldType.optional(), default: zIPAdapterModelFieldValue, }); const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -442,15 +559,14 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt // #endregion // #region T2IAdapterField -const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ - name: z.literal('T2IAdapterModelField'), -}); + export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional(); const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zT2IAdapterModelFieldValue, }); const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, + originalType: zFieldType.optional(), default: zT2IAdapterModelFieldValue, }); const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -466,15 +582,14 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda // #endregion // #region SchedulerField -const zSchedulerFieldType = zFieldTypeBase.extend({ - name: z.literal('SchedulerField'), -}); + export const zSchedulerFieldValue = zSchedulerField.optional(); const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ value: zSchedulerFieldValue, }); const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, + originalType: zFieldType.optional(), default: zSchedulerFieldValue, }); const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ @@ -501,15 +616,14 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie * - Reserved fields like IsIntermediate * - Any other field we don't have full-on schemas for */ -const zStatelessFieldType = zFieldTypeBase.extend({ - name: z.string().min(1), // stateless --> we accept the field's name as the type -}); + const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ value: zStatelessFieldValue, }); const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, + originalType: zFieldType.optional(), default: zStatelessFieldValue, input: z.literal('connection'), // stateless --> only accepts connection inputs }); @@ -535,34 +649,6 @@ export type StatelessFieldInputTemplate = z.infer; -export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => - zStatefulFieldType.safeParse(val).success; - -const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); -export type FieldType = z.infer; -// #endregion - // #region StatefulFieldValue & FieldValue export const zStatefulFieldValue = z.union([ zIntegerFieldValue, @@ -572,6 +658,7 @@ export const zStatefulFieldValue = z.union([ zEnumFieldValue, zImageFieldValue, zBoardFieldValue, + zModelIdentifierFieldValue, zMainModelFieldValue, zSDXLMainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -598,6 +685,7 @@ const zStatefulFieldInputInstance = z.union([ zEnumFieldInputInstance, zImageFieldInputInstance, zBoardFieldInputInstance, + zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, @@ -625,6 +713,7 @@ const zStatefulFieldInputTemplate = z.union([ zEnumFieldInputTemplate, zImageFieldInputTemplate, zBoardFieldInputTemplate, + zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, @@ -653,6 +742,7 @@ const zStatefulFieldOutputTemplate = z.union([ zEnumFieldOutputTemplate, zImageFieldOutputTemplate, zBoardFieldOutputTemplate, + zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 66a3db62bf..0a7149bd6b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -70,13 +70,18 @@ export const isInvocationNodeData = (node?: AnyNodeData | null): node is Invocat // #region NodeExecutionState export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +const zNodeError = z.object({ + error_type: z.string(), + error_message: z.string(), + error_traceback: z.string(), +}); const zNodeExecutionState = z.object({ nodeId: z.string().trim().min(1), status: zNodeStatus, progress: z.number().nullable(), progressImage: zProgressImage.nullable(), - error: z.string().nullable(), outputs: z.array(z.any()), + error: zNodeError.nullable(), }); export type NodeExecutionState = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index 79946cd8d5..f1d4e61300 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -1,4 +1,4 @@ -import type { FieldType, StatefulFieldType } from 'features/nodes/types/field'; +import type { StatefulFieldType, StatelessFieldType } from 'features/nodes/types/v2/field'; import type { FieldTypeV1 } from './workflowV1'; @@ -165,7 +165,7 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { * Thus, this object was manually edited to ensure it is correct. */ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { - [key in FieldTypeV1]?: FieldType; + [key in FieldTypeV1]?: StatelessFieldType; } = { Any: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, ClipField: { diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts index 1e464fa76d..4b680d1de3 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -316,6 +316,7 @@ const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ const zStatelessFieldType = zFieldTypeBase.extend({ name: z.string().min(1), // stateless --> we accept the field's name as the type }); +export type StatelessFieldType = z.infer; const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ type: zStatelessFieldType, @@ -327,6 +328,27 @@ const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ // #endregion +const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; + /** * Here we define the main field unions: * - FieldType diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index a424bf8d4b..9805edfaf2 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -47,6 +47,7 @@ const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ type: z.literal('default'), sourceHandle: z.string().trim().min(1), targetHandle: z.string().trim().min(1), + hidden: z.boolean().optional(), }); const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ type: z.literal('collapsed'), diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts index 2feba262c2..110a20e5a7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts @@ -29,7 +29,7 @@ export const addControlNetToLinearGraph = async ( assert(activeTabName !== 'generation', 'Tried to use addControlNetToLinearGraph on generation tab'); if (controlNets.length) { - // Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' control input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const controlNetIterateNode: Invocation<'collect'> = { id: CONTROL_NET_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts index e9d9bd4663..1f24463419 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts @@ -25,7 +25,7 @@ export const addIPAdapterToLinearGraph = async ( }); if (ipAdapters.length) { - // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' ip adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const ipAdapterCollectNode: Invocation<'collect'> = { id: IP_ADAPTER_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts index 7c51d9488f..72cf9ca0f8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts @@ -28,7 +28,7 @@ export const addT2IAdaptersToLinearGraph = async ( ); if (t2iAdapters.length) { - // Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' t2i adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const t2iAdapterCollectNode: Invocation<'collect'> = { id: T2I_ADAPTER_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts index 8f5fe9f2b8..5c89dcbf29 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts @@ -330,6 +330,7 @@ export const buildCanvasImageToImageGraph = async ( clip_skip: clipSkip, strength, init_image: initialImage.image_name, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts index c995c38a3c..20304b8830 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_INPAINT_GRAPH, CANVAS_OUTPUT, @@ -421,6 +422,15 @@ export const buildCanvasInpaintGraph = async ( }); } + addCoreMetadataNode( + graph, + { + generation_mode: 'inpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts index e4a9b11b96..2c85b20222 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPAINT_GRAPH, CANVAS_OUTPUT, @@ -579,6 +580,15 @@ export const buildCanvasOutpaintGraph = async ( ); } + addCoreMetadataNode( + graph, + { + generation_mode: 'outpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts index 186dfa53b3..b4549ff582 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts @@ -332,6 +332,7 @@ export const buildCanvasSDXLImageToImageGraph = async ( init_image: initialImage.image_name, positive_style_prompt: positiveStylePrompt, negative_style_prompt: negativeStylePrompt, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts index 277b713079..dfbe2436d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPUT, INPAINT_CREATE_MASK, @@ -432,6 +433,15 @@ export const buildCanvasSDXLInpaintGraph = async ( }); } + addCoreMetadataNode( + graph, + { + generation_mode: 'sdxl_inpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts index b09d7d8b90..d58796575c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPUT, INPAINT_CREATE_MASK, @@ -588,6 +589,15 @@ export const buildCanvasSDXLOutpaintGraph = async ( ); } + addCoreMetadataNode( + graph, + { + generation_mode: 'sdxl_outpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts index b2a8aa6ada..b9e8e011b3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts @@ -291,6 +291,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise steps, rand_device: use_cpu ? 'cpu' : 'cuda', scheduler, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts index 8ce5134480..fe33ab5cf3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts @@ -280,6 +280,7 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise { + if (!layer.isEnabled) { + return false; + } if (isControlAdapterLayer(layer)) { - if (!layer.isEnabled) { - return false; - } return isValidControlAdapter(layer.controlAdapter, base); } if (isIPAdapterLayer(layer)) { - if (!layer.isEnabled) { - return false; - } return isValidIPAdapter(layer.ipAdapter, base); } if (isInitialImageLayer(layer)) { - if (!layer.isEnabled) { - return false; - } - if (!layer.image) { - return false; - } - return true; + const hasImage = Boolean(layer.image); + return hasImage; } if (isRegionalGuidanceLayer(layer)) { const hasTextPrompt = Boolean(layer.positivePrompt || layer.negativePrompt); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index f8097566c9..597779fd61 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = IntegerField: 0, IPAdapterModelField: undefined, LoRAModelField: undefined, + ModelIdentifierField: undefined, MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 3e8278ea6a..2b77274526 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -13,6 +13,7 @@ import type { IPAdapterModelFieldInputTemplate, LoRAModelFieldInputTemplate, MainModelFieldInputTemplate, + ModelIdentifierFieldInputTemplate, SchedulerFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, @@ -30,26 +31,16 @@ import { isNumber, startCase } from 'lodash-es'; // eslint-disable-next-line @typescript-eslint/no-explicit-any type FieldInputTemplateBuilder = // valid `any`! - (arg: { - schemaObject: InvocationFieldSchema; - baseField: Omit; - isCollection: boolean; - isCollectionOrScalar: boolean; - }) => T; + (arg: { schemaObject: InvocationFieldSchema; baseField: Omit; fieldType: T['type'] }) => T; const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: IntegerFieldInputTemplate = { ...baseField, - type: { - name: 'IntegerField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 0, }; @@ -79,16 +70,11 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: FloatFieldInputTemplate = { ...baseField, - type: { - name: 'FloatField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 0, }; @@ -118,16 +104,11 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: StringFieldInputTemplate = { ...baseField, - type: { - name: 'StringField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? '', }; @@ -145,35 +126,39 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: BooleanFieldInputTemplate = { ...baseField, - type: { - name: 'BooleanField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? false, }; return template; }; +const buildModelIdentifierFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: ModelIdentifierFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: MainModelFieldInputTemplate = { ...baseField, - type: { - name: 'MainModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -183,16 +168,11 @@ const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SDXLMainModelFieldInputTemplate = { ...baseField, - type: { - name: 'SDXLMainModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -202,16 +182,11 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SDXLRefinerModelFieldInputTemplate = { ...baseField, - type: { - name: 'SDXLRefinerModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -221,16 +196,11 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: VAEModelFieldInputTemplate = { ...baseField, - type: { - name: 'VAEModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -240,16 +210,11 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: LoRAModelFieldInputTemplate = { ...baseField, - type: { - name: 'LoRAModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -259,16 +224,11 @@ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ControlNetModelFieldInputTemplate = { ...baseField, - type: { - name: 'ControlNetModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -278,16 +238,11 @@ const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: IPAdapterModelFieldInputTemplate = { ...baseField, - type: { - name: 'IPAdapterModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -297,16 +252,11 @@ const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: T2IAdapterModelFieldInputTemplate = { ...baseField, - type: { - name: 'T2IAdapterModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -316,16 +266,11 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: BoardFieldInputTemplate = { ...baseField, - type: { - name: 'BoardField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -335,16 +280,11 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ImageFieldInputTemplate = { ...baseField, - type: { - name: 'ImageField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -354,8 +294,7 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { let options: EnumFieldInputTemplate['options'] = []; if (schemaObject.anyOf) { @@ -383,11 +322,7 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ColorFieldInputTemplate = { ...baseField, - type: { - name: 'ColorField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, }; @@ -418,16 +348,11 @@ const buildColorFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SchedulerFieldInputTemplate = { ...baseField, - type: { - name: 'SchedulerField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 'euler', }; @@ -445,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record connection only inputs - type: fieldType, - default: undefined, // stateless --> no default value - }; - return template; + return template; + } else { + // This is a StatelessField, create it directly. + const template: StatelessFieldInputTemplate = { + ...baseField, + input: 'connection', // stateless --> connection only inputs + type: fieldType, + default: undefined, // stateless --> no default value + }; + + return template; + } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts index 8c789493ad..abbe2c3488 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts @@ -9,7 +9,7 @@ export const buildFieldOutputTemplate = ( ): FieldOutputTemplate => { const { title, description, ui_hidden, ui_type, ui_order } = fieldSchema; - const fieldOutputTemplate: FieldOutputTemplate = { + const template: FieldOutputTemplate = { fieldKind: 'output', name: fieldName, title: title ?? (fieldName ? startCase(fieldName) : ''), @@ -20,5 +20,5 @@ export const buildFieldOutputTemplate = ( ui_order, }; - return fieldOutputTemplate; + return template; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts index d7011ad6f8..3d3aff3cd6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -4,6 +4,7 @@ import { UnsupportedPrimitiveTypeError, UnsupportedUnionError, } from 'features/nodes/types/error'; +import type { FieldType } from 'features/nodes/types/field'; import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType'; import { describe, expect, it } from 'vitest'; @@ -11,52 +12,52 @@ import { describe, expect, it } from 'vitest'; type ParseFieldTypeTestCase = { name: string; schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema; - expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean }; + expected: FieldType; }; const primitiveTypes: ParseFieldTypeTestCase[] = [ { - name: 'Scalar IntegerField', + name: 'SINGLE IntegerField', schema: { type: 'integer' }, - expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'IntegerField', cardinality: 'SINGLE' }, }, { - name: 'Scalar FloatField', + name: 'SINGLE FloatField', schema: { type: 'number' }, - expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'FloatField', cardinality: 'SINGLE' }, }, { - name: 'Scalar StringField', + name: 'SINGLE StringField', schema: { type: 'string' }, - expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'StringField', cardinality: 'SINGLE' }, }, { - name: 'Scalar BooleanField', + name: 'SINGLE BooleanField', schema: { type: 'boolean' }, - expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'BooleanField', cardinality: 'SINGLE' }, }, { - name: 'Collection IntegerField', + name: 'COLLECTION IntegerField', schema: { items: { type: 'integer' }, type: 'array' }, - expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'IntegerField', cardinality: 'COLLECTION' }, }, { - name: 'Collection FloatField', + name: 'COLLECTION FloatField', schema: { items: { type: 'number' }, type: 'array' }, - expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'FloatField', cardinality: 'COLLECTION' }, }, { - name: 'Collection StringField', + name: 'COLLECTION StringField', schema: { items: { type: 'string' }, type: 'array' }, - expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'StringField', cardinality: 'COLLECTION' }, }, { - name: 'Collection BooleanField', + name: 'COLLECTION BooleanField', schema: { items: { type: 'boolean' }, type: 'array' }, - expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'BooleanField', cardinality: 'COLLECTION' }, }, { - name: 'CollectionOrScalar IntegerField', + name: 'SINGLE_OR_COLLECTION IntegerField', schema: { anyOf: [ { @@ -70,10 +71,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar FloatField', + name: 'SINGLE_OR_COLLECTION FloatField', schema: { anyOf: [ { @@ -87,10 +88,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar StringField', + name: 'SINGLE_OR_COLLECTION StringField', schema: { anyOf: [ { @@ -104,10 +105,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar BooleanField', + name: 'SINGLE_OR_COLLECTION BooleanField', schema: { anyOf: [ { @@ -121,13 +122,13 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION' }, }, ]; const complexTypes: ParseFieldTypeTestCase[] = [ { - name: 'Scalar ConditioningField', + name: 'SINGLE ConditioningField', schema: { allOf: [ { @@ -135,10 +136,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE' }, }, { - name: 'Nullable Scalar ConditioningField', + name: 'Nullable SINGLE ConditioningField', schema: { anyOf: [ { @@ -149,10 +150,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE' }, }, { - name: 'Collection ConditioningField', + name: 'COLLECTION ConditioningField', schema: { anyOf: [ { @@ -163,7 +164,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'COLLECTION' }, }, { name: 'Nullable Collection ConditioningField', @@ -180,10 +181,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'COLLECTION' }, }, { - name: 'CollectionOrScalar ConditioningField', + name: 'SINGLE_OR_COLLECTION ConditioningField', schema: { anyOf: [ { @@ -197,10 +198,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'Nullable CollectionOrScalar ConditioningField', + name: 'Nullable SINGLE_OR_COLLECTION ConditioningField', schema: { anyOf: [ { @@ -217,7 +218,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' }, }, ]; @@ -228,14 +229,14 @@ const specialCases: ParseFieldTypeTestCase[] = [ type: 'string', enum: ['large', 'base', 'small'], }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'String EnumField with one value', schema: { const: 'Some Value', }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (SchedulerField)', @@ -244,7 +245,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'SchedulerField', }, - expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (AnyField)', @@ -253,7 +254,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'AnyField', }, - expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (CollectionField)', @@ -262,7 +263,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'CollectionField', }, - expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, ]; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 13da6b3831..18dcd8fb21 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -6,14 +6,8 @@ import { UnsupportedUnionError, } from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; -import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; -import { - isArraySchemaObject, - isInvocationFieldSchema, - isNonArraySchemaObject, - isRefObject, - isSchemaObject, -} from 'features/nodes/types/openapi'; +import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import { isArraySchemaObject, isNonArraySchemaObject, isRefObject, isSchemaObject } from 'features/nodes/types/openapi'; import { t } from 'i18next'; import { isArray } from 'lodash-es'; import type { OpenAPIV3_1 } from 'openapi-types'; @@ -35,7 +29,7 @@ const OPENAPI_TO_FIELD_TYPE_MAP: Record = { boolean: 'BooleanField', }; -const isCollectionFieldType = (fieldType: string) => { +export const isCollectionFieldType = (fieldType: string) => { /** * CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between * it and other `list[Any]` fields, due to its special internal handling. @@ -48,25 +42,13 @@ const isCollectionFieldType = (fieldType: string) => { return false; }; -export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => { - if (isInvocationFieldSchema(schemaObject)) { - // Check if this field has an explicit type provided by the node schema - const { ui_type } = schemaObject; - if (ui_type) { - return { - name: ui_type, - isCollection: isCollectionFieldType(ui_type), - isCollectionOrScalar: false, - }; - } - } +export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => { if (isSchemaObject(schemaObject)) { if (schemaObject.const) { // Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum return { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } if (!schemaObject.type) { @@ -82,8 +64,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } } else if (schemaObject.anyOf) { @@ -106,15 +87,14 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } else if (isSchemaObject(filteredAnyOf[0])) { return parseFieldType(filteredAnyOf[0]); } } /** - * Handle CollectionOrScalar inputs, eg string | string[]. In OpenAPI, this is: + * Handle SINGLE_OR_COLLECTION inputs, eg string | string[]. In OpenAPI, this is: * - an `anyOf` with two items * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` * - the other is a `SchemaObject` or `ReferenceObject` of type T @@ -160,8 +140,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation if (firstType && firstType === secondType) { return { name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType, - isCollection: false, - isCollectionOrScalar: true, // <-- don't forget, CollectionOrScalar type! + cardinality: 'SINGLE_OR_COLLECTION', }; } @@ -175,8 +154,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } else if (schemaObject.enum) { return { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } else if (schemaObject.type) { if (schemaObject.type === 'array') { @@ -202,8 +180,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } return { name, - isCollection: true, // <-- don't forget, collection! - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }; } @@ -214,8 +191,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } return { name, - isCollection: true, // <-- don't forget, collection! - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }; } else if (!isArray(schemaObject.type)) { // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' @@ -230,8 +206,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } } @@ -242,8 +217,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | Invocation } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } throw new FieldParseError(t('nodes.unableToParseFieldType')); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts index 6c0a6635c7..656bdc9d64 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts @@ -1,790 +1,19 @@ +import { schema, templates } from 'features/nodes/store/util/testUtils'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { omit, pick } from 'lodash-es'; -import type { OpenAPIV3_1 } from 'openapi-types'; import { describe, expect, it } from 'vitest'; describe('parseSchema', () => { it('should parse the schema', () => { - const templates = parseSchema(schema); - expect(templates).toEqual(expected); + const parsed = parseSchema(schema); + expect(parsed).toEqual(templates); }); it('should omit denied nodes', () => { - const templates = parseSchema(schema, undefined, ['add']); - expect(templates).toEqual(omit(expected, 'add')); + const parsed = parseSchema(schema, undefined, ['add']); + expect(parsed).toEqual(omit(templates, 'add')); }); it('should include only allowed nodes', () => { - const templates = parseSchema(schema, ['add']); - expect(templates).toEqual(pick(expected, 'add')); + const parsed = parseSchema(schema, ['add']); + expect(parsed).toEqual(pick(templates, 'add')); }); }); - -const expected = { - add: { - title: 'Add Integers', - type: 'add', - version: '1.0.1', - tags: ['math', 'add'], - description: 'Adds two numbers', - outputType: 'integer_output', - inputs: { - a: { - name: 'a', - title: 'A', - required: false, - description: 'The first number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - b: { - name: 'b', - title: 'B', - required: false, - description: 'The second number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - }, - outputs: { - value: { - fieldKind: 'output', - name: 'value', - title: 'Value', - description: 'The output integer', - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - scheduler: { - title: 'Scheduler', - type: 'scheduler', - version: '1.0.0', - tags: ['scheduler'], - description: 'Selects a scheduler.', - outputType: 'scheduler_output', - inputs: { - scheduler: { - name: 'scheduler', - title: 'Scheduler', - required: false, - description: 'Scheduler to use during inference', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - ui_type: 'SchedulerField', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 'euler', - }, - }, - outputs: { - scheduler: { - fieldKind: 'output', - name: 'scheduler', - title: 'Scheduler', - description: 'Scheduler to use during inference', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - main_model_loader: { - title: 'Main Model', - type: 'main_model_loader', - version: '1.0.2', - tags: ['model'], - description: 'Loads a main model, outputting its submodels.', - outputType: 'model_loader_output', - inputs: { - model: { - name: 'model', - title: 'Model', - required: true, - description: 'Main model (UNet, VAE, CLIP) to load', - fieldKind: 'input', - input: 'direct', - ui_hidden: false, - ui_type: 'MainModelField', - type: { - name: 'MainModelField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - }, - outputs: { - vae: { - fieldKind: 'output', - name: 'vae', - title: 'VAE', - description: 'VAE', - type: { - name: 'VAEField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - clip: { - fieldKind: 'output', - name: 'clip', - title: 'CLIP', - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - type: { - name: 'CLIPField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - unet: { - fieldKind: 'output', - name: 'unet', - title: 'UNet', - description: 'UNet (scheduler, LoRAs)', - type: { - name: 'UNetField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, -}; - -const schema = { - openapi: '3.1.0', - info: { - title: 'Invoke - Community Edition', - description: 'An API for invoking AI image operations', - version: '1.0.0', - }, - components: { - schemas: { - AddInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - a: { - type: 'integer', - title: 'A', - description: 'The first number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - b: { - type: 'integer', - title: 'B', - description: 'The second number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - type: { - type: 'string', - enum: ['add'], - const: 'add', - title: 'type', - default: 'add', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Add Integers', - description: 'Adds two numbers', - category: 'math', - classification: 'stable', - node_pack: 'invokeai', - tags: ['math', 'add'], - version: '1.0.1', - output: { - $ref: '#/components/schemas/IntegerOutput', - }, - class: 'invocation', - }, - IntegerOutput: { - description: 'Base class for nodes that output a single integer', - properties: { - value: { - description: 'The output integer', - field_kind: 'output', - title: 'Value', - type: 'integer', - ui_hidden: false, - }, - type: { - const: 'integer_output', - default: 'integer_output', - enum: ['integer_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['value', 'type', 'type'], - title: 'IntegerOutput', - type: 'object', - class: 'output', - }, - SchedulerInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - scheduler: { - type: 'string', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - title: 'Scheduler', - description: 'Scheduler to use during inference', - default: 'euler', - field_kind: 'input', - input: 'any', - orig_default: 'euler', - orig_required: false, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - type: 'string', - enum: ['scheduler'], - const: 'scheduler', - title: 'type', - default: 'scheduler', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Scheduler', - description: 'Selects a scheduler.', - category: 'latents', - classification: 'stable', - node_pack: 'invokeai', - tags: ['scheduler'], - version: '1.0.0', - output: { - $ref: '#/components/schemas/SchedulerOutput', - }, - class: 'invocation', - }, - SchedulerOutput: { - properties: { - scheduler: { - description: 'Scheduler to use during inference', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - field_kind: 'output', - title: 'Scheduler', - type: 'string', - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - const: 'scheduler_output', - default: 'scheduler_output', - enum: ['scheduler_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['scheduler', 'type', 'type'], - title: 'SchedulerOutput', - type: 'object', - class: 'output', - }, - MainModelLoaderInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - model: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Main model (UNet, VAE, CLIP) to load', - field_kind: 'input', - input: 'direct', - orig_required: true, - ui_hidden: false, - ui_type: 'MainModelField', - }, - type: { - type: 'string', - enum: ['main_model_loader'], - const: 'main_model_loader', - title: 'type', - default: 'main_model_loader', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['model', 'type', 'id'], - title: 'Main Model', - description: 'Loads a main model, outputting its submodels.', - category: 'model', - classification: 'stable', - node_pack: 'invokeai', - tags: ['model'], - version: '1.0.2', - output: { - $ref: '#/components/schemas/ModelLoaderOutput', - }, - class: 'invocation', - }, - ModelIdentifierField: { - properties: { - key: { - description: "The model's unique key", - title: 'Key', - type: 'string', - }, - hash: { - description: "The model's BLAKE3 hash", - title: 'Hash', - type: 'string', - }, - name: { - description: "The model's name", - title: 'Name', - type: 'string', - }, - base: { - allOf: [ - { - $ref: '#/components/schemas/BaseModelType', - }, - ], - description: "The model's base model type", - }, - type: { - allOf: [ - { - $ref: '#/components/schemas/ModelType', - }, - ], - description: "The model's type", - }, - submodel_type: { - anyOf: [ - { - $ref: '#/components/schemas/SubModelType', - }, - { - type: 'null', - }, - ], - default: null, - description: 'The submodel to load, if this is a main model', - }, - }, - required: ['key', 'hash', 'name', 'base', 'type'], - title: 'ModelIdentifierField', - type: 'object', - }, - BaseModelType: { - description: 'Base model type.', - enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], - title: 'BaseModelType', - type: 'string', - }, - ModelType: { - description: 'Model type.', - enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], - title: 'ModelType', - type: 'string', - }, - SubModelType: { - description: 'Submodel type.', - enum: [ - 'unet', - 'text_encoder', - 'text_encoder_2', - 'tokenizer', - 'tokenizer_2', - 'vae', - 'vae_decoder', - 'vae_encoder', - 'scheduler', - 'safety_checker', - ], - title: 'SubModelType', - type: 'string', - }, - ModelLoaderOutput: { - description: 'Model loader output', - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/VAEField', - }, - ], - description: 'VAE', - field_kind: 'output', - title: 'VAE', - ui_hidden: false, - }, - type: { - const: 'model_loader_output', - default: 'model_loader_output', - enum: ['model_loader_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - clip: { - allOf: [ - { - $ref: '#/components/schemas/CLIPField', - }, - ], - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - field_kind: 'output', - title: 'CLIP', - ui_hidden: false, - }, - unet: { - allOf: [ - { - $ref: '#/components/schemas/UNetField', - }, - ], - description: 'UNet (scheduler, LoRAs)', - field_kind: 'output', - title: 'UNet', - ui_hidden: false, - }, - }, - required: ['vae', 'type', 'clip', 'unet', 'type'], - title: 'ModelLoaderOutput', - type: 'object', - class: 'output', - }, - UNetField: { - properties: { - unet: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load unet submodel', - }, - scheduler: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load scheduler submodel', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - freeu_config: { - anyOf: [ - { - $ref: '#/components/schemas/FreeUConfig', - }, - { - type: 'null', - }, - ], - default: null, - description: 'FreeU configuration', - }, - }, - required: ['unet', 'scheduler', 'loras'], - title: 'UNetField', - type: 'object', - class: 'output', - }, - LoRAField: { - properties: { - lora: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load lora model', - }, - weight: { - description: 'Weight to apply to lora model', - title: 'Weight', - type: 'number', - }, - }, - required: ['lora', 'weight'], - title: 'LoRAField', - type: 'object', - class: 'output', - }, - FreeUConfig: { - description: - 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', - properties: { - s1: { - description: - 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S1', - type: 'number', - }, - s2: { - description: - 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S2', - type: 'number', - }, - b1: { - description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B1', - type: 'number', - }, - b2: { - description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B2', - type: 'number', - }, - }, - required: ['s1', 's2', 'b1', 'b2'], - title: 'FreeUConfig', - type: 'object', - class: 'output', - }, - VAEField: { - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load vae submodel', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - }, - required: ['vae'], - title: 'VAEField', - type: 'object', - class: 'output', - }, - CLIPField: { - properties: { - tokenizer: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load tokenizer submodel', - }, - text_encoder: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load text_encoder submodel', - }, - skipped_layers: { - description: 'Number of skipped layers in text_encoder', - title: 'Skipped Layers', - type: 'integer', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - }, - required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], - title: 'CLIPField', - type: 'object', - class: 'output', - }, - }, - }, -} as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 3178209f93..3981b759db 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -1,23 +1,29 @@ import { logger } from 'app/logging/logger'; +import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; import type { Templates } from 'features/nodes/store/types'; import { FieldParseError } from 'features/nodes/types/error'; -import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import { + type FieldInputTemplate, + type FieldOutputTemplate, + type FieldType, + isStatefulFieldType, +} from 'features/nodes/types/field'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { InvocationSchemaObject } from 'features/nodes/types/openapi'; +import type { InvocationFieldSchema, InvocationSchemaObject } from 'features/nodes/types/openapi'; import { isInvocationFieldSchema, isInvocationOutputSchemaObject, isInvocationSchemaObject, } from 'features/nodes/types/openapi'; import { t } from 'i18next'; -import { reduce } from 'lodash-es'; +import { isEqual, reduce } from 'lodash-es'; import type { OpenAPIV3_1 } from 'openapi-types'; import { serializeError } from 'serialize-error'; import { buildFieldInputTemplate } from './buildFieldInputTemplate'; import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; -import { parseFieldType } from './parseFieldType'; +import { isCollectionFieldType, parseFieldType } from './parseFieldType'; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache']; const RESERVED_OUTPUT_FIELD_NAMES = ['type']; @@ -94,51 +100,39 @@ export const parseSchema = ( return inputsAccumulator; } - try { - const fieldType = parseFieldType(property); + const fieldTypeOverride: FieldType | null = property.ui_type + ? { + name: property.ui_type, + cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE', + } + : null; - if (isReservedFieldType(fieldType.name)) { - logger('nodes').trace( - { node: type, field: propertyName, schema: parseify(property) }, - 'Skipped reserved input field' - ); - return inputsAccumulator; - } + const originalFieldType = getFieldType(property, propertyName, type, 'input'); - const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType); - - inputsAccumulator[propertyName] = fieldInputTemplate; - } catch (e) { - if (e instanceof FieldParseError) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - t('nodes.inputFieldTypeParseError', { - node: type, - field: propertyName, - message: e.message, - }) - ); - } else { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - error: serializeError(e), - }, - t('nodes.inputFieldTypeParseError', { - node: type, - field: propertyName, - message: 'unknown error', - }) - ); - } + const fieldType = fieldTypeOverride ?? originalFieldType; + if (!fieldType) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Unable to parse field type' + ); + return inputsAccumulator; } + if (isReservedFieldType(fieldType.name)) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Skipped reserved input field' + ); + return inputsAccumulator; + } + + if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { + fieldType.originalType = deepClone(originalFieldType); + } + + const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType); + inputsAccumulator[propertyName] = fieldInputTemplate; + return inputsAccumulator; }, {} @@ -183,54 +177,31 @@ export const parseSchema = ( return outputsAccumulator; } - try { - const fieldType = parseFieldType(property); + const fieldTypeOverride: FieldType | null = property.ui_type + ? { + name: property.ui_type, + cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE', + } + : null; - if (!fieldType) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - 'Missing output field type' - ); - return outputsAccumulator; - } + const originalFieldType = getFieldType(property, propertyName, type, 'output'); - const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType); - - outputsAccumulator[propertyName] = fieldOutputTemplate; - } catch (e) { - if (e instanceof FieldParseError) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - t('nodes.outputFieldTypeParseError', { - node: type, - field: propertyName, - message: e.message, - }) - ); - } else { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - error: serializeError(e), - }, - t('nodes.outputFieldTypeParseError', { - node: type, - field: propertyName, - message: 'unknown error', - }) - ); - } + const fieldType = fieldTypeOverride ?? originalFieldType; + if (!fieldType) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Unable to parse field type' + ); + return outputsAccumulator; } + + if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { + fieldType.originalType = deepClone(originalFieldType); + } + + const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType); + + outputsAccumulator[propertyName] = fieldOutputTemplate; return outputsAccumulator; }, {} as Record @@ -259,3 +230,45 @@ export const parseSchema = ( return invocations; }; + +const getFieldType = ( + property: InvocationFieldSchema, + propertyName: string, + type: string, + kind: 'input' | 'output' +): FieldType | null => { + try { + return parseFieldType(property); + } catch (e) { + const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError'; + if (e instanceof FieldParseError) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t(tKey, { + node: type, + field: propertyName, + message: e.message, + }) + ); + } else { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + error: serializeError(e), + }, + t(tKey, { + node: type, + field: propertyName, + message: 'unknown error', + }) + ); + } + return null; + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index b164dde90e..cec8b0a2b7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -66,6 +66,7 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo target: edge.target, sourceHandle: edge.sourceHandle, targetHandle: edge.targetHandle, + hidden: edge.hidden, }); } else if (edge.type === 'collapsed') { newWorkflow.edges.push({ diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 32369b88c9..c7bcbf0953 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -1,12 +1,12 @@ import { deepClone } from 'common/util/deepClone'; import { $templates } from 'features/nodes/store/nodesSlice'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; -import type { FieldType } from 'features/nodes/types/field'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; +import type { StatelessFieldType } from 'features/nodes/types/v2/field'; import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; @@ -43,14 +43,14 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { if (!newFieldType) { throw new WorkflowMigrationError(t('nodes.unknownFieldType', { type: input.type })); } - (input.type as unknown as FieldType) = newFieldType; + (input.type as unknown as StatelessFieldType) = newFieldType; }); forEach(node.data.outputs, (output) => { const newFieldType = FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type]; if (!newFieldType) { throw new WorkflowMigrationError(t('nodes.unknownFieldType', { type: output.type })); } - (output.type as unknown as FieldType) = newFieldType; + (output.type as unknown as StatelessFieldType) = newFieldType; }); // Add node pack const invocationTemplate = templates[node.data.type]; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts new file mode 100644 index 0000000000..6c74acd894 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts @@ -0,0 +1,116 @@ +import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; +import { get } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +describe('validateWorkflow', () => { + const workflow: WorkflowV3 = { + name: '', + author: '', + description: '', + version: '', + contact: '', + tags: '', + notes: '', + exposedFields: [], + meta: { version: '3.0.0', category: 'user' }, + nodes: [ + { + id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad', + type: 'invocation', + data: { + id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad', + type: 'main_model_loader', + version: '1.0.2', + label: '', + notes: '', + isOpen: true, + isIntermediate: true, + useCache: true, + inputs: { + model: { + name: 'model', + label: '', + value: { + key: '2c85d9e7-12cd-4e59-bb94-96d4502e99d4', + hash: 'random:aadc6641321ba17a324788ef1691f3584b382f0e7fa4a90be169f2a4ac77435c', + name: 'Analog-Diffusion2', + base: 'sd-1', + type: 'main', + }, + }, + }, + }, + position: { x: 394.62314170481613, y: -424.6962537790139 }, + }, + { + id: 'afad11b4-bb5c-45d1-b956-6c8e2357ee11', + type: 'invocation', + data: { + id: 'afad11b4-bb5c-45d1-b956-6c8e2357ee11', + type: 'img_resize', + version: '1.2.2', + label: '', + notes: '', + isOpen: true, + isIntermediate: true, + useCache: true, + inputs: { + board: { + name: 'board', + label: '', + value: { board_id: '99a08f09-8232-4b74-94a2-f8e136d62f8c' }, + }, + metadata: { name: 'metadata', label: 'Metadata' }, + image: { + name: 'image', + label: '', + value: { image_name: '96c124c8-f62f-4d4f-9788-72218469f298.png' }, + }, + width: { name: 'width', label: '', value: 512 }, + height: { name: 'height', label: '', value: 512 }, + resample_mode: { name: 'resample_mode', label: '', value: 'bicubic' }, + }, + }, + position: { x: -46.428806920557236, y: -479.6641524207518 }, + }, + ], + edges: [], + }; + const resolveTrue = async (): Promise => new Promise((resolve) => resolve(true)); + const resolveFalse = async (): Promise => new Promise((resolve) => resolve(false)); + it('should reset images that are inaccessible', async () => { + const validationResult = await validateWorkflow( + workflow, + { img_resize, main_model_loader }, + resolveFalse, + resolveTrue, + resolveTrue + ); + expect(validationResult.warnings.length).toBe(1); + expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined(); + }); + it('should reset boards that are inaccessible', async () => { + const validationResult = await validateWorkflow( + workflow, + { img_resize, main_model_loader }, + resolveTrue, + resolveFalse, + resolveTrue + ); + expect(validationResult.warnings.length).toBe(1); + expect(get(validationResult, 'workflow.nodes[1].data.inputs.board.value')).toBeUndefined(); + }); + it('should reset models that are inaccessible', async () => { + const validationResult = await validateWorkflow( + workflow, + { img_resize, main_model_loader }, + resolveTrue, + resolveTrue, + resolveFalse + ); + expect(validationResult.warnings.length).toBe(1); + expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined(); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index d2d3d64cb0..e757ab8e13 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,11 @@ import type { JSONObject } from 'common/types'; import { parseify } from 'common/util/serialize'; import type { Templates } from 'features/nodes/store/types'; +import { + isBoardFieldInputInstance, + isImageFieldInputInstance, + isModelIdentifierFieldInputInstance, +} from 'features/nodes/types/field'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; @@ -20,6 +25,18 @@ type ValidateWorkflowResult = { warnings: WorkflowWarning[]; }; +const MODEL_FIELD_TYPES = [ + 'ModelIdentifier', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VAEModelField', + 'LoRAModelField', + 'ControlNetModelField', + 'IPAdapterModelField', + 'T2IAdapterModelField', +]; + /** * Parses and validates a workflow: * - Parses the workflow schema, and migrates it to the latest version if necessary. @@ -27,11 +44,17 @@ type ValidateWorkflowResult = { * - Attempts to update nodes which have a mismatched version. * - Removes edges which are invalid. * @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow)) - * @param invocationTemplates The node templates to validate against. + * @param templates The node templates to validate against. * @throws {WorkflowVersionError} If the workflow version is not recognized. * @throws {z.ZodError} If there is a validation error. */ -export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => { +export const validateWorkflow = async ( + workflow: unknown, + templates: Templates, + checkImageAccess: (name: string) => Promise, + checkBoardAccess: (id: string) => Promise, + checkModelAccess: (key: string) => Promise +): Promise => { // Parse the raw workflow data & migrate it to the latest version const _workflow = parseAndMigrateWorkflow(workflow); @@ -50,8 +73,8 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat const invocationNodes = nodes.filter(isWorkflowInvocationNode); const keyedNodes = keyBy(invocationNodes, 'id'); - invocationNodes.forEach((node) => { - const template = invocationTemplates[node.data.type]; + for (const node of Object.values(invocationNodes)) { + const template = templates[node.data.type]; if (!template) { // This node's type template does not exist const message = t('nodes.missingTemplate', { @@ -62,7 +85,7 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat message, data: parseify(node), }); - return; + continue; } if (getNeedsUpdate(node.data, template)) { @@ -75,15 +98,59 @@ export const validateWorkflow = (workflow: unknown, invocationTemplates: Templat message, data: parseify({ node, nodeTemplate: template }), }); - return; + continue; } - }); + + for (const input of Object.values(node.data.inputs)) { + const fieldTemplate = template.inputs[input.name]; + + if (!fieldTemplate) { + const message = t('nodes.missingFieldTemplate'); + warnings.push({ + message, + data: parseify({ node, nodeTemplate: template, input }), + }); + continue; + } + + // We need to confirm that all images, boards and models are accessible before loading, + // else the workflow could end up with stale data an an error state. + if (fieldTemplate.type.name === 'ImageField' && isImageFieldInputInstance(input) && input.value) { + const hasAccess = await checkImageAccess(input.value.image_name); + if (!hasAccess) { + const message = t('nodes.imageAccessError', { image_name: input.value.image_name }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + if (fieldTemplate.type.name === 'BoardField' && isBoardFieldInputInstance(input) && input.value) { + const hasAccess = await checkBoardAccess(input.value.board_id); + if (!hasAccess) { + const message = t('nodes.boardAccessError', { board_id: input.value.board_id }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + if ( + MODEL_FIELD_TYPES.includes(fieldTemplate.type.name) && + isModelIdentifierFieldInputInstance(input) && + input.value + ) { + const hasAccess = await checkModelAccess(input.value.key); + if (!hasAccess) { + const message = t('nodes.modelAccessError', { key: input.value.key }); + warnings.push({ message, data: parseify({ node, nodeTemplate: template, input }) }); + input.value = undefined; + } + } + } + } edges.forEach((edge, i) => { // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. const sourceNode = keyedNodes[edge.source]; const targetNode = keyedNodes[edge.target]; - const sourceTemplate = sourceNode ? invocationTemplates[sourceNode.data.type] : undefined; - const targetTemplate = targetNode ? invocationTemplates[targetNode.data.type] : undefined; + const sourceTemplate = sourceNode ? templates[sourceNode.data.type] : undefined; + const targetTemplate = targetNode ? templates[targetNode.data.type] : undefined; const issues: string[] = []; if (!sourceNode) { diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx index 4dfbbd1f7f..d75c98c064 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx @@ -39,7 +39,6 @@ export const ParamNegativePrompt = memo(() => { fontSize="sm" variant="darkFilled" paddingRight={30} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index 9e00bea079..ebe64ea7dd 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -55,7 +55,6 @@ export const ParamPositivePrompt = memo(() => { onKeyDown={onKeyDown} variant="darkFilled" paddingRight={30} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts b/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts index d892906fcd..683f5479f9 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts @@ -1,10 +1,10 @@ import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice'; import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; +import { toast } from 'features/toast/toast'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { t } from 'i18next'; import { useCallback, useEffect } from 'react'; @@ -16,7 +16,6 @@ export const usePreselectedImage = (selectedImage?: { }) => { const dispatch = useAppDispatch(); const optimalDimension = useAppSelector(selectOptimalDimension); - const toaster = useAppToaster(); const { currentData: selectedImageDto } = useGetImageDTOQuery(selectedImage?.imageName ?? skipToken); @@ -26,14 +25,13 @@ export const usePreselectedImage = (selectedImage?: { if (selectedImageDto) { dispatch(setInitialCanvasImage(selectedImageDto, optimalDimension)); dispatch(setActiveTab('canvas')); - toaster({ + toast({ + id: 'SENT_TO_CANVAS', title: t('toast.sentToUnifiedCanvas'), status: 'info', - duration: 2500, - isClosable: true, }); } - }, [selectedImageDto, dispatch, optimalDimension, toaster]); + }, [selectedImageDto, dispatch, optimalDimension]); const handleSendToImg2Img = useCallback(() => { if (selectedImageDto) { diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx index b719ae0a92..d5b1e7dc59 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemDetail.tsx @@ -76,7 +76,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => { - {queueItem?.error && ( + {(queueItem?.error_traceback || queueItem?.error_message) && ( { {t('common.error')} - {queueItem.error} + {queueItem?.error_traceback || queueItem?.error_message} )} diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts index 8600525dae..9d92eabff8 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelBatch.ts @@ -1,5 +1,5 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelByBatchIdsMutation, useGetBatchStatusQuery } from 'services/api/endpoints/queue'; @@ -23,7 +23,6 @@ export const useCancelBatch = (batch_id: string) => { const [trigger, { isLoading }] = useCancelByBatchIdsMutation({ fixedCacheKey: 'cancelByBatchIds', }); - const dispatch = useAppDispatch(); const { t } = useTranslation(); const cancelBatch = useCallback(async () => { if (isCanceled) { @@ -31,21 +30,19 @@ export const useCancelBatch = (batch_id: string) => { } try { await trigger({ batch_ids: [batch_id] }).unwrap(); - dispatch( - addToast({ - title: t('queue.cancelBatchSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'CANCEL_BATCH_SUCCEEDED', + title: t('queue.cancelBatchSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('queue.cancelBatchFailed'), - status: 'error', - }) - ); + toast({ + id: 'CANCEL_BATCH_FAILED', + title: t('queue.cancelBatchFailed'), + status: 'error', + }); } - }, [batch_id, dispatch, isCanceled, t, trigger]); + }, [batch_id, isCanceled, t, trigger]); return { cancelBatch, isLoading, isCanceled, isDisabled: !isConnected }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts index a0275076e3..057490ed99 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts @@ -1,5 +1,5 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -9,7 +9,6 @@ export const useCancelCurrentQueueItem = () => { const isConnected = useAppSelector((s) => s.system.isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const [trigger, { isLoading }] = useCancelQueueItemMutation(); - const dispatch = useAppDispatch(); const { t } = useTranslation(); const currentQueueItemId = useMemo(() => queueStatus?.queue.item_id, [queueStatus?.queue.item_id]); const cancelQueueItem = useCallback(async () => { @@ -18,21 +17,19 @@ export const useCancelCurrentQueueItem = () => { } try { await trigger(currentQueueItemId).unwrap(); - dispatch( - addToast({ - title: t('queue.cancelSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'QUEUE_CANCEL_SUCCEEDED', + title: t('queue.cancelSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('queue.cancelFailed'), - status: 'error', - }) - ); + toast({ + id: 'QUEUE_CANCEL_FAILED', + title: t('queue.cancelFailed'), + status: 'error', + }); } - }, [currentQueueItemId, dispatch, t, trigger]); + }, [currentQueueItemId, t, trigger]); const isDisabled = useMemo(() => !isConnected || isNil(currentQueueItemId), [isConnected, currentQueueItemId]); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts index f22b98a7ee..268eca75cc 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelQueueItem.ts @@ -1,5 +1,5 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; @@ -7,26 +7,23 @@ import { useCancelQueueItemMutation } from 'services/api/endpoints/queue'; export const useCancelQueueItem = (item_id: number) => { const isConnected = useAppSelector((s) => s.system.isConnected); const [trigger, { isLoading }] = useCancelQueueItemMutation(); - const dispatch = useAppDispatch(); const { t } = useTranslation(); const cancelQueueItem = useCallback(async () => { try { await trigger(item_id).unwrap(); - dispatch( - addToast({ - title: t('queue.cancelSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'QUEUE_CANCEL_SUCCEEDED', + title: t('queue.cancelSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('queue.cancelFailed'), - status: 'error', - }) - ); + toast({ + id: 'QUEUE_CANCEL_FAILED', + title: t('queue.cancelFailed'), + status: 'error', + }); } - }, [dispatch, item_id, t, trigger]); + }, [item_id, t, trigger]); return { cancelQueueItem, isLoading, isDisabled: !isConnected }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts index 34b46d79b4..7ef9d93742 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearInvocationCache.ts @@ -1,12 +1,11 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useClearInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; export const useClearInvocationCache = () => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const isConnected = useAppSelector((s) => s.system.isConnected); const [trigger, { isLoading }] = useClearInvocationCacheMutation({ @@ -22,21 +21,19 @@ export const useClearInvocationCache = () => { try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('invocationCache.clearSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'INVOCATION_CACHE_CLEAR_SUCCEEDED', + title: t('invocationCache.clearSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('invocationCache.clearFailed'), - status: 'error', - }) - ); + toast({ + id: 'INVOCATION_CACHE_CLEAR_FAILED', + title: t('invocationCache.clearFailed'), + status: 'error', + }); } - }, [isDisabled, trigger, dispatch, t]); + }, [isDisabled, trigger, t]); return { clearInvocationCache, isLoading, cacheStatus, isDisabled }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts index 0ca2528dab..ca7d1e4894 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useClearQueue.ts @@ -1,6 +1,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useClearQueueMutation, useGetQueueStatusQuery } from 'services/api/endpoints/queue'; @@ -21,21 +21,19 @@ export const useClearQueue = () => { try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('queue.clearSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'QUEUE_CLEAR_SUCCEEDED', + title: t('queue.clearSucceeded'), + status: 'success', + }); dispatch(listCursorChanged(undefined)); dispatch(listPriorityChanged(undefined)); } catch { - dispatch( - addToast({ - title: t('queue.clearFailed'), - status: 'error', - }) - ); + toast({ + id: 'QUEUE_CLEAR_FAILED', + title: t('queue.clearFailed'), + status: 'error', + }); } }, [queueStatus?.queue.total, trigger, dispatch, t]); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts index c7d5a575d2..371e9198e7 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useDisableInvocationCache.ts @@ -1,12 +1,11 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useDisableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; export const useDisableInvocationCache = () => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const isConnected = useAppSelector((s) => s.system.isConnected); const [trigger, { isLoading }] = useDisableInvocationCacheMutation({ @@ -25,21 +24,19 @@ export const useDisableInvocationCache = () => { try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('invocationCache.disableSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'INVOCATION_CACHE_DISABLE_SUCCEEDED', + title: t('invocationCache.disableSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('invocationCache.disableFailed'), - status: 'error', - }) - ); + toast({ + id: 'INVOCATION_CACHE_DISABLE_FAILED', + title: t('invocationCache.disableFailed'), + status: 'error', + }); } - }, [isDisabled, trigger, dispatch, t]); + }, [isDisabled, trigger, t]); return { disableInvocationCache, isLoading, cacheStatus, isDisabled }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts index 22bb7aa97d..fb39cf7347 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnableInvocationCache.ts @@ -1,12 +1,11 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useEnableInvocationCacheMutation, useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; export const useEnableInvocationCache = () => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(); const isConnected = useAppSelector((s) => s.system.isConnected); const [trigger, { isLoading }] = useEnableInvocationCacheMutation({ @@ -25,21 +24,19 @@ export const useEnableInvocationCache = () => { try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('invocationCache.enableSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'INVOCATION_CACHE_ENABLE_SUCCEEDED', + title: t('invocationCache.enableSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('invocationCache.enableFailed'), - status: 'error', - }) - ); + toast({ + id: 'INVOCATION_CACHE_ENABLE_FAILED', + title: t('invocationCache.enableFailed'), + status: 'error', + }); } - }, [isDisabled, trigger, dispatch, t]); + }, [isDisabled, trigger, t]); return { enableInvocationCache, isLoading, cacheStatus, isDisabled }; }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts index 25c3423bcf..f5424c6b18 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePauseProcessor.ts @@ -1,11 +1,10 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, usePauseProcessorMutation } from 'services/api/endpoints/queue'; export const usePauseProcessor = () => { - const dispatch = useAppDispatch(); const { t } = useTranslation(); const isConnected = useAppSelector((s) => s.system.isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); @@ -21,21 +20,19 @@ export const usePauseProcessor = () => { } try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('queue.pauseSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'PAUSE_SUCCEEDED', + title: t('queue.pauseSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('queue.pauseFailed'), - status: 'error', - }) - ); + toast({ + id: 'PAUSE_FAILED', + title: t('queue.pauseFailed'), + status: 'error', + }); } - }, [isStarted, trigger, dispatch, t]); + }, [isStarted, trigger, t]); const isDisabled = useMemo(() => !isConnected || !isStarted, [isConnected, isStarted]); diff --git a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts index 2cfec364fa..eaeabe5423 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/usePruneQueue.ts @@ -1,6 +1,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { listCursorChanged, listPriorityChanged } from 'features/queue/store/queueSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, usePruneQueueMutation } from 'services/api/endpoints/queue'; @@ -30,21 +30,19 @@ export const usePruneQueue = () => { } try { const data = await trigger().unwrap(); - dispatch( - addToast({ - title: t('queue.pruneSucceeded', { item_count: data.deleted }), - status: 'success', - }) - ); + toast({ + id: 'PRUNE_SUCCEEDED', + title: t('queue.pruneSucceeded', { item_count: data.deleted }), + status: 'success', + }); dispatch(listCursorChanged(undefined)); dispatch(listPriorityChanged(undefined)); } catch { - dispatch( - addToast({ - title: t('queue.pruneFailed'), - status: 'error', - }) - ); + toast({ + id: 'PRUNE_FAILED', + title: t('queue.pruneFailed'), + status: 'error', + }); } }, [finishedCount, trigger, dispatch, t]); diff --git a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts index 6e3ea83d7d..851b268416 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useResumeProcessor.ts @@ -1,11 +1,10 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery, useResumeProcessorMutation } from 'services/api/endpoints/queue'; export const useResumeProcessor = () => { - const dispatch = useAppDispatch(); const isConnected = useAppSelector((s) => s.system.isConnected); const { data: queueStatus } = useGetQueueStatusQuery(); const { t } = useTranslation(); @@ -21,21 +20,19 @@ export const useResumeProcessor = () => { } try { await trigger().unwrap(); - dispatch( - addToast({ - title: t('queue.resumeSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'PROCESSOR_RESUMED', + title: t('queue.resumeSucceeded'), + status: 'success', + }); } catch { - dispatch( - addToast({ - title: t('queue.resumeFailed'), - status: 'error', - }) - ); + toast({ + id: 'PROCESSOR_RESUME_FAILED', + title: t('queue.resumeFailed'), + status: 'error', + }); } - }, [isStarted, trigger, dispatch, t]); + }, [isStarted, trigger, t]); const isDisabled = useMemo(() => !isConnected || isStarted, [isConnected, isStarted]); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt.tsx index afc116a903..bba9e0b32d 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt.tsx @@ -42,7 +42,6 @@ export const ParamSDXLNegativeStylePrompt = memo(() => { fontSize="sm" variant="darkFilled" paddingRight={30} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt.tsx index b16730db45..3828136c74 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt.tsx @@ -39,7 +39,6 @@ export const ParamSDXLPositiveStylePrompt = memo(() => { fontSize="sm" variant="darkFilled" paddingRight={30} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/useClearIntermediates.ts b/invokeai/frontend/web/src/features/system/components/SettingsModal/useClearIntermediates.ts index e9f1debcf8..f392acb521 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/useClearIntermediates.ts +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/useClearIntermediates.ts @@ -1,7 +1,7 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useClearIntermediatesMutation, useGetIntermediatesCountQuery } from 'services/api/endpoints/images'; @@ -42,20 +42,18 @@ export const useClearIntermediates = (shouldShowClearIntermediates: boolean): Us .then((clearedCount) => { dispatch(controlAdaptersReset()); dispatch(resetCanvas()); - dispatch( - addToast({ - title: t('settings.intermediatesCleared', { count: clearedCount }), - status: 'info', - }) - ); + toast({ + id: 'INTERMEDIATES_CLEARED', + title: t('settings.intermediatesCleared', { count: clearedCount }), + status: 'info', + }); }) .catch(() => { - dispatch( - addToast({ - title: t('settings.intermediatesClearedFailed'), - status: 'error', - }) - ); + toast({ + id: 'INTERMEDIATES_CLEAR_FAILED', + title: t('settings.intermediatesClearedFailed'), + status: 'error', + }); }); }, [t, _clearIntermediates, dispatch, hasPendingItems]); diff --git a/invokeai/frontend/web/src/features/system/store/configSlice.ts b/invokeai/frontend/web/src/features/system/store/configSlice.ts index 76280df1ce..7d26dbd34c 100644 --- a/invokeai/frontend/web/src/features/system/store/configSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/configSlice.ts @@ -15,6 +15,7 @@ const baseDimensionConfig: NumericalParameterConfig = { }; const initialConfigState: AppConfig = { + isLocal: true, shouldUpdateImagesOnConnect: false, shouldFetchMetadataFromApi: false, disabledTabs: [], diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 17ddec5471..488410d5f3 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,25 +1,16 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice, isAnyOf } from '@reduxjs/toolkit'; +import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage'; -import { makeToast } from 'features/system/util/makeToast'; -import { t } from 'i18next'; -import { startCase } from 'lodash-es'; import type { LogLevelName } from 'roarr'; import { socketConnected, socketDisconnected, socketGeneratorProgress, - socketGraphExecutionStateComplete, socketInvocationComplete, - socketInvocationError, - socketInvocationRetrievalError, socketInvocationStarted, - socketModelLoadCompleted, + socketModelLoadComplete, socketModelLoadStarted, socketQueueItemStatusChanged, - socketSessionRetrievalError, } from 'services/events/actions'; import type { Language, SystemState } from './types'; @@ -29,7 +20,6 @@ const initialSystemState: SystemState = { isConnected: false, shouldConfirmOnDelete: true, enableImageDebugging: false, - toastQueue: [], denoiseProgress: null, shouldAntialiasProgressImage: false, consoleLogLevel: 'debug', @@ -39,6 +29,7 @@ const initialSystemState: SystemState = { shouldUseWatermarker: false, shouldEnableInformationalPopovers: false, status: 'DISCONNECTED', + cancellations: [], }; export const systemSlice = createSlice({ @@ -51,12 +42,6 @@ export const systemSlice = createSlice({ setEnableImageDebugging: (state, action: PayloadAction) => { state.enableImageDebugging = action.payload; }, - addToast: (state, action: PayloadAction) => { - state.toastQueue.push(action.payload); - }, - clearToastQueue: (state) => { - state.toastQueue = []; - }, consoleLogLevelChanged: (state, action: PayloadAction) => { state.consoleLogLevel = action.payload; }, @@ -102,6 +87,7 @@ export const systemSlice = createSlice({ * Invocation Started */ builder.addCase(socketInvocationStarted, (state) => { + state.cancellations = []; state.denoiseProgress = null; state.status = 'PROCESSING'; }); @@ -110,20 +96,18 @@ export const systemSlice = createSlice({ * Generator Progress */ builder.addCase(socketGeneratorProgress, (state, action) => { - const { - step, - total_steps, - order, - progress_image, - graph_execution_state_id: session_id, - queue_batch_id: batch_id, - } = action.payload.data; + const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data; + + if (state.cancellations.includes(session_id)) { + // Do not update the progress if this session has been cancelled. This prevents a race condition where we get a + // progress update after the session has been cancelled. + return; + } state.denoiseProgress = { step, total_steps, - order, - percentage: calculateStepPercentage(step, total_steps, order), + percentage, progress_image, session_id, batch_id, @@ -140,51 +124,27 @@ export const systemSlice = createSlice({ state.status = 'CONNECTED'; }); - /** - * Graph Execution State Complete - */ - builder.addCase(socketGraphExecutionStateComplete, (state) => { - state.denoiseProgress = null; - state.status = 'CONNECTED'; - }); - builder.addCase(socketModelLoadStarted, (state) => { state.status = 'LOADING_MODEL'; }); - builder.addCase(socketModelLoadCompleted, (state) => { + builder.addCase(socketModelLoadComplete, (state) => { state.status = 'CONNECTED'; }); builder.addCase(socketQueueItemStatusChanged, (state, action) => { - if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) { + if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) { state.status = 'CONNECTED'; state.denoiseProgress = null; + state.cancellations.push(action.payload.data.session_id); } }); - - // *** Matchers - must be after all cases *** - - /** - * Any server error - */ - builder.addMatcher(isAnyServerError, (state, action) => { - state.toastQueue.push( - makeToast({ - title: t('toast.serverError'), - status: 'error', - description: startCase(action.payload.data.error_type), - }) - ); - }); }, }); export const { setShouldConfirmOnDelete, setEnableImageDebugging, - addToast, - clearToastQueue, consoleLogLevelChanged, shouldLogToConsoleChanged, shouldAntialiasProgressImageChanged, @@ -194,8 +154,6 @@ export const { setShouldEnableInformationalPopovers, } = systemSlice.actions; -const isAnyServerError = isAnyOf(socketInvocationError, socketSessionRetrievalError, socketInvocationRetrievalError); - export const selectSystemSlice = (state: RootState) => state.system; /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ @@ -210,5 +168,5 @@ export const systemPersistConfig: PersistConfig = { name: systemSlice.name, initialState: initialSystemState, migrate: migrateSystemState, - persistDenylist: ['isConnected', 'denoiseProgress', 'status'], + persistDenylist: ['isConnected', 'denoiseProgress', 'status', 'cancellations'], }; diff --git a/invokeai/frontend/web/src/features/system/store/types.ts b/invokeai/frontend/web/src/features/system/store/types.ts index 430df9aa7d..d896dee5f5 100644 --- a/invokeai/frontend/web/src/features/system/store/types.ts +++ b/invokeai/frontend/web/src/features/system/store/types.ts @@ -1,4 +1,3 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { LogLevel } from 'app/logging/logger'; import type { ProgressImage } from 'services/events/types'; import { z } from 'zod'; @@ -11,7 +10,6 @@ type DenoiseProgress = { progress_image: ProgressImage | null | undefined; step: number; total_steps: number; - order: number; percentage: number; }; @@ -47,7 +45,6 @@ export interface SystemState { isConnected: boolean; shouldConfirmOnDelete: boolean; enableImageDebugging: boolean; - toastQueue: UseToastOptions[]; denoiseProgress: DenoiseProgress | null; consoleLogLevel: LogLevel; shouldLogToConsole: boolean; @@ -57,4 +54,5 @@ export interface SystemState { shouldUseWatermarker: boolean; status: SystemStatus; shouldEnableInformationalPopovers: boolean; + cancellations: string[]; } diff --git a/invokeai/frontend/web/src/features/system/util/calculateStepPercentage.ts b/invokeai/frontend/web/src/features/system/util/calculateStepPercentage.ts deleted file mode 100644 index 70902e4a92..0000000000 --- a/invokeai/frontend/web/src/features/system/util/calculateStepPercentage.ts +++ /dev/null @@ -1,13 +0,0 @@ -export const calculateStepPercentage = (step: number, total_steps: number, order: number) => { - if (total_steps === 0) { - return 0; - } - - // we add one extra to step so that the progress bar will be full when denoise completes - - if (order === 2) { - return Math.floor((step + 1 + 1) / 2) / Math.floor((total_steps + 1) / 2); - } - - return (step + 1 + 1) / (total_steps + 1); -}; diff --git a/invokeai/frontend/web/src/features/system/util/makeToast.ts b/invokeai/frontend/web/src/features/system/util/makeToast.ts deleted file mode 100644 index aa77fd60ae..0000000000 --- a/invokeai/frontend/web/src/features/system/util/makeToast.ts +++ /dev/null @@ -1,20 +0,0 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; - -export type MakeToastArg = string | UseToastOptions; - -/** - * Makes a toast from a string or a UseToastOptions object. - * If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms. - */ -export const makeToast = (arg: MakeToastArg): UseToastOptions => { - if (typeof arg === 'string') { - return { - title: arg, - status: 'info', - isClosable: true, - duration: 2500, - }; - } - - return { status: 'info', isClosable: true, duration: 2500, ...arg }; -}; diff --git a/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx b/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx new file mode 100644 index 0000000000..7b23e5534b --- /dev/null +++ b/invokeai/frontend/web/src/features/toast/ErrorToastDescription.tsx @@ -0,0 +1,59 @@ +import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; +import { t } from 'i18next'; +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiCopyBold } from 'react-icons/pi'; + +function onCopy(sessionId: string) { + navigator.clipboard.writeText(sessionId); +} + +const ERROR_TYPE_TO_TITLE: Record = { + OutOfMemoryError: 'toast.outOfMemoryError', +}; + +const COMMERCIAL_ERROR_TYPE_TO_DESC: Record = { + OutOfMemoryError: 'toast.outOfMemoryErrorDesc', +}; + +export const getTitleFromErrorType = (errorType: string) => { + return t(ERROR_TYPE_TO_TITLE[errorType] ?? 'toast.serverError'); +}; + +type Props = { errorType: string; errorMessage?: string | null; sessionId: string; isLocal: boolean }; + +export default function ErrorToastDescription({ errorType, errorMessage, sessionId, isLocal }: Props) { + const { t } = useTranslation(); + const description = useMemo(() => { + // Special handling for commercial error types + const descriptionTKey = isLocal ? null : COMMERCIAL_ERROR_TYPE_TO_DESC[errorType]; + if (descriptionTKey) { + return t(descriptionTKey); + } + if (errorMessage) { + return `${errorType}: ${errorMessage}`; + } + }, [errorMessage, errorType, isLocal, t]); + return ( + + {description && {description}} + {!isLocal && ( + + + {t('toast.sessionRef', { sessionId })} + + } + onClick={onCopy.bind(null, sessionId)} + variant="ghost" + sx={sx} + /> + + )} + + ); +} + +const sx = { svg: { fill: 'base.50' } }; diff --git a/invokeai/frontend/web/src/features/toast/toast.ts b/invokeai/frontend/web/src/features/toast/toast.ts new file mode 100644 index 0000000000..2bb499a854 --- /dev/null +++ b/invokeai/frontend/web/src/features/toast/toast.ts @@ -0,0 +1,122 @@ +import type { UseToastOptions } from '@invoke-ai/ui-library'; +import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; +import { map } from 'nanostores'; + +const toastApi = createStandaloneToast({ + theme: theme, + defaultOptions: TOAST_OPTIONS.defaultOptions, +}).toast; + +// Slightly modified version of UseToastOptions +type ToastConfig = Omit & { + // Only string - Chakra allows numbers + id?: string; +}; + +type ToastArg = ToastConfig & { + /** + * Whether to append the number of times this toast has been shown to the title. Defaults to true. + * @example + * toast({ title: 'Hello', withCount: true }); + * // first toast: 'Hello' + * // second toast: 'Hello (2)' + */ + withCount?: boolean; + /** + * Whether to update the description when updating the toast. Defaults to true. + * @example + * // updateDescription: true + * toast({ title: 'Hello', description: 'Foo' }); // Foo + * toast({ title: 'Hello', description: 'Bar' }); // Bar + * @example + * // updateDescription: false + * toast({ title: 'Hello', description: 'Foo' }); // Foo + * toast({ title: 'Hello', description: 'Bar' }); // Foo + */ + updateDescription?: boolean; +}; + +type ToastInternalState = { + id: string; + config: ToastConfig; + count: number; +}; + +// We expose a limited API for the toast +type ToastApi = { + getState: () => ToastInternalState | null; + close: () => void; + isActive: () => boolean; +}; + +// Store each toast state by id, allowing toast consumers to not worry about persistent ids and updating and such +const $toastMap = map>({}); + +// Helpers to get the getters for the toast API +const getIsActive = (id: string) => () => toastApi.isActive(id); +const getClose = (id: string) => () => toastApi.close(id); +const getGetState = (id: string) => () => $toastMap.get()[id] ?? null; + +/** + * Creates a toast with the given config. If the toast with the same id already exists, it will be updated. + * When a toast is updated, its title, description, status and duration will be overwritten by the new config. + * Use `updateDescription: false` to keep the description when updating. + * Set duration to `null` to make the toast persistent. + * @param arg The toast config. + * @returns An object with methods to get the toast state, close the toast and check if the toast is active + */ +export const toast = (arg: ToastArg): ToastApi => { + // All toasts need an id, set a random one if not provided + const id = arg.id ?? crypto.randomUUID(); + if (!arg.id) { + arg.id = id; + } + if (arg.withCount === undefined) { + arg.withCount = true; + } + if (arg.updateDescription === undefined) { + arg.updateDescription = true; + } + let state = $toastMap.get()[arg.id]; + if (!state) { + // First time caller, create and set the state + state = { id, config: parseConfig(null, id, arg, 1), count: 1 }; + $toastMap.setKey(id, state); + // Create the toast + toastApi(state.config); + } else { + // This toast is already active, update its state + state.count += 1; + state.config = parseConfig(state, id, arg, state.count); + $toastMap.setKey(id, state); + // Update the toast itself + toastApi.update(id, state.config); + } + return { getState: getGetState(id), close: getClose(id), isActive: getIsActive(id) }; +}; + +/** + * Give a toast id, arg and current count, returns the parsed toast config (including dynamic title and description) + * @param state The current state of the toast or null if it doesn't exist + * @param id The id of the toast + * @param arg The arg passed to the toast function + * @param count The current call count of the toast + * @returns The parsed toast config + */ +const parseConfig = (state: ToastInternalState | null, id: string, arg: ToastArg, count: number): ToastConfig => { + const onCloseComplete = () => { + $toastMap.setKey(id, undefined); + if (arg.onCloseComplete) { + arg.onCloseComplete(); + } + }; + const title = arg.withCount && count > 1 ? `${arg.title} (${count})` : arg.title; + const description = !arg.updateDescription && state ? state.config.description : arg.description; + const config: ToastConfig = { + ...arg, + title, + description, + onCloseComplete, + }; + return config; +}; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/NewWorkflowConfirmationAlertDialog.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/NewWorkflowConfirmationAlertDialog.tsx index b01d259da7..701441b093 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/components/NewWorkflowConfirmationAlertDialog.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/components/NewWorkflowConfirmationAlertDialog.tsx @@ -2,8 +2,7 @@ import { ConfirmationAlertDialog, Flex, Text, useDisclosure } from '@invoke-ai/u import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { workflowModeChanged } from 'features/nodes/store/workflowSlice'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -21,14 +20,11 @@ export const NewWorkflowConfirmationAlertDialog = memo((props: Props) => { dispatch(nodeEditorReset()); dispatch(workflowModeChanged('edit')); - dispatch( - addToast( - makeToast({ - title: t('workflows.newWorkflowCreated'), - status: 'success', - }) - ) - ); + toast({ + id: 'NEW_WORKFLOW_CREATED', + title: t('workflows.newWorkflowCreated'), + status: 'success', + }); onClose(); }, [dispatch, onClose, t]); diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx index 8f3cb0c6f6..8006ca937f 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx @@ -1,15 +1,19 @@ import { MenuItem } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { useLoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal'; +import { size } from 'lodash-es'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFlaskBold } from 'react-icons/pi'; const LoadWorkflowFromGraphMenuItem = () => { const { t } = useTranslation(); + const templates = useStore($templates); const { onOpen } = useLoadWorkflowFromGraphModal(); return ( - } onClick={onOpen}> + } onClick={onOpen} isDisabled={!size(templates)}> {t('workflows.loadFromGraph')} ); diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useDeleteLibraryWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useDeleteLibraryWorkflow.ts index 7d2c636e9c..7b93a34f83 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useDeleteLibraryWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useDeleteLibraryWorkflow.ts @@ -1,5 +1,4 @@ -import { useToast } from '@invoke-ai/ui-library'; -import { useAppToaster } from 'app/components/Toaster'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useDeleteWorkflowMutation, workflowsApi } from 'services/api/endpoints/workflows'; @@ -17,8 +16,6 @@ type UseDeleteLibraryWorkflowReturn = { type UseDeleteLibraryWorkflow = (arg: UseDeleteLibraryWorkflowOptions) => UseDeleteLibraryWorkflowReturn; export const useDeleteLibraryWorkflow: UseDeleteLibraryWorkflow = ({ onSuccess, onError }) => { - const toaster = useAppToaster(); - const toast = useToast(); const { t } = useTranslation(); const [_deleteWorkflow, deleteWorkflowResult] = useDeleteWorkflowMutation(); @@ -26,21 +23,21 @@ export const useDeleteLibraryWorkflow: UseDeleteLibraryWorkflow = ({ onSuccess, async (workflow_id: string) => { try { await _deleteWorkflow(workflow_id).unwrap(); - toaster({ + toast({ + id: 'WORKFLOW_DELETED', title: t('toast.workflowDeleted'), }); onSuccess && onSuccess(); } catch { - if (!toast.isActive(`auth-error-toast-${workflowsApi.endpoints.deleteWorkflow.name}`)) { - toaster({ - title: t('toast.problemDeletingWorkflow'), - status: 'error', - }); - } + toast({ + id: `AUTH_ERROR_TOAST_${workflowsApi.endpoints.deleteWorkflow.name}`, + title: t('toast.problemDeletingWorkflow'), + status: 'error', + }); onError && onError(); } }, - [_deleteWorkflow, toaster, t, onSuccess, onError, toast] + [_deleteWorkflow, t, onSuccess, onError] ); return { deleteWorkflow, deleteWorkflowResult }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow.ts index 7ea9329540..12c302f9c9 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow.ts @@ -1,6 +1,6 @@ -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; import { workflowLoadRequested } from 'features/nodes/store/actions'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useLazyGetImageWorkflowQuery } from 'services/api/endpoints/images'; @@ -21,7 +21,6 @@ type UseGetAndLoadEmbeddedWorkflow = ( export const useGetAndLoadEmbeddedWorkflow: UseGetAndLoadEmbeddedWorkflow = ({ onSuccess, onError }) => { const dispatch = useAppDispatch(); - const toaster = useAppToaster(); const { t } = useTranslation(); const [_getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult] = useLazyGetImageWorkflowQuery(); const getAndLoadEmbeddedWorkflow = useCallback( @@ -33,20 +32,22 @@ export const useGetAndLoadEmbeddedWorkflow: UseGetAndLoadEmbeddedWorkflow = ({ o // No toast - the listener for this action does that after the workflow is loaded onSuccess && onSuccess(); } else { - toaster({ + toast({ + id: 'PROBLEM_RETRIEVING_WORKFLOW', title: t('toast.problemRetrievingWorkflow'), status: 'error', }); } } catch { - toaster({ + toast({ + id: 'PROBLEM_RETRIEVING_WORKFLOW', title: t('toast.problemRetrievingWorkflow'), status: 'error', }); onError && onError(); } }, - [_getAndLoadEmbeddedWorkflow, dispatch, onSuccess, toaster, t, onError] + [_getAndLoadEmbeddedWorkflow, dispatch, onSuccess, t, onError] ); return { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow.ts index f616812175..89933999bd 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow.ts @@ -1,5 +1,4 @@ import { useToast } from '@invoke-ai/ui-library'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch } from 'app/store/storeHooks'; import { workflowLoadRequested } from 'features/nodes/store/actions'; import { useCallback } from 'react'; @@ -20,7 +19,6 @@ type UseGetAndLoadLibraryWorkflow = (arg: UseGetAndLoadLibraryWorkflowOptions) = export const useGetAndLoadLibraryWorkflow: UseGetAndLoadLibraryWorkflow = ({ onSuccess, onError }) => { const dispatch = useAppDispatch(); - const toaster = useAppToaster(); const toast = useToast(); const { t } = useTranslation(); const [_getAndLoadWorkflow, getAndLoadWorkflowResult] = useLazyGetWorkflowQuery(); @@ -33,16 +31,15 @@ export const useGetAndLoadLibraryWorkflow: UseGetAndLoadLibraryWorkflow = ({ onS // No toast - the listener for this action does that after the workflow is loaded onSuccess && onSuccess(); } catch { - if (!toast.isActive(`auth-error-toast-${workflowsApi.endpoints.getWorkflow.name}`)) { - toaster({ - title: t('toast.problemRetrievingWorkflow'), - status: 'error', - }); - } + toast({ + id: `AUTH_ERROR_TOAST_${workflowsApi.endpoints.getWorkflow.name}`, + title: t('toast.problemRetrievingWorkflow'), + status: 'error', + }); onError && onError(); } }, - [_getAndLoadWorkflow, dispatch, onSuccess, toaster, t, onError, toast] + [_getAndLoadWorkflow, dispatch, onSuccess, t, onError, toast] ); return { getAndLoadWorkflow, getAndLoadWorkflowResult }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useLoadWorkflowFromFile.tsx b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useLoadWorkflowFromFile.tsx index 7a39d4ecd0..94a1ef5c51 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useLoadWorkflowFromFile.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useLoadWorkflowFromFile.tsx @@ -1,8 +1,7 @@ import { useLogger } from 'app/logging/useLogger'; import { useAppDispatch } from 'app/store/storeHooks'; import { workflowLoadRequested } from 'features/nodes/store/actions'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { workflowLoadedFromFile } from 'features/workflowLibrary/store/actions'; import type { RefObject } from 'react'; import { useCallback } from 'react'; @@ -35,14 +34,11 @@ export const useLoadWorkflowFromFile: UseLoadWorkflowFromFile = ({ resetRef, onS } catch (e) { // There was a problem reading the file logger.error(t('nodes.unableToLoadWorkflow')); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToLoadWorkflow'), - status: 'error', - }) - ) - ); + toast({ + id: 'UNABLE_TO_LOAD_WORKFLOW', + title: t('nodes.unableToLoadWorkflow'), + status: 'error', + }); reader.abort(); } }; diff --git a/invokeai/frontend/web/src/services/api/authToastMiddleware.ts b/invokeai/frontend/web/src/services/api/authToastMiddleware.ts index 002cc96174..3a906a613b 100644 --- a/invokeai/frontend/web/src/services/api/authToastMiddleware.ts +++ b/invokeai/frontend/web/src/services/api/authToastMiddleware.ts @@ -1,6 +1,6 @@ -import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; +import type { Middleware } from '@reduxjs/toolkit'; import { isRejectedWithValue } from '@reduxjs/toolkit'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { z } from 'zod'; @@ -22,7 +22,7 @@ const zRejectedForbiddenAction = z.object({ .optional(), }); -export const authToastMiddleware: Middleware = (api: MiddlewareAPI) => (next) => (action) => { +export const authToastMiddleware: Middleware = () => (next) => (action) => { if (isRejectedWithValue(action)) { try { const parsed = zRejectedForbiddenAction.parse(action); @@ -32,16 +32,13 @@ export const authToastMiddleware: Middleware = (api: MiddlewareAPI) => (next) => return; } - const { dispatch } = api; const customMessage = parsed.payload.data.detail !== 'Forbidden' ? parsed.payload.data.detail : undefined; - dispatch( - addToast({ - id: `auth-error-toast-${endpointName}`, - title: t('common.somethingWentWrong'), - status: 'error', - description: customMessage, - }) - ); + toast({ + id: `auth-error-toast-${endpointName}`, + title: t('toast.somethingWentWrong'), + status: 'error', + description: customMessage, + }); } catch (error) { // no-op } diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 98c253b479..c9052a607d 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -4,7 +4,7 @@ import { getStore } from 'app/store/nanostores/store'; import type { JSONObject } from 'common/types'; import type { BoardId } from 'features/gallery/store/types'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES, IMAGE_LIMIT } from 'features/gallery/store/types'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { keyBy } from 'lodash-es'; import type { components, paths } from 'services/api/schema'; @@ -206,13 +206,12 @@ export const imagesApi = api.injectEndpoints({ const { data } = await queryFulfilled; if (data.deleted_images.length < imageDTOs.length) { - dispatch( - addToast({ - title: t('gallery.problemDeletingImages'), - description: t('gallery.problemDeletingImagesDesc'), - status: 'warning', - }) - ); + toast({ + id: 'problem-deleting-images', + title: t('gallery.problemDeletingImages'), + description: t('gallery.problemDeletingImagesDesc'), + status: 'warning', + }); } // convert to an object so we can access the successfully delete image DTOs by name @@ -571,11 +570,15 @@ export const imagesApi = api.injectEndpoints({ session_id?: string; board_id?: string; crop_visible?: boolean; + metadata?: JSONObject; } >({ - query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible }) => { + query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible, metadata }) => { const formData = new FormData(); formData.append('file', file); + if (metadata) { + formData.append('metadata', JSON.stringify(metadata)); + } return { url: buildImagesUrl('upload'), method: 'POST', diff --git a/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts b/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts new file mode 100644 index 0000000000..00e27d49c6 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/accessChecks.ts @@ -0,0 +1,55 @@ +import { getStore } from 'app/store/nanostores/store'; +import { boardsApi } from 'services/api/endpoints/boards'; +import { imagesApi } from 'services/api/endpoints/images'; +import { modelsApi } from 'services/api/endpoints/models'; + +/** + * Checks if the client has access to a model. + * @param key The model key. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkModelAccess = async (key: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key)); + req.unsubscribe(); + const result = await req.unwrap(); + return Boolean(result); + } catch { + return false; + } +}; + +/** + * Checks if the client has access to an image. + * @param name The image name. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkImageAccess = async (name: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(imagesApi.endpoints.getImageDTO.initiate(name)); + req.unsubscribe(); + const result = await req.unwrap(); + return Boolean(result); + } catch { + return false; + } +}; + +/** + * Checks if the client has access to a board. + * @param id The board id. + * @returns A promise that resolves to true if the client has access, else false. + */ +export const checkBoardAccess = async (id: string): Promise => { + const { dispatch } = getStore(); + try { + const req = dispatch(boardsApi.endpoints.listAllBoards.initiate()); + req.unsubscribe(); + const result = await req.unwrap(); + return result.some((b) => b.board_id === id); + } catch { + return false; + } +}; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index c1f9486bc7..67b39237b1 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -581,7 +581,6 @@ export type components = { * type * @default add * @constant - * @enum {string} */ type: "add"; }; @@ -619,7 +618,6 @@ export type components = { * type * @default alpha_mask_to_tensor * @constant - * @enum {string} */ type: "alpha_mask_to_tensor"; }; @@ -745,7 +743,6 @@ export type components = { * Type * @default basemetadata * @constant - * @enum {string} */ type?: "basemetadata"; }; @@ -898,7 +895,6 @@ export type components = { * type * @default blank_image * @constant - * @enum {string} */ type: "blank_image"; }; @@ -938,7 +934,6 @@ export type components = { * type * @default lblend * @constant - * @enum {string} */ type: "lblend"; }; @@ -1175,6 +1170,8 @@ export type components = { * Format: binary */ file: Blob; + /** @description The metadata to associate with the image */ + metadata?: components["schemas"]["JsonValue"] | null; }; /** * Boolean Collection Primitive @@ -1208,7 +1205,6 @@ export type components = { * type * @default boolean_collection * @constant - * @enum {string} */ type: "boolean_collection"; }; @@ -1226,7 +1222,6 @@ export type components = { * type * @default boolean_collection_output * @constant - * @enum {string} */ type: "boolean_collection_output"; }; @@ -1262,7 +1257,6 @@ export type components = { * type * @default boolean * @constant - * @enum {string} */ type: "boolean"; }; @@ -1280,7 +1274,6 @@ export type components = { * type * @default boolean_output * @constant - * @enum {string} */ type: "boolean_output"; }; @@ -1315,7 +1308,6 @@ export type components = { * type * @default clip_output * @constant - * @enum {string} */ type: "clip_output"; }; @@ -1356,7 +1348,6 @@ export type components = { * type * @default clip_skip * @constant - * @enum {string} */ type: "clip_skip"; }; @@ -1375,7 +1366,6 @@ export type components = { * type * @default clip_skip_output * @constant - * @enum {string} */ type: "clip_skip_output"; }; @@ -1409,6 +1399,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -1421,17 +1412,18 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Format * @constant - * @enum {string} */ format: "diffusers"; /** @default */ @@ -1440,7 +1432,6 @@ export type components = { * Type * @default clip_vision * @constant - * @enum {string} */ type: "clip_vision"; }; @@ -1476,7 +1467,6 @@ export type components = { * type * @default infill_cv2 * @constant - * @enum {string} */ type: "infill_cv2"; }; @@ -1536,7 +1526,6 @@ export type components = { * type * @default calculate_image_tiles_even_split * @constant - * @enum {string} */ type: "calculate_image_tiles_even_split"; }; @@ -1596,7 +1585,6 @@ export type components = { * type * @default calculate_image_tiles * @constant - * @enum {string} */ type: "calculate_image_tiles"; }; @@ -1656,7 +1644,6 @@ export type components = { * type * @default calculate_image_tiles_min_overlap * @constant - * @enum {string} */ type: "calculate_image_tiles_min_overlap"; }; @@ -1671,7 +1658,6 @@ export type components = { * type * @default calculate_image_tiles_output * @constant - * @enum {string} */ type: "calculate_image_tiles_output"; }; @@ -1742,7 +1728,6 @@ export type components = { * type * @default canny_image_processor * @constant - * @enum {string} */ type: "canny_image_processor"; }; @@ -1788,7 +1773,6 @@ export type components = { * type * @default canvas_paste_back * @constant - * @enum {string} */ type: "canvas_paste_back"; }; @@ -1844,7 +1828,6 @@ export type components = { * type * @default img_pad_crop * @constant - * @enum {string} */ type: "img_pad_crop"; }; @@ -1896,7 +1879,6 @@ export type components = { * type * @default collect * @constant - * @enum {string} */ type: "collect"; }; @@ -1911,7 +1893,6 @@ export type components = { * type * @default collect_output * @constant - * @enum {string} */ type: "collect_output"; }; @@ -1929,7 +1910,6 @@ export type components = { * type * @default color_collection_output * @constant - * @enum {string} */ type: "color_collection_output"; }; @@ -1976,7 +1956,6 @@ export type components = { * type * @default color_correct * @constant - * @enum {string} */ type: "color_correct"; }; @@ -2042,7 +2021,6 @@ export type components = { * type * @default color * @constant - * @enum {string} */ type: "color"; }; @@ -2084,7 +2062,6 @@ export type components = { * type * @default color_map_image_processor * @constant - * @enum {string} */ type: "color_map_image_processor"; }; @@ -2099,7 +2076,6 @@ export type components = { * type * @default color_output * @constant - * @enum {string} */ type: "color_output"; }; @@ -2142,7 +2118,6 @@ export type components = { * type * @default compel * @constant - * @enum {string} */ type: "compel"; }; @@ -2178,7 +2153,6 @@ export type components = { * type * @default conditioning_collection * @constant - * @enum {string} */ type: "conditioning_collection"; }; @@ -2196,7 +2170,6 @@ export type components = { * type * @default conditioning_collection_output * @constant - * @enum {string} */ type: "conditioning_collection_output"; }; @@ -2244,7 +2217,6 @@ export type components = { * type * @default conditioning * @constant - * @enum {string} */ type: "conditioning"; }; @@ -2259,7 +2231,6 @@ export type components = { * type * @default conditioning_output * @constant - * @enum {string} */ type: "conditioning_output"; }; @@ -2325,7 +2296,6 @@ export type components = { * type * @default content_shuffle_image_processor * @constant - * @enum {string} */ type: "content_shuffle_image_processor"; }; @@ -2378,7 +2348,10 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetCheckpointConfig: { - /** @description Default settings for this model */ + /** + * @description Default settings for this model + * @default null + */ default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key @@ -2405,6 +2378,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -2417,18 +2391,19 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Format * @default checkpoint * @constant - * @enum {string} */ format: "checkpoint"; /** @@ -2445,7 +2420,6 @@ export type components = { * Type * @default controlnet * @constant - * @enum {string} */ type: "controlnet"; }; @@ -2454,7 +2428,10 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetDiffusersConfig: { - /** @description Default settings for this model */ + /** + * @description Default settings for this model + * @default null + */ default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key @@ -2481,6 +2458,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -2493,18 +2471,19 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Format * @default diffusers * @constant - * @enum {string} */ format: "diffusers"; /** @default */ @@ -2513,7 +2492,6 @@ export type components = { * Type * @default controlnet * @constant - * @enum {string} */ type: "controlnet"; }; @@ -2542,7 +2520,7 @@ export type components = { /** @description The control image */ image?: components["schemas"]["ImageField"]; /** @description ControlNet model to load */ - control_model: components["schemas"]["ModelIdentifierField"]; + control_model?: components["schemas"]["ModelIdentifierField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -2579,7 +2557,6 @@ export type components = { * type * @default controlnet * @constant - * @enum {string} */ type: "controlnet"; }; @@ -2635,7 +2612,6 @@ export type components = { * type * @default control_output * @constant - * @enum {string} */ type: "control_output"; }; @@ -2826,7 +2802,6 @@ export type components = { * type * @default core_metadata * @constant - * @enum {string} */ type: "core_metadata"; [key: string]: unknown; @@ -2875,7 +2850,6 @@ export type components = { * type * @default create_denoise_mask * @constant - * @enum {string} */ type: "create_denoise_mask"; }; @@ -2952,7 +2926,6 @@ export type components = { * type * @default create_gradient_mask * @constant - * @enum {string} */ type: "create_gradient_mask"; }; @@ -3005,7 +2978,6 @@ export type components = { * type * @default crop_latents * @constant - * @enum {string} */ type: "crop_latents"; }; @@ -3061,7 +3033,6 @@ export type components = { * type * @default cv_inpaint * @constant - * @enum {string} */ type: "cv_inpaint"; }; @@ -3118,7 +3089,6 @@ export type components = { * type * @default dw_openpose_image_processor * @constant - * @enum {string} */ type: "dw_openpose_image_processor"; }; @@ -3241,7 +3211,6 @@ export type components = { * type * @default denoise_latents * @constant - * @enum {string} */ type: "denoise_latents"; }; @@ -3279,7 +3248,6 @@ export type components = { * type * @default denoise_mask_output * @constant - * @enum {string} */ type: "denoise_mask_output"; }; @@ -3328,7 +3296,6 @@ export type components = { * type * @default depth_anything_image_processor * @constant - * @enum {string} */ type: "depth_anything_image_processor"; }; @@ -3370,7 +3337,6 @@ export type components = { * type * @default div * @constant - * @enum {string} */ type: "div"; }; @@ -3505,7 +3471,6 @@ export type components = { * type * @default dynamic_prompt * @constant - * @enum {string} */ type: "dynamic_prompt"; }; @@ -3561,7 +3526,6 @@ export type components = { * type * @default esrgan * @constant - * @enum {string} */ type: "esrgan"; }; @@ -3661,7 +3625,6 @@ export type components = { * type * @default face_identifier * @constant - * @enum {string} */ type: "face_identifier"; }; @@ -3731,7 +3694,6 @@ export type components = { * type * @default face_mask_detection * @constant - * @enum {string} */ type: "face_mask_detection"; }; @@ -3756,7 +3718,6 @@ export type components = { * type * @default face_mask_output * @constant - * @enum {string} */ type: "face_mask_output"; /** @description The output mask */ @@ -3828,7 +3789,6 @@ export type components = { * type * @default face_off * @constant - * @enum {string} */ type: "face_off"; }; @@ -3853,7 +3813,6 @@ export type components = { * type * @default face_off_output * @constant - * @enum {string} */ type: "face_off_output"; /** @description The output mask */ @@ -3901,7 +3860,6 @@ export type components = { * type * @default float_collection * @constant - * @enum {string} */ type: "float_collection"; }; @@ -3919,7 +3877,6 @@ export type components = { * type * @default float_collection_output * @constant - * @enum {string} */ type: "float_collection_output"; }; @@ -3955,7 +3912,6 @@ export type components = { * type * @default float * @constant - * @enum {string} */ type: "float"; }; @@ -4003,7 +3959,6 @@ export type components = { * type * @default float_range * @constant - * @enum {string} */ type: "float_range"; }; @@ -4052,7 +4007,6 @@ export type components = { * type * @default float_math * @constant - * @enum {string} */ type: "float_math"; }; @@ -4070,7 +4024,6 @@ export type components = { * type * @default float_output * @constant - * @enum {string} */ type: "float_output"; }; @@ -4119,7 +4072,6 @@ export type components = { * type * @default float_to_int * @constant - * @enum {string} */ type: "float_to_int"; }; @@ -4223,7 +4175,6 @@ export type components = { * type * @default freeu * @constant - * @enum {string} */ type: "freeu"; }; @@ -4240,7 +4191,6 @@ export type components = { * type * @default gradient_mask_output * @constant - * @enum {string} */ type: "gradient_mask_output"; }; @@ -4256,7 +4206,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["IdealSizeInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"]; + [key: string]: components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["IterateInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["StringReplaceInvocation"]; }; /** * Edges @@ -4293,7 +4243,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoRALoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["String2Output"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["CLIPSkipInvocationOutput"]; + [key: string]: components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["String2Output"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["ConditioningCollectionOutput"]; }; /** * Errors @@ -4336,7 +4286,6 @@ export type components = { * Type * @default hf * @constant - * @enum {string} */ type?: "hf"; }; @@ -4395,7 +4344,6 @@ export type components = { * type * @default hed_image_processor * @constant - * @enum {string} */ type: "hed_image_processor"; }; @@ -4439,7 +4387,6 @@ export type components = { * type * @default heuristic_resize * @constant - * @enum {string} */ type: "heuristic_resize"; }; @@ -4462,7 +4409,6 @@ export type components = { * Type * @default huggingface * @constant - * @enum {string} */ type?: "huggingface"; /** @@ -4530,6 +4476,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -4542,24 +4489,24 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default ip_adapter * @constant - * @enum {string} */ type: "ip_adapter"; /** * Format * @constant - * @enum {string} */ format: "checkpoint"; }; @@ -4635,7 +4582,7 @@ export type components = { * IP-Adapter Model * @description The IP-Adapter model. */ - ip_adapter_model: components["schemas"]["ModelIdentifierField"]; + ip_adapter_model?: components["schemas"]["ModelIdentifierField"]; /** * Clip Vision Model * @description CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models. @@ -4674,7 +4621,6 @@ export type components = { * type * @default ip_adapter * @constant - * @enum {string} */ type: "ip_adapter"; }; @@ -4708,6 +4654,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -4720,18 +4667,19 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default ip_adapter * @constant - * @enum {string} */ type: "ip_adapter"; /** Image Encoder Model Id */ @@ -4739,7 +4687,6 @@ export type components = { /** * Format * @constant - * @enum {string} */ format: "invokeai"; }; @@ -4791,7 +4738,6 @@ export type components = { * type * @default ip_adapter_output * @constant - * @enum {string} */ type: "ip_adapter_output"; }; @@ -4841,7 +4787,6 @@ export type components = { * type * @default ideal_size * @constant - * @enum {string} */ type: "ideal_size"; }; @@ -4864,7 +4809,6 @@ export type components = { * type * @default ideal_size_output * @constant - * @enum {string} */ type: "ideal_size_output"; }; @@ -4913,7 +4857,6 @@ export type components = { * type * @default img_blur * @constant - * @enum {string} */ type: "img_blur"; }; @@ -4968,7 +4911,6 @@ export type components = { * type * @default img_chan * @constant - * @enum {string} */ type: "img_chan"; }; @@ -5022,7 +4964,6 @@ export type components = { * type * @default img_channel_multiply * @constant - * @enum {string} */ type: "img_channel_multiply"; }; @@ -5070,7 +5011,6 @@ export type components = { * type * @default img_channel_offset * @constant - * @enum {string} */ type: "img_channel_offset"; }; @@ -5105,7 +5045,6 @@ export type components = { * type * @default image_collection * @constant - * @enum {string} */ type: "image_collection"; }; @@ -5123,7 +5062,6 @@ export type components = { * type * @default image_collection_output * @constant - * @enum {string} */ type: "image_collection_output"; }; @@ -5166,7 +5104,6 @@ export type components = { * type * @default img_conv * @constant - * @enum {string} */ type: "img_conv"; }; @@ -5226,7 +5163,6 @@ export type components = { * type * @default img_crop * @constant - * @enum {string} */ type: "img_crop"; }; @@ -5359,7 +5295,6 @@ export type components = { * type * @default img_hue_adjust * @constant - * @enum {string} */ type: "img_hue_adjust"; }; @@ -5407,7 +5342,6 @@ export type components = { * type * @default img_ilerp * @constant - * @enum {string} */ type: "img_ilerp"; }; @@ -5439,7 +5373,6 @@ export type components = { * type * @default image * @constant - * @enum {string} */ type: "image"; }; @@ -5487,7 +5420,6 @@ export type components = { * type * @default img_lerp * @constant - * @enum {string} */ type: "img_lerp"; }; @@ -5533,7 +5465,6 @@ export type components = { * type * @default image_mask_to_tensor * @constant - * @enum {string} */ type: "image_mask_to_tensor"; }; @@ -5571,7 +5502,6 @@ export type components = { * type * @default img_mul * @constant - * @enum {string} */ type: "img_mul"; }; @@ -5607,7 +5537,6 @@ export type components = { * type * @default img_nsfw * @constant - * @enum {string} */ type: "img_nsfw"; }; @@ -5632,7 +5561,6 @@ export type components = { * type * @default image_output * @constant - * @enum {string} */ type: "image_output"; }; @@ -5690,7 +5618,6 @@ export type components = { * type * @default img_paste * @constant - * @enum {string} */ type: "img_paste"; }; @@ -5775,7 +5702,6 @@ export type components = { * type * @default img_resize * @constant - * @enum {string} */ type: "img_resize"; }; @@ -5824,7 +5750,6 @@ export type components = { * type * @default img_scale * @constant - * @enum {string} */ type: "img_scale"; }; @@ -5870,7 +5795,6 @@ export type components = { * type * @default i2l * @constant - * @enum {string} */ type: "i2l"; }; @@ -5933,7 +5857,6 @@ export type components = { * type * @default img_watermark * @constant - * @enum {string} */ type: "img_watermark"; }; @@ -6000,7 +5923,6 @@ export type components = { * type * @default infill_rgba * @constant - * @enum {string} */ type: "infill_rgba"; }; @@ -6049,7 +5971,6 @@ export type components = { * type * @default infill_patchmatch * @constant - * @enum {string} */ type: "infill_patchmatch"; }; @@ -6097,7 +6018,6 @@ export type components = { * type * @default infill_tile * @constant - * @enum {string} */ type: "infill_tile"; }; @@ -6139,7 +6059,6 @@ export type components = { * type * @default integer_collection * @constant - * @enum {string} */ type: "integer_collection"; }; @@ -6157,7 +6076,6 @@ export type components = { * type * @default integer_collection_output * @constant - * @enum {string} */ type: "integer_collection_output"; }; @@ -6193,7 +6111,6 @@ export type components = { * type * @default integer * @constant - * @enum {string} */ type: "integer"; }; @@ -6242,7 +6159,6 @@ export type components = { * type * @default integer_math * @constant - * @enum {string} */ type: "integer_math"; }; @@ -6260,7 +6176,6 @@ export type components = { * type * @default integer_output * @constant - * @enum {string} */ type: "integer_output"; }; @@ -6292,7 +6207,6 @@ export type components = { * type * @default invert_tensor_mask * @constant - * @enum {string} */ type: "invert_tensor_mask"; }; @@ -6362,7 +6276,6 @@ export type components = { * type * @default iterate * @constant - * @enum {string} */ type: "iterate"; }; @@ -6390,7 +6303,6 @@ export type components = { * type * @default iterate_output * @constant - * @enum {string} */ type: "iterate_output"; }; @@ -6427,7 +6339,6 @@ export type components = { * type * @default infill_lama * @constant - * @enum {string} */ type: "infill_lama"; }; @@ -6462,7 +6373,6 @@ export type components = { * type * @default latents_collection * @constant - * @enum {string} */ type: "latents_collection"; }; @@ -6480,7 +6390,6 @@ export type components = { * type * @default latents_collection_output * @constant - * @enum {string} */ type: "latents_collection_output"; }; @@ -6529,7 +6438,6 @@ export type components = { * type * @default latents * @constant - * @enum {string} */ type: "latents"; }; @@ -6554,7 +6462,6 @@ export type components = { * type * @default latents_output * @constant - * @enum {string} */ type: "latents_output"; }; @@ -6604,7 +6511,6 @@ export type components = { * type * @default l2i * @constant - * @enum {string} */ type: "l2i"; }; @@ -6670,7 +6576,6 @@ export type components = { * type * @default leres_image_processor * @constant - * @enum {string} */ type: "leres_image_processor"; }; @@ -6718,7 +6623,6 @@ export type components = { * type * @default lineart_anime_image_processor * @constant - * @enum {string} */ type: "lineart_anime_image_processor"; }; @@ -6772,7 +6676,6 @@ export type components = { * type * @default lineart_image_processor * @constant - * @enum {string} */ type: "lineart_image_processor"; }; @@ -6817,7 +6720,6 @@ export type components = { * type * @default lora_collection_loader * @constant - * @enum {string} */ type: "lora_collection_loader"; }; @@ -6851,6 +6753,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -6863,30 +6766,31 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default lora * @constant - * @enum {string} */ type: "lora"; /** * Trigger Phrases * @description Set of trigger phrases for this model + * @default null */ trigger_phrases?: string[] | null; /** * Format * @default diffusers * @constant - * @enum {string} */ format: "diffusers"; }; @@ -6926,7 +6830,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -6947,7 +6851,6 @@ export type components = { * type * @default lora_loader * @constant - * @enum {string} */ type: "lora_loader"; }; @@ -6972,7 +6875,6 @@ export type components = { * type * @default lora_loader_output * @constant - * @enum {string} */ type: "lora_loader_output"; }; @@ -7006,6 +6908,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -7018,30 +6921,31 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default lora * @constant - * @enum {string} */ type: "lora"; /** * Trigger Phrases * @description Set of trigger phrases for this model + * @default null */ trigger_phrases?: string[] | null; /** * Format * @default lycoris * @constant - * @enum {string} */ format: "lycoris"; }; @@ -7084,7 +6988,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -7095,7 +6999,6 @@ export type components = { * type * @default lora_selector * @constant - * @enum {string} */ type: "lora_selector"; }; @@ -7113,7 +7016,6 @@ export type components = { * type * @default lora_selector_output * @constant - * @enum {string} */ type: "lora_selector_output"; }; @@ -7133,7 +7035,6 @@ export type components = { * Type * @default local * @constant - * @enum {string} */ type?: "local"; }; @@ -7172,6 +7073,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -7184,26 +7086,31 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default main * @constant - * @enum {string} */ type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model + * @default null */ trigger_phrases?: string[] | null; - /** @description Default settings for this model */ + /** + * @description Default settings for this model + * @default null + */ default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; /** @default normal */ variant?: components["schemas"]["ModelVariantType"]; @@ -7211,7 +7118,6 @@ export type components = { * Format * @default checkpoint * @constant - * @enum {string} */ format: "checkpoint"; /** @@ -7262,6 +7168,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -7274,26 +7181,31 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default main * @constant - * @enum {string} */ type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model + * @default null */ trigger_phrases?: string[] | null; - /** @description Default settings for this model */ + /** + * @description Default settings for this model + * @default null + */ default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; /** @default normal */ variant?: components["schemas"]["ModelVariantType"]; @@ -7301,7 +7213,6 @@ export type components = { * Format * @default diffusers * @constant - * @enum {string} */ format: "diffusers"; /** @default */ @@ -7312,41 +7223,49 @@ export type components = { /** * Vae * @description Default VAE for this model (model key) + * @default null */ vae?: string | null; /** * Vae Precision * @description Default VAE precision for this model + * @default null */ vae_precision?: ("fp16" | "fp32") | null; /** * Scheduler * @description Default scheduler for this model + * @default null */ scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null; /** * Steps * @description Default number of steps for this model + * @default null */ steps?: number | null; /** * Cfg Scale * @description Default CFG Scale for this model + * @default null */ cfg_scale?: number | null; /** * Cfg Rescale Multiplier * @description Default CFG Rescale Multiplier for this model + * @default null */ cfg_rescale_multiplier?: number | null; /** * Width * @description Default width for this model + * @default null */ width?: number | null; /** * Height * @description Default height for this model + * @default null */ height?: number | null; }; @@ -7373,12 +7292,11 @@ export type components = { */ use_cache?: boolean; /** @description Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default main_model_loader * @constant - * @enum {string} */ type: "main_model_loader"; }; @@ -7416,7 +7334,6 @@ export type components = { * type * @default mask_combine * @constant - * @enum {string} */ type: "mask_combine"; }; @@ -7472,7 +7389,6 @@ export type components = { * type * @default mask_edge * @constant - * @enum {string} */ type: "mask_edge"; }; @@ -7514,7 +7430,6 @@ export type components = { * type * @default tomask * @constant - * @enum {string} */ type: "tomask"; }; @@ -7564,7 +7479,6 @@ export type components = { * type * @default mask_from_id * @constant - * @enum {string} */ type: "mask_from_id"; }; @@ -7589,7 +7503,6 @@ export type components = { * type * @default mask_output * @constant - * @enum {string} */ type: "mask_output"; }; @@ -7649,7 +7562,6 @@ export type components = { * type * @default mediapipe_face_processor * @constant - * @enum {string} */ type: "mediapipe_face_processor"; }; @@ -7684,7 +7596,6 @@ export type components = { * type * @default merge_metadata * @constant - * @enum {string} */ type: "merge_metadata"; }; @@ -7736,7 +7647,6 @@ export type components = { * type * @default merge_tiles_to_image * @constant - * @enum {string} */ type: "merge_tiles_to_image"; }; @@ -7777,7 +7687,6 @@ export type components = { * type * @default metadata * @constant - * @enum {string} */ type: "metadata"; }; @@ -7830,7 +7739,6 @@ export type components = { * type * @default metadata_item * @constant - * @enum {string} */ type: "metadata_item"; }; @@ -7845,7 +7753,6 @@ export type components = { * type * @default metadata_item_output * @constant - * @enum {string} */ type: "metadata_item_output"; }; @@ -7857,7 +7764,6 @@ export type components = { * type * @default metadata_output * @constant - * @enum {string} */ type: "metadata_output"; }; @@ -7917,7 +7823,6 @@ export type components = { * type * @default midas_depth_image_processor * @constant - * @enum {string} */ type: "midas_depth_image_processor"; }; @@ -7977,7 +7882,6 @@ export type components = { * type * @default mlsd_image_processor * @constant - * @enum {string} */ type: "mlsd_image_processor"; }; @@ -8014,6 +7918,59 @@ export type components = { */ submodel_type?: components["schemas"]["SubModelType"] | null; }; + /** + * Model identifier + * @description Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as + * input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an + * error. + */ + ModelIdentifierInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Model + * @description The model to select + */ + model?: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default model_identifier + * @constant + */ + type: "model_identifier"; + }; + /** + * ModelIdentifierOutput + * @description Model identifier output + */ + ModelIdentifierOutput: { + /** + * Model + * @description Model identifier + */ + model: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default model_identifier_output + * @constant + */ + type: "model_identifier_output"; + }; /** * ModelInstallJob * @description Object that tracks the current status of an install request. @@ -8108,7 +8065,6 @@ export type components = { * type * @default model_loader_output * @constant - * @enum {string} */ type: "model_loader_output"; /** @@ -8239,7 +8195,6 @@ export type components = { * type * @default mul * @constant - * @enum {string} */ type: "mul"; }; @@ -8311,7 +8266,6 @@ export type components = { * type * @default noise * @constant - * @enum {string} */ type: "noise"; }; @@ -8336,7 +8290,6 @@ export type components = { * type * @default noise_output * @constant - * @enum {string} */ type: "noise_output"; }; @@ -8384,7 +8337,6 @@ export type components = { * type * @default normalbae_image_processor * @constant - * @enum {string} */ type: "normalbae_image_processor"; }; @@ -8492,7 +8444,6 @@ export type components = { * type * @default pair_tile_image * @constant - * @enum {string} */ type: "pair_tile_image"; }; @@ -8504,7 +8455,6 @@ export type components = { * type * @default pair_tile_image_output * @constant - * @enum {string} */ type: "pair_tile_image_output"; }; @@ -8564,7 +8514,6 @@ export type components = { * type * @default pidi_image_processor * @constant - * @enum {string} */ type: "pidi_image_processor"; }; @@ -8621,7 +8570,6 @@ export type components = { * type * @default prompt_from_file * @constant - * @enum {string} */ type: "prompt_from_file"; }; @@ -8680,7 +8628,6 @@ export type components = { * type * @default rand_float * @constant - * @enum {string} */ type: "rand_float"; }; @@ -8722,7 +8669,6 @@ export type components = { * type * @default rand_int * @constant - * @enum {string} */ type: "rand_int"; }; @@ -8776,7 +8722,6 @@ export type components = { * type * @default random_range * @constant - * @enum {string} */ type: "random_range"; }; @@ -8824,7 +8769,6 @@ export type components = { * type * @default range * @constant - * @enum {string} */ type: "range"; }; @@ -8872,7 +8816,6 @@ export type components = { * type * @default range_of_size * @constant - * @enum {string} */ type: "range_of_size"; }; @@ -8934,7 +8877,6 @@ export type components = { * type * @default rectangle_mask * @constant - * @enum {string} */ type: "rectangle_mask"; }; @@ -9025,7 +8967,6 @@ export type components = { * type * @default lresize * @constant - * @enum {string} */ type: "lresize"; }; @@ -9077,7 +9018,6 @@ export type components = { * type * @default round_float * @constant - * @enum {string} */ type: "round_float"; }; @@ -9161,7 +9101,6 @@ export type components = { * type * @default sdxl_compel_prompt * @constant - * @enum {string} */ type: "sdxl_compel_prompt"; }; @@ -9211,7 +9150,6 @@ export type components = { * type * @default sdxl_lora_collection_loader * @constant - * @enum {string} */ type: "sdxl_lora_collection_loader"; }; @@ -9241,7 +9179,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -9267,7 +9205,6 @@ export type components = { * type * @default sdxl_lora_loader * @constant - * @enum {string} */ type: "sdxl_lora_loader"; }; @@ -9298,7 +9235,6 @@ export type components = { * type * @default sdxl_lora_loader_output * @constant - * @enum {string} */ type: "sdxl_lora_loader_output"; }; @@ -9325,12 +9261,11 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_model_loader * @constant - * @enum {string} */ type: "sdxl_model_loader"; }; @@ -9363,7 +9298,6 @@ export type components = { * type * @default sdxl_model_loader_output * @constant - * @enum {string} */ type: "sdxl_model_loader_output"; }; @@ -9427,7 +9361,6 @@ export type components = { * type * @default sdxl_refiner_compel_prompt * @constant - * @enum {string} */ type: "sdxl_refiner_compel_prompt"; }; @@ -9454,12 +9387,11 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_refiner_model_loader * @constant - * @enum {string} */ type: "sdxl_refiner_model_loader"; }; @@ -9487,7 +9419,6 @@ export type components = { * type * @default sdxl_refiner_model_loader_output * @constant - * @enum {string} */ type: "sdxl_refiner_model_loader_output"; }; @@ -9528,7 +9459,6 @@ export type components = { * type * @default save_image * @constant - * @enum {string} */ type: "save_image"; }; @@ -9578,7 +9508,6 @@ export type components = { * type * @default lscale * @constant - * @enum {string} */ type: "lscale"; }; @@ -9615,7 +9544,6 @@ export type components = { * type * @default scheduler * @constant - * @enum {string} */ type: "scheduler"; }; @@ -9631,7 +9559,6 @@ export type components = { * type * @default scheduler_output * @constant - * @enum {string} */ type: "scheduler_output"; }; @@ -9689,7 +9616,6 @@ export type components = { * type * @default seamless * @constant - * @enum {string} */ type: "seamless"; }; @@ -9714,7 +9640,6 @@ export type components = { * type * @default seamless_output * @constant - * @enum {string} */ type: "seamless_output"; }; @@ -9762,7 +9687,6 @@ export type components = { * type * @default segment_anything_processor * @constant - * @enum {string} */ type: "segment_anything_processor"; }; @@ -9818,10 +9742,20 @@ export type components = { */ session_id: string; /** - * Error + * Error Type + * @description The error type if this queue item errored + */ + error_type?: string | null; + /** + * Error Message * @description The error message if this queue item errored */ - error?: string | null; + error_message?: string | null; + /** + * Error Traceback + * @description The error traceback if this queue item errored + */ + error_traceback?: string | null; /** * Created At * @description When this queue item was created @@ -9888,10 +9822,20 @@ export type components = { */ session_id: string; /** - * Error + * Error Type + * @description The error type if this queue item errored + */ + error_type?: string | null; + /** + * Error Message * @description The error message if this queue item errored */ - error?: string | null; + error_message?: string | null; + /** + * Error Traceback + * @description The error traceback if this queue item errored + */ + error_traceback?: string | null; /** * Created At * @description When this queue item was created @@ -10004,7 +9948,6 @@ export type components = { * type * @default show_image * @constant - * @enum {string} */ type: "show_image"; }; @@ -10127,7 +10070,6 @@ export type components = { * type * @default step_param_easing * @constant - * @enum {string} */ type: "step_param_easing"; }; @@ -10150,7 +10092,6 @@ export type components = { * type * @default string_2_output * @constant - * @enum {string} */ type: "string_2_output"; }; @@ -10186,7 +10127,6 @@ export type components = { * type * @default string_collection * @constant - * @enum {string} */ type: "string_collection"; }; @@ -10204,7 +10144,6 @@ export type components = { * type * @default string_collection_output * @constant - * @enum {string} */ type: "string_collection_output"; }; @@ -10240,7 +10179,6 @@ export type components = { * type * @default string * @constant - * @enum {string} */ type: "string"; }; @@ -10282,7 +10220,6 @@ export type components = { * type * @default string_join * @constant - * @enum {string} */ type: "string_join"; }; @@ -10330,7 +10267,6 @@ export type components = { * type * @default string_join_three * @constant - * @enum {string} */ type: "string_join_three"; }; @@ -10348,7 +10284,6 @@ export type components = { * type * @default string_output * @constant - * @enum {string} */ type: "string_output"; }; @@ -10371,7 +10306,6 @@ export type components = { * type * @default string_pos_neg_output * @constant - * @enum {string} */ type: "string_pos_neg_output"; }; @@ -10425,7 +10359,6 @@ export type components = { * type * @default string_replace * @constant - * @enum {string} */ type: "string_replace"; }; @@ -10467,7 +10400,6 @@ export type components = { * type * @default string_split * @constant - * @enum {string} */ type: "string_split"; }; @@ -10503,7 +10435,6 @@ export type components = { * type * @default string_split_neg * @constant - * @enum {string} */ type: "string_split_neg"; }; @@ -10551,7 +10482,6 @@ export type components = { * type * @default sub * @constant - * @enum {string} */ type: "sub"; }; @@ -10560,7 +10490,10 @@ export type components = { * @description Model config for T2I. */ T2IAdapterConfig: { - /** @description Default settings for this model */ + /** + * @description Default settings for this model + * @default null + */ default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key @@ -10587,6 +10520,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -10599,17 +10533,18 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Format * @constant - * @enum {string} */ format: "diffusers"; /** @default */ @@ -10618,7 +10553,6 @@ export type components = { * Type * @default t2i_adapter * @constant - * @enum {string} */ type: "t2i_adapter"; }; @@ -10682,7 +10616,7 @@ export type components = { * T2I-Adapter Model * @description The T2I-Adapter model. */ - t2i_adapter_model: components["schemas"]["ModelIdentifierField"]; + t2i_adapter_model?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the T2I-Adapter @@ -10712,7 +10646,6 @@ export type components = { * type * @default t2i_adapter * @constant - * @enum {string} */ type: "t2i_adapter"; }; @@ -10761,7 +10694,6 @@ export type components = { * type * @default t2i_adapter_output * @constant - * @enum {string} */ type: "t2i_adapter_output"; }; @@ -10817,6 +10749,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -10829,25 +10762,25 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default embedding * @constant - * @enum {string} */ type: "embedding"; /** * Format * @default embedding_file * @constant - * @enum {string} */ format: "embedding_file"; }; @@ -10881,6 +10814,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -10893,25 +10827,25 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default embedding * @constant - * @enum {string} */ type: "embedding"; /** * Format * @default embedding_folder * @constant - * @enum {string} */ format: "embedding_folder"; }; @@ -10960,7 +10894,6 @@ export type components = { * type * @default tile_image_processor * @constant - * @enum {string} */ type: "tile_image_processor"; }; @@ -10992,7 +10925,6 @@ export type components = { * type * @default tile_to_properties * @constant - * @enum {string} */ type: "tile_to_properties"; }; @@ -11052,7 +10984,6 @@ export type components = { * type * @default tile_to_properties_output * @constant - * @enum {string} */ type: "tile_to_properties_output"; }; @@ -11097,7 +11028,6 @@ export type components = { * type * @default unet_output * @constant - * @enum {string} */ type: "unet_output"; }; @@ -11117,7 +11047,6 @@ export type components = { * Type * @default url * @constant - * @enum {string} */ type?: "url"; }; @@ -11165,7 +11094,6 @@ export type components = { * type * @default unsharp_mask * @constant - * @enum {string} */ type: "unsharp_mask"; }; @@ -11212,6 +11140,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -11224,18 +11153,19 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Format * @default checkpoint * @constant - * @enum {string} */ format: "checkpoint"; /** @@ -11252,7 +11182,6 @@ export type components = { * Type * @default vae * @constant - * @enum {string} */ type: "vae"; }; @@ -11286,6 +11215,7 @@ export type components = { /** * Description * @description Model description + * @default null */ description?: string | null; /** @@ -11298,25 +11228,25 @@ export type components = { /** * Source Api Response * @description The original API response from the source, as stringified JSON. + * @default null */ source_api_response?: string | null; /** * Cover Image * @description Url for image to preview model + * @default null */ cover_image?: string | null; /** * Type * @default vae * @constant - * @enum {string} */ type: "vae"; /** * Format * @default diffusers * @constant - * @enum {string} */ format: "diffusers"; }; @@ -11356,12 +11286,11 @@ export type components = { * VAE * @description VAE model to load */ - vae_model: components["schemas"]["ModelIdentifierField"]; + vae_model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default vae_loader * @constant - * @enum {string} */ type: "vae_loader"; }; @@ -11379,7 +11308,6 @@ export type components = { * type * @default vae_output * @constant - * @enum {string} */ type: "vae_output"; }; @@ -11649,7 +11577,6 @@ export type components = { * type * @default zoe_depth_image_processor * @constant - * @enum {string} */ type: "zoe_depth_image_processor"; }; @@ -11841,143 +11768,925 @@ export type components = { */ UIType: "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; InvocationOutputMap: { - ideal_size: components["schemas"]["IdealSizeOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - string_replace: components["schemas"]["StringOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; lora_selector: components["schemas"]["LoRASelectorOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; + integer: components["schemas"]["IntegerOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + controlnet: components["schemas"]["ControlOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + freeu: components["schemas"]["UNetOutput"]; + img_scale: components["schemas"]["ImageOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; + float_math: components["schemas"]["FloatOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + image_mask_to_tensor: components["schemas"]["MaskOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + i2l: components["schemas"]["LatentsOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + save_image: components["schemas"]["ImageOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + color: components["schemas"]["ColorOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + l2i: components["schemas"]["ImageOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + image: components["schemas"]["ImageOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + sub: components["schemas"]["IntegerOutput"]; + rand_int: components["schemas"]["IntegerOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + string: components["schemas"]["StringOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + noise: components["schemas"]["NoiseOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + range_of_size: components["schemas"]["IntegerCollectionOutput"]; + latents: components["schemas"]["LatentsOutput"]; + add: components["schemas"]["IntegerOutput"]; + div: components["schemas"]["IntegerOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + round_float: components["schemas"]["FloatOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + string_join_three: components["schemas"]["StringOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + string_split: components["schemas"]["String2Output"]; + show_image: components["schemas"]["ImageOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + string_join: components["schemas"]["StringOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + compel: components["schemas"]["ConditioningOutput"]; + tomask: components["schemas"]["ImageOutput"]; + esrgan: components["schemas"]["ImageOutput"]; + denoise_latents: components["schemas"]["LatentsOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + prompt_from_file: components["schemas"]["StringCollectionOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + infill_tile: components["schemas"]["ImageOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + mul: components["schemas"]["IntegerOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + canny_image_processor: components["schemas"]["ImageOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; metadata: components["schemas"]["MetadataOutput"]; clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - denoise_latents: components["schemas"]["LatentsOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - img_nsfw: components["schemas"]["ImageOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - image_mask_to_tensor: components["schemas"]["MaskOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - latents: components["schemas"]["LatentsOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; - integer: components["schemas"]["IntegerOutput"]; - mul: components["schemas"]["IntegerOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - round_float: components["schemas"]["FloatOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - sub: components["schemas"]["IntegerOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - image: components["schemas"]["ImageOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - l2i: components["schemas"]["ImageOutput"]; - canny_image_processor: components["schemas"]["ImageOutput"]; - save_image: components["schemas"]["ImageOutput"]; - string_split: components["schemas"]["String2Output"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - add: components["schemas"]["IntegerOutput"]; - infill_tile: components["schemas"]["ImageOutput"]; - color: components["schemas"]["ColorOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - freeu: components["schemas"]["UNetOutput"]; - pidi_image_processor: components["schemas"]["ImageOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - noise: components["schemas"]["NoiseOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - div: components["schemas"]["IntegerOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - string: components["schemas"]["StringOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; - float_math: components["schemas"]["FloatOutput"]; - tomask: components["schemas"]["ImageOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; - range_of_size: components["schemas"]["IntegerCollectionOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; canvas_paste_back: components["schemas"]["ImageOutput"]; - controlnet: components["schemas"]["ControlOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - float_to_int: components["schemas"]["IntegerOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - float: components["schemas"]["FloatOutput"]; - rand_int: components["schemas"]["IntegerOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; - string_join: components["schemas"]["StringOutput"]; - compel: components["schemas"]["ConditioningOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; - img_watermark: components["schemas"]["ImageOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - prompt_from_file: components["schemas"]["StringCollectionOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; img_pad_crop: components["schemas"]["ImageOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; - show_image: components["schemas"]["ImageOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - lresize: components["schemas"]["LatentsOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - img_scale: components["schemas"]["ImageOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + float: components["schemas"]["FloatOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + string_replace: components["schemas"]["StringOutput"]; + }; + /** + * BatchEnqueuedEvent + * @description Event model for batch_enqueued + */ + BatchEnqueuedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Batch Id + * @description The ID of the batch + */ + batch_id: string; + /** + * Enqueued + * @description The number of invocations enqueued + */ + enqueued: number; + /** + * Requested + * @description The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full) + */ + requested: number; + /** + * Priority + * @description The priority of the batch + */ + priority: number; + }; + /** + * BulkDownloadCompleteEvent + * @description Event model for bulk_download_complete + */ + BulkDownloadCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Bulk Download Id + * @description The ID of the bulk image download + */ + bulk_download_id: string; + /** + * Bulk Download Item Id + * @description The ID of the bulk image download item + */ + bulk_download_item_id: string; + /** + * Bulk Download Item Name + * @description The name of the bulk image download item + */ + bulk_download_item_name: string; + }; + /** + * BulkDownloadErrorEvent + * @description Event model for bulk_download_error + */ + BulkDownloadErrorEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Bulk Download Id + * @description The ID of the bulk image download + */ + bulk_download_id: string; + /** + * Bulk Download Item Id + * @description The ID of the bulk image download item + */ + bulk_download_item_id: string; + /** + * Bulk Download Item Name + * @description The name of the bulk image download item + */ + bulk_download_item_name: string; + /** + * Error + * @description The error message + */ + error: string; + }; + /** + * BulkDownloadStartedEvent + * @description Event model for bulk_download_started + */ + BulkDownloadStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Bulk Download Id + * @description The ID of the bulk image download + */ + bulk_download_id: string; + /** + * Bulk Download Item Id + * @description The ID of the bulk image download item + */ + bulk_download_item_id: string; + /** + * Bulk Download Item Name + * @description The name of the bulk image download item + */ + bulk_download_item_name: string; + }; + /** + * DownloadCancelledEvent + * @description Event model for download_cancelled + */ + DownloadCancelledEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Source + * @description The source of the download + */ + source: string; + }; + /** + * DownloadCompleteEvent + * @description Event model for download_complete + */ + DownloadCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Source + * @description The source of the download + */ + source: string; + /** + * Download Path + * @description The local path where the download is saved + */ + download_path: string; + /** + * Total Bytes + * @description The total number of bytes downloaded + */ + total_bytes: number; + }; + /** + * DownloadErrorEvent + * @description Event model for download_error + */ + DownloadErrorEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Source + * @description The source of the download + */ + source: string; + /** + * Error Type + * @description The type of error + */ + error_type: string; + /** + * Error + * @description The error message + */ + error: string; + }; + /** + * DownloadProgressEvent + * @description Event model for download_progress + */ + DownloadProgressEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Source + * @description The source of the download + */ + source: string; + /** + * Download Path + * @description The local path where the download is saved + */ + download_path: string; + /** + * Current Bytes + * @description The number of bytes downloaded so far + */ + current_bytes: number; + /** + * Total Bytes + * @description The total number of bytes to be downloaded + */ + total_bytes: number; + }; + /** + * DownloadStartedEvent + * @description Event model for download_started + */ + DownloadStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Source + * @description The source of the download + */ + source: string; + /** + * Download Path + * @description The local path where the download is saved + */ + download_path: string; + }; + /** + * BaseInvocation + * @description All invocations must use the `@invocation` decorator to provide their unique type. + */ + BaseInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache: boolean; + }; + /** + * BaseInvocationOutput + * @description Base class for all invocation outputs. + * + * All invocation outputs must use the `@invocation_output` decorator to provide their unique type. + */ + BaseInvocationOutput: Record; + /** + * InvocationCompleteEvent + * @description Event model for invocation_complete + */ + InvocationCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Item Id + * @description The ID of the queue item + */ + item_id: number; + /** + * Batch Id + * @description The ID of the queue batch + */ + batch_id: string; + /** + * Session Id + * @description The ID of the session (aka graph execution state) + */ + session_id: string; + /** @description The ID of the invocation */ + invocation: components["schemas"]["BaseInvocation"]; + /** + * Invocation Source Id + * @description The ID of the prepared invocation's source node + */ + invocation_source_id: string; + /** @description The result of the invocation */ + result: components["schemas"]["BaseInvocationOutput"]; + }; + /** + * InvocationDenoiseProgressEvent + * @description Event model for invocation_denoise_progress + */ + InvocationDenoiseProgressEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Item Id + * @description The ID of the queue item + */ + item_id: number; + /** + * Batch Id + * @description The ID of the queue batch + */ + batch_id: string; + /** + * Session Id + * @description The ID of the session (aka graph execution state) + */ + session_id: string; + /** @description The ID of the invocation */ + invocation: components["schemas"]["BaseInvocation"]; + /** + * Invocation Source Id + * @description The ID of the prepared invocation's source node + */ + invocation_source_id: string; + /** @description The progress image sent at each step during processing */ + progress_image: components["schemas"]["ProgressImage"]; + /** + * Step + * @description The current step of the invocation + */ + step: number; + /** + * Total Steps + * @description The total number of steps in the invocation + */ + total_steps: number; + /** + * Order + * @description The order of the invocation in the session + */ + order: number; + /** + * Percentage + * @description The percentage of completion of the invocation + */ + percentage: number; + }; + /** + * InvocationErrorEvent + * @description Event model for invocation_error + */ + InvocationErrorEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Item Id + * @description The ID of the queue item + */ + item_id: number; + /** + * Batch Id + * @description The ID of the queue batch + */ + batch_id: string; + /** + * Session Id + * @description The ID of the session (aka graph execution state) + */ + session_id: string; + /** @description The ID of the invocation */ + invocation: components["schemas"]["BaseInvocation"]; + /** + * Invocation Source Id + * @description The ID of the prepared invocation's source node + */ + invocation_source_id: string; + /** + * Error Type + * @description The error type + */ + error_type: string; + /** + * Error Message + * @description The error message + */ + error_message: string; + /** + * Error Traceback + * @description The error traceback + */ + error_traceback: string; + /** + * User Id + * @description The ID of the user who created the invocation + * @default null + */ + user_id: string | null; + /** + * Project Id + * @description The ID of the user who created the invocation + * @default null + */ + project_id: string | null; + }; + /** + * InvocationStartedEvent + * @description Event model for invocation_started + */ + InvocationStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Item Id + * @description The ID of the queue item + */ + item_id: number; + /** + * Batch Id + * @description The ID of the queue batch + */ + batch_id: string; + /** + * Session Id + * @description The ID of the session (aka graph execution state) + */ + session_id: string; + /** @description The ID of the invocation */ + invocation: components["schemas"]["BaseInvocation"]; + /** + * Invocation Source Id + * @description The ID of the prepared invocation's source node + */ + invocation_source_id: string; + }; + /** + * ModelInstallCancelledEvent + * @description Event model for model_install_cancelled + */ + ModelInstallCancelledEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + }; + /** + * ModelInstallCompleteEvent + * @description Event model for model_install_complete + */ + ModelInstallCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + /** + * Key + * @description Model config record key + */ + key: string; + /** + * Total Bytes + * @description Size of the model (may be None for installation of a local path) + */ + total_bytes: number | null; + }; + /** + * ModelInstallDownloadProgressEvent + * @description Event model for model_install_download_progress + */ + ModelInstallDownloadProgressEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + /** + * Local Path + * @description Where model is downloading to + */ + local_path: string; + /** + * Bytes + * @description Number of bytes downloaded so far + */ + bytes: number; + /** + * Total Bytes + * @description Total size of download, including all files + */ + total_bytes: number; + /** + * Parts + * @description Progress of downloading URLs that comprise the model, if any + */ + parts: ({ + [key: string]: number | string; + })[]; + }; + /** + * ModelInstallDownloadsCompleteEvent + * @description Emitted once when an install job becomes active. + */ + ModelInstallDownloadsCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + }; + /** + * ModelInstallErrorEvent + * @description Event model for model_install_error + */ + ModelInstallErrorEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + /** + * Error Type + * @description The name of the exception + */ + error_type: string; + /** + * Error + * @description A text description of the exception + */ + error: string; + }; + /** + * ModelInstallStartedEvent + * @description Event model for model_install_started + */ + ModelInstallStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + }; + /** + * ModelLoadCompleteEvent + * @description Event model for model_load_complete + */ + ModelLoadCompleteEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Config + * @description The model's config + */ + config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + /** + * @description The submodel type, if any + * @default null + */ + submodel_type: components["schemas"]["SubModelType"] | null; + }; + /** + * ModelLoadStartedEvent + * @description Event model for model_load_started + */ + ModelLoadStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Config + * @description The model's config + */ + config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + /** + * @description The submodel type, if any + * @default null + */ + submodel_type: components["schemas"]["SubModelType"] | null; + }; + /** + * QueueClearedEvent + * @description Event model for queue_cleared + */ + QueueClearedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + }; + /** + * QueueItemStatusChangedEvent + * @description Event model for queue_item_status_changed + */ + QueueItemStatusChangedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Queue Id + * @description The ID of the queue + */ + queue_id: string; + /** + * Item Id + * @description The ID of the queue item + */ + item_id: number; + /** + * Batch Id + * @description The ID of the queue batch + */ + batch_id: string; + /** + * Status + * @description The new status of the queue item + * @enum {string} + */ + status: "pending" | "in_progress" | "completed" | "failed" | "canceled"; + /** + * Error Type + * @description The error type, if any + * @default null + */ + error_type: string | null; + /** + * Error Message + * @description The error message, if any + * @default null + */ + error_message: string | null; + /** + * Error Traceback + * @description The error traceback, if any + * @default null + */ + error_traceback: string | null; + /** + * Created At + * @description The timestamp when the queue item was created + * @default null + */ + created_at: string | null; + /** + * Updated At + * @description The timestamp when the queue item was last updated + * @default null + */ + updated_at: string | null; + /** + * Started At + * @description The timestamp when the queue item was started + * @default null + */ + started_at: string | null; + /** + * Completed At + * @description The timestamp when the queue item was completed + * @default null + */ + completed_at: string | null; + /** @description The status of the batch */ + batch_status: components["schemas"]["BatchStatus"]; + /** @description The status of the queue */ + queue_status: components["schemas"]["SessionQueueStatus"]; + /** + * Session Id + * @description The ID of the session (aka graph execution state) + */ + session_id: string; }; }; responses: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 1160a2bee5..3522d719fb 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -1,4 +1,3 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import type { components, paths } from 'services/api/schema'; import type { O } from 'ts-toolbelt'; @@ -39,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT // Models export type ModelType = S['ModelType']; -export type SubModelType = S['SubModelType']; export type BaseModelType = S['BaseModelType']; // Model Configs @@ -200,7 +198,7 @@ type CanvasInitialImageAction = { type ToastAction = { type: 'TOAST'; - toastOptions?: UseToastOptions; + title?: string; }; type AddToBatchAction = { diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 8dd1cfd4fa..257819b4c8 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -1,101 +1,58 @@ import { createAction } from '@reduxjs/toolkit'; import type { - BulkDownloadCompletedEvent, + BulkDownloadCompleteEvent, BulkDownloadFailedEvent, BulkDownloadStartedEvent, - GeneratorProgressEvent, - GraphExecutionStateCompleteEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, InvocationCompleteEvent, + InvocationDenoiseProgressEvent, InvocationErrorEvent, - InvocationRetrievalErrorEvent, InvocationStartedEvent, ModelInstallCancelledEvent, - ModelInstallCompletedEvent, - ModelInstallDownloadingEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, ModelInstallErrorEvent, - ModelLoadCompletedEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, ModelLoadStartedEvent, QueueItemStatusChangedEvent, - SessionRetrievalErrorEvent, } from 'services/events/types'; -// Create actions for each socket -// Middleware and redux can then respond to them as needed +const createSocketAction = (name: string) => + createAction(`socket/${name}`); -export const socketConnected = createAction('socket/socketConnected'); - -export const socketDisconnected = createAction('socket/socketDisconnected'); - -export const socketSubscribedSession = createAction<{ - sessionId: string; -}>('socket/socketSubscribedSession'); - -export const socketUnsubscribedSession = createAction<{ sessionId: string }>('socket/socketUnsubscribedSession'); - -export const socketInvocationStarted = createAction<{ - data: InvocationStartedEvent; -}>('socket/socketInvocationStarted'); - -export const socketInvocationComplete = createAction<{ - data: InvocationCompleteEvent; -}>('socket/socketInvocationComplete'); - -export const socketInvocationError = createAction<{ - data: InvocationErrorEvent; -}>('socket/socketInvocationError'); - -export const socketGraphExecutionStateComplete = createAction<{ - data: GraphExecutionStateCompleteEvent; -}>('socket/socketGraphExecutionStateComplete'); - -export const socketGeneratorProgress = createAction<{ - data: GeneratorProgressEvent; -}>('socket/socketGeneratorProgress'); - -export const socketModelLoadStarted = createAction<{ - data: ModelLoadStartedEvent; -}>('socket/socketModelLoadStarted'); - -export const socketModelLoadCompleted = createAction<{ - data: ModelLoadCompletedEvent; -}>('socket/socketModelLoadCompleted'); - -export const socketModelInstallDownloading = createAction<{ - data: ModelInstallDownloadingEvent; -}>('socket/socketModelInstallDownloading'); - -export const socketModelInstallCompleted = createAction<{ - data: ModelInstallCompletedEvent; -}>('socket/socketModelInstallCompleted'); - -export const socketModelInstallError = createAction<{ - data: ModelInstallErrorEvent; -}>('socket/socketModelInstallError'); - -export const socketModelInstallCancelled = createAction<{ - data: ModelInstallCancelledEvent; -}>('socket/socketModelInstallCancelled'); - -export const socketSessionRetrievalError = createAction<{ - data: SessionRetrievalErrorEvent; -}>('socket/socketSessionRetrievalError'); - -export const socketInvocationRetrievalError = createAction<{ - data: InvocationRetrievalErrorEvent; -}>('socket/socketInvocationRetrievalError'); - -export const socketQueueItemStatusChanged = createAction<{ - data: QueueItemStatusChangedEvent; -}>('socket/socketQueueItemStatusChanged'); - -export const socketBulkDownloadStarted = createAction<{ - data: BulkDownloadStartedEvent; -}>('socket/socketBulkDownloadStarted'); - -export const socketBulkDownloadCompleted = createAction<{ - data: BulkDownloadCompletedEvent; -}>('socket/socketBulkDownloadCompleted'); - -export const socketBulkDownloadFailed = createAction<{ - data: BulkDownloadFailedEvent; -}>('socket/socketBulkDownloadFailed'); +export const socketConnected = createSocketAction('Connected'); +export const socketDisconnected = createSocketAction('Disconnected'); +export const socketInvocationStarted = createSocketAction('InvocationStartedEvent'); +export const socketInvocationComplete = createSocketAction('InvocationCompleteEvent'); +export const socketInvocationError = createSocketAction('InvocationErrorEvent'); +export const socketGeneratorProgress = createSocketAction( + 'InvocationDenoiseProgressEvent' +); +export const socketModelLoadStarted = createSocketAction('ModelLoadStartedEvent'); +export const socketModelLoadComplete = createSocketAction('ModelLoadCompleteEvent'); +export const socketDownloadStarted = createSocketAction('DownloadStartedEvent'); +export const socketDownloadProgress = createSocketAction('DownloadProgressEvent'); +export const socketDownloadComplete = createSocketAction('DownloadCompleteEvent'); +export const socketDownloadCancelled = createSocketAction('DownloadCancelledEvent'); +export const socketDownloadError = createSocketAction('DownloadErrorEvent'); +export const socketModelInstallStarted = createSocketAction('ModelInstallStartedEvent'); +export const socketModelInstallDownloadProgress = createSocketAction( + 'ModelInstallDownloadProgressEvent' +); +export const socketModelInstallDownloadsComplete = createSocketAction( + 'ModelInstallDownloadsCompleteEvent' +); +export const socketModelInstallComplete = createSocketAction('ModelInstallCompleteEvent'); +export const socketModelInstallError = createSocketAction('ModelInstallErrorEvent'); +export const socketModelInstallCancelled = createSocketAction('ModelInstallCancelledEvent'); +export const socketQueueItemStatusChanged = + createSocketAction('QueueItemStatusChangedEvent'); +export const socketBulkDownloadStarted = createSocketAction('BulkDownloadStartedEvent'); +export const socketBulkDownloadComplete = createSocketAction('BulkDownloadCompleteEvent'); +export const socketBulkDownloadError = createSocketAction('BulkDownloadFailedEvent'); diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.ts b/invokeai/frontend/web/src/services/events/setEventListeners.ts new file mode 100644 index 0000000000..6bc0154ef0 --- /dev/null +++ b/invokeai/frontend/web/src/services/events/setEventListeners.ts @@ -0,0 +1,128 @@ +import { $baseUrl } from 'app/store/nanostores/baseUrl'; +import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; +import { $queueId } from 'app/store/nanostores/queueId'; +import type { AppDispatch } from 'app/store/store'; +import { toast } from 'features/toast/toast'; +import { + socketBulkDownloadComplete, + socketBulkDownloadError, + socketBulkDownloadStarted, + socketConnected, + socketDisconnected, + socketDownloadCancelled, + socketDownloadComplete, + socketDownloadError, + socketDownloadProgress, + socketDownloadStarted, + socketGeneratorProgress, + socketInvocationComplete, + socketInvocationError, + socketInvocationStarted, + socketModelInstallCancelled, + socketModelInstallComplete, + socketModelInstallDownloadProgress, + socketModelInstallDownloadsComplete, + socketModelInstallError, + socketModelInstallStarted, + socketModelLoadComplete, + socketModelLoadStarted, + socketQueueItemStatusChanged, +} from 'services/events/actions'; +import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; +import type { Socket } from 'socket.io-client'; + +type SetEventListenersArg = { + socket: Socket; + dispatch: AppDispatch; +}; + +export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => { + socket.on('connect', () => { + dispatch(socketConnected()); + const queue_id = $queueId.get(); + socket.emit('subscribe_queue', { queue_id }); + if (!$baseUrl.get()) { + const bulk_download_id = $bulkDownloadId.get(); + socket.emit('subscribe_bulk_download', { bulk_download_id }); + } + }); + socket.on('connect_error', (error) => { + if (error && error.message) { + const data: string | undefined = (error as unknown as { data: string | undefined }).data; + if (data === 'ERR_UNAUTHENTICATED') { + toast({ + id: `connect-error-${error.message}`, + title: error.message, + status: 'error', + duration: 10000, + }); + } + } + }); + socket.on('disconnect', () => { + dispatch(socketDisconnected()); + }); + socket.on('invocation_started', (data) => { + dispatch(socketInvocationStarted({ data })); + }); + socket.on('invocation_denoise_progress', (data) => { + dispatch(socketGeneratorProgress({ data })); + }); + socket.on('invocation_error', (data) => { + dispatch(socketInvocationError({ data })); + }); + socket.on('invocation_complete', (data) => { + dispatch(socketInvocationComplete({ data })); + }); + socket.on('model_load_started', (data) => { + dispatch(socketModelLoadStarted({ data })); + }); + socket.on('model_load_complete', (data) => { + dispatch(socketModelLoadComplete({ data })); + }); + socket.on('download_started', (data) => { + dispatch(socketDownloadStarted({ data })); + }); + socket.on('download_progress', (data) => { + dispatch(socketDownloadProgress({ data })); + }); + socket.on('download_complete', (data) => { + dispatch(socketDownloadComplete({ data })); + }); + socket.on('download_cancelled', (data) => { + dispatch(socketDownloadCancelled({ data })); + }); + socket.on('download_error', (data) => { + dispatch(socketDownloadError({ data })); + }); + socket.on('model_install_started', (data) => { + dispatch(socketModelInstallStarted({ data })); + }); + socket.on('model_install_download_progress', (data) => { + dispatch(socketModelInstallDownloadProgress({ data })); + }); + socket.on('model_install_downloads_complete', (data) => { + dispatch(socketModelInstallDownloadsComplete({ data })); + }); + socket.on('model_install_complete', (data) => { + dispatch(socketModelInstallComplete({ data })); + }); + socket.on('model_install_error', (data) => { + dispatch(socketModelInstallError({ data })); + }); + socket.on('model_install_cancelled', (data) => { + dispatch(socketModelInstallCancelled({ data })); + }); + socket.on('queue_item_status_changed', (data) => { + dispatch(socketQueueItemStatusChanged({ data })); + }); + socket.on('bulk_download_started', (data) => { + dispatch(socketBulkDownloadStarted({ data })); + }); + socket.on('bulk_download_complete', (data) => { + dispatch(socketBulkDownloadComplete({ data })); + }); + socket.on('bulk_download_error', (data) => { + dispatch(socketBulkDownloadError({ data })); + }); +}; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 161a85b8f6..3a7de93627 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,275 +1,73 @@ -import type { components } from 'services/api/schema'; -import type { AnyModelConfig, Graph, GraphExecutionState, SubModelType } from 'services/api/types'; - -/** - * A progress image, we get one for each step in the generation - */ -export type ProgressImage = { - dataURL: string; - width: number; - height: number; -}; +import type { Graph, GraphExecutionState, S } from 'services/api/types'; export type AnyInvocation = NonNullable[string]>; export type AnyResult = NonNullable; -type BaseNode = { - id: string; - type: string; - [key: string]: AnyInvocation[keyof AnyInvocation]; -}; +export type ModelLoadStartedEvent = S['ModelLoadStartedEvent']; +export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent']; -export type ModelLoadStartedEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - model_config: AnyModelConfig; - submodel_type?: SubModelType | null; +export type InvocationStartedEvent = Omit & { invocation: AnyInvocation }; +export type InvocationDenoiseProgressEvent = Omit & { + invocation: AnyInvocation; }; - -export type ModelLoadCompletedEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - model_config: AnyModelConfig; - submodel_type?: SubModelType | null; -}; - -export type ModelInstallDownloadingEvent = { - bytes: number; - local_path: string; - source: string; - timestamp: number; - total_bytes: number; - id: number; -}; - -export type ModelInstallCompletedEvent = { - key: number; - source: string; - timestamp: number; - id: number; -}; - -export type ModelInstallErrorEvent = { - error: string; - error_type: string; - source: string; - timestamp: number; - id: number; -}; - -export type ModelInstallCancelledEvent = { - source: string; - timestamp: number; - id: number; -}; - -/** - * A `generator_progress` socket.io event. - * - * @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... } - */ -export type GeneratorProgressEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - node_id: string; - source_node_id: string; - progress_image?: ProgressImage; - step: number; - order: number; - total_steps: number; -}; - -/** - * A `invocation_complete` socket.io event. - * - * `result` is a discriminated union with a `type` property as the discriminant. - * - * @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... } - */ -export type InvocationCompleteEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - node: BaseNode; - source_node_id: string; +export type InvocationCompleteEvent = Omit & { result: AnyResult; + invocation: AnyInvocation; }; +export type InvocationErrorEvent = Omit & { invocation: AnyInvocation }; +export type ProgressImage = InvocationDenoiseProgressEvent['progress_image']; -/** - * A `invocation_error` socket.io event. - * - * @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... } - */ -export type InvocationErrorEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - node: BaseNode; - source_node_id: string; - error_type: string; - error: string; -}; +export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; +export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent']; +export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent']; +export type ModelInstallErrorEvent = S['ModelInstallErrorEvent']; +export type ModelInstallStartedEvent = S['ModelInstallStartedEvent']; +export type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent']; -/** - * A `invocation_started` socket.io event. - * - * @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... } - */ -export type InvocationStartedEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - node: BaseNode; - source_node_id: string; -}; +export type DownloadStartedEvent = S['DownloadStartedEvent']; +export type DownloadProgressEvent = S['DownloadProgressEvent']; +export type DownloadCompleteEvent = S['DownloadCompleteEvent']; +export type DownloadCancelledEvent = S['DownloadCancelledEvent']; +export type DownloadErrorEvent = S['DownloadErrorEvent']; -/** - * A `graph_execution_state_complete` socket.io event. - * - * @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... } - */ -export type GraphExecutionStateCompleteEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; -}; +export type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent']; -/** - * A `session_retrieval_error` socket.io event. - * - * @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... } - */ -export type SessionRetrievalErrorEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - error_type: string; - error: string; -}; - -/** - * A `invocation_retrieval_error` socket.io event. - * - * @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... } - */ -export type InvocationRetrievalErrorEvent = { - queue_id: string; - queue_item_id: number; - queue_batch_id: string; - graph_execution_state_id: string; - node_id: string; - error_type: string; - error: string; -}; - -/** - * A `queue_item_status_changed` socket.io event. - * - * @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... } - */ -export type QueueItemStatusChangedEvent = { - queue_id: string; - queue_item: { - queue_id: string; - item_id: number; - batch_id: string; - session_id: string; - status: components['schemas']['SessionQueueItemDTO']['status']; - error: string | undefined; - created_at: string; - updated_at: string; - started_at: string | undefined; - completed_at: string | undefined; - }; - batch_status: { - queue_id: string; - batch_id: string; - pending: number; - in_progress: number; - completed: number; - failed: number; - canceled: number; - total: number; - }; - queue_status: { - queue_id: string; - item_id?: number; - batch_id?: string; - session_id?: string; - pending: number; - in_progress: number; - completed: number; - failed: number; - canceled: number; - total: number; - }; -}; +export type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent']; +export type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent']; +export type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent']; type ClientEmitSubscribeQueue = { queue_id: string; }; - -type ClientEmitUnsubscribeQueue = { - queue_id: string; -}; - -export type BulkDownloadStartedEvent = { - bulk_download_id: string; - bulk_download_item_id: string; - bulk_download_item_name: string; -}; - -export type BulkDownloadCompletedEvent = { - bulk_download_id: string; - bulk_download_item_id: string; - bulk_download_item_name: string; -}; - -export type BulkDownloadFailedEvent = { - bulk_download_id: string; - bulk_download_item_id: string; - bulk_download_item_name: string; - error: string; -}; - +type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue; type ClientEmitSubscribeBulkDownload = { bulk_download_id: string; }; - -type ClientEmitUnsubscribeBulkDownload = { - bulk_download_id: string; -}; +type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload; export type ServerToClientEvents = { - generator_progress: (payload: GeneratorProgressEvent) => void; + invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void; invocation_complete: (payload: InvocationCompleteEvent) => void; invocation_error: (payload: InvocationErrorEvent) => void; invocation_started: (payload: InvocationStartedEvent) => void; - graph_execution_state_complete: (payload: GraphExecutionStateCompleteEvent) => void; + download_started: (payload: DownloadStartedEvent) => void; + download_progress: (payload: DownloadProgressEvent) => void; + download_complete: (payload: DownloadCompleteEvent) => void; + download_cancelled: (payload: DownloadCancelledEvent) => void; + download_error: (payload: DownloadErrorEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; - model_load_completed: (payload: ModelLoadCompletedEvent) => void; - model_install_downloading: (payload: ModelInstallDownloadingEvent) => void; - model_install_completed: (payload: ModelInstallCompletedEvent) => void; + model_install_started: (payload: ModelInstallStartedEvent) => void; + model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; + model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; + model_install_complete: (payload: ModelInstallCompleteEvent) => void; model_install_error: (payload: ModelInstallErrorEvent) => void; - model_install_canceled: (payload: ModelInstallCancelledEvent) => void; - session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void; - invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void; + model_install_cancelled: (payload: ModelInstallCancelledEvent) => void; + model_load_complete: (payload: ModelLoadCompleteEvent) => void; queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void; bulk_download_started: (payload: BulkDownloadStartedEvent) => void; - bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void; - bulk_download_failed: (payload: BulkDownloadFailedEvent) => void; + bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void; + bulk_download_error: (payload: BulkDownloadFailedEvent) => void; }; export type ClientToServerEvents = { diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts deleted file mode 100644 index 4476624e4e..0000000000 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ /dev/null @@ -1,210 +0,0 @@ -import { $baseUrl } from 'app/store/nanostores/baseUrl'; -import { $bulkDownloadId } from 'app/store/nanostores/bulkDownloadId'; -import { $queueId } from 'app/store/nanostores/queueId'; -import type { AppDispatch } from 'app/store/store'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; -import { - socketBulkDownloadCompleted, - socketBulkDownloadFailed, - socketBulkDownloadStarted, - socketConnected, - socketDisconnected, - socketGeneratorProgress, - socketGraphExecutionStateComplete, - socketInvocationComplete, - socketInvocationError, - socketInvocationRetrievalError, - socketInvocationStarted, - socketModelInstallCompleted, - socketModelInstallDownloading, - socketModelInstallError, - socketModelLoadCompleted, - socketModelLoadStarted, - socketQueueItemStatusChanged, - socketSessionRetrievalError, -} from 'services/events/actions'; -import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import type { Socket } from 'socket.io-client'; - -type SetEventListenersArg = { - socket: Socket; - dispatch: AppDispatch; -}; - -export const setEventListeners = (arg: SetEventListenersArg) => { - const { socket, dispatch } = arg; - - /** - * Connect - */ - socket.on('connect', () => { - dispatch(socketConnected()); - const queue_id = $queueId.get(); - socket.emit('subscribe_queue', { queue_id }); - if (!$baseUrl.get()) { - const bulk_download_id = $bulkDownloadId.get(); - socket.emit('subscribe_bulk_download', { bulk_download_id }); - } - }); - - socket.on('connect_error', (error) => { - if (error && error.message) { - const data: string | undefined = (error as unknown as { data: string | undefined }).data; - if (data === 'ERR_UNAUTHENTICATED') { - dispatch( - addToast( - makeToast({ - title: error.message, - status: 'error', - duration: 10000, - }) - ) - ); - } - } - }); - - /** - * Disconnect - */ - socket.on('disconnect', () => { - dispatch(socketDisconnected()); - }); - - /** - * Invocation started - */ - socket.on('invocation_started', (data) => { - dispatch(socketInvocationStarted({ data })); - }); - - /** - * Generator progress - */ - socket.on('generator_progress', (data) => { - dispatch(socketGeneratorProgress({ data })); - }); - - /** - * Invocation error - */ - socket.on('invocation_error', (data) => { - dispatch(socketInvocationError({ data })); - }); - - /** - * Invocation complete - */ - socket.on('invocation_complete', (data) => { - dispatch( - socketInvocationComplete({ - data, - }) - ); - }); - - /** - * Graph complete - */ - socket.on('graph_execution_state_complete', (data) => { - dispatch( - socketGraphExecutionStateComplete({ - data, - }) - ); - }); - - /** - * Model load started - */ - socket.on('model_load_started', (data) => { - dispatch( - socketModelLoadStarted({ - data, - }) - ); - }); - - /** - * Model load completed - */ - socket.on('model_load_completed', (data) => { - dispatch( - socketModelLoadCompleted({ - data, - }) - ); - }); - - /** - * Model Install Downloading - */ - socket.on('model_install_downloading', (data) => { - dispatch( - socketModelInstallDownloading({ - data, - }) - ); - }); - - /** - * Model Install Completed - */ - socket.on('model_install_completed', (data) => { - dispatch( - socketModelInstallCompleted({ - data, - }) - ); - }); - - /** - * Model Install Error - */ - socket.on('model_install_error', (data) => { - dispatch( - socketModelInstallError({ - data, - }) - ); - }); - - /** - * Session retrieval error - */ - socket.on('session_retrieval_error', (data) => { - dispatch( - socketSessionRetrievalError({ - data, - }) - ); - }); - - /** - * Invocation retrieval error - */ - socket.on('invocation_retrieval_error', (data) => { - dispatch( - socketInvocationRetrievalError({ - data, - }) - ); - }); - - socket.on('queue_item_status_changed', (data) => { - dispatch(socketQueueItemStatusChanged({ data })); - }); - - socket.on('bulk_download_started', (data) => { - dispatch(socketBulkDownloadStarted({ data })); - }); - - socket.on('bulk_download_completed', (data) => { - dispatch(socketBulkDownloadCompleted({ data })); - }); - - socket.on('bulk_download_failed', (data) => { - dispatch(socketBulkDownloadFailed({ data })); - }); -}; diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index aef46acb47..6e997e12f5 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "4.2.1" +__version__ = "4.2.3" diff --git a/pyproject.toml b/pyproject.toml index 86cbb8315c..3913a6cd1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ classifiers = [ ] dependencies = [ # Core generation dependencies, pinned for reproducible builds. - "accelerate==0.29.2", - "clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", + "accelerate==0.30.1", + "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7", "diffusers[torch]==0.27.2", @@ -45,18 +45,18 @@ dependencies = [ "onnxruntime==1.16.3", "opencv-python==4.9.0.80", "pytorch-lightning==2.1.3", - "safetensors==0.4.2", + "safetensors==0.4.3", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "torch==2.2.2", "torchmetrics==0.11.4", "torchsde==0.2.6", "torchvision==0.17.2", - "transformers==4.39.3", + "transformers==4.41.1", # Core application dependencies, pinned for reproducible builds. "fastapi-events==0.11.0", "fastapi==0.110.0", - "huggingface-hub==0.22.2", + "huggingface-hub==0.23.1", "pydantic-settings==2.2.1", "pydantic==2.6.3", "python-socketio==5.11.1", diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index b18f6e038d..bf3bd27993 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -9,6 +9,11 @@ import pytest from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.events.events_common import ( + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, +) from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecordNotFoundException, @@ -281,9 +286,9 @@ def assert_handler_success( # Check that the correct events were emitted assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_completed" - assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent) + assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path) def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): @@ -329,9 +334,9 @@ def test_handler_on_generic_exception( event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_failed" - assert event_bus.events[1].payload["error"] == exception.__str__() + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadErrorEvent) + assert event_bus.events[1].error == exception.__str__() def execute_handler_test_on_error( @@ -344,9 +349,9 @@ def execute_handler_test_on_error( event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_failed" - assert event_bus.events[1].payload["error"] == error.__str__() + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadErrorEvent) + assert event_bus.events[1].error == error.__str__() def test_delete(tmp_path: Path): diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index c9317163c8..fd2e2a65ae 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -14,6 +14,13 @@ from requests_testadapter import TestAdapter from invokeai.app.services.config import get_config from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob +from invokeai.app.services.events.events_common import ( + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, +) from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService @@ -88,14 +95,14 @@ def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: queue.join() events = event_bus.events assert len(events) == 3 - assert events[0].payload["timestamp"] <= events[1].payload["timestamp"] - assert events[1].payload["timestamp"] <= events[2].payload["timestamp"] - assert events[0].event_name == "download_started" - assert events[1].event_name == "download_progress" - assert events[1].payload["total_bytes"] > 0 - assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"] - assert events[2].event_name == "download_complete" - assert events[2].payload["total_bytes"] == 32029 + assert isinstance(events[0], DownloadStartedEvent) + assert isinstance(events[1], DownloadProgressEvent) + assert isinstance(events[2], DownloadCompleteEvent) + assert events[0].timestamp <= events[1].timestamp + assert events[1].timestamp <= events[2].timestamp + assert events[1].total_bytes > 0 + assert events[1].current_bytes <= events[1].total_bytes + assert events[2].total_bytes == 32029 # test a failure event_bus.events = [] # reset our accumulator @@ -104,10 +111,10 @@ def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: events = event_bus.events print("\n".join([x.model_dump_json() for x in events])) assert len(events) == 1 - assert events[0].event_name == "download_error" - assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)" - assert events[0].payload["error"] is not None - assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"]) + assert isinstance(events[0], DownloadErrorEvent) + assert events[0].error_type == "HTTPError(NOT FOUND)" + assert events[0].error is not None + assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error) queue.stop() @@ -171,8 +178,8 @@ def test_cancel(tmp_path: Path, mm2_session: Session) -> None: assert job.status == DownloadJobStatus.CANCELLED assert cancelled events = event_bus.events - assert events[-1].event_name == "download_cancelled" - assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" + assert isinstance(events[-1], DownloadCancelledEvent) + assert events[-1].source == "http://www.civitai.com/models/12345" queue.stop() @@ -278,7 +285,7 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any assert job.status == DownloadJobStatus.CANCELLED assert cancelled events = event_bus.events - assert "download_cancelled" in [x.event_name for x in events] + assert DownloadCancelledEvent in [type(x) for x in events] queue.stop() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index ca8616238f..5c9f908ccc 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -13,12 +13,20 @@ from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, +) from invokeai.app.services.model_install import ( HFModelSource, + ModelInstallServiceBase, +) +from invokeai.app.services.model_install.model_install_common import ( InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException @@ -30,6 +38,7 @@ from invokeai.backend.model_manager.config import ( ModelType, ) from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 +from tests.test_nodes import TestEventService OS = platform.uname().system @@ -137,17 +146,16 @@ def test_background_install( assert job.total_bytes == size # test that the expected events were issued - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus assert bus assert hasattr(bus, "events") assert len(bus.events) == 2 - event_names = [x.event_name for x in bus.events] - assert "model_install_running" in event_names - assert "model_install_completed" in event_names - assert Path(bus.events[0].payload["source"]) == source - assert Path(bus.events[1].payload["source"]) == source - key = bus.events[1].payload["key"] + assert isinstance(bus.events[0], ModelInstallStartedEvent) + assert isinstance(bus.events[1], ModelInstallCompleteEvent) + assert Path(bus.events[0].source) == source + assert Path(bus.events[1].source) == source + key = bus.events[1].key assert key is not None # see if the thing actually got installed at the expected location @@ -226,7 +234,7 @@ def test_delete_register( def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus store = mm2_installer.record_store assert store is not None assert bus is not None @@ -244,20 +252,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: assert (mm2_app_config.models_path / model_record.path).exists() assert len(bus.events) == 4 - event_names = [x.event_name for x in bus.events] - assert event_names == [ - "model_install_downloading", - "model_install_downloads_done", - "model_install_running", - "model_install_completed", - ] + assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) + assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) + assert isinstance(bus.events[2], ModelInstallStartedEvent) + assert isinstance(bus.events[3], ModelInstallCompleteEvent) @pytest.mark.timeout(timeout=10, method="thread") def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus store = mm2_installer.record_store assert isinstance(bus, EventServiceBase) assert store is not None @@ -274,15 +279,10 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con assert model_record.type == ModelType.Main assert model_record.format == ModelFormat.Diffusers - assert hasattr(bus, "events") # the dummyeventservice has this + assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events) + assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events) + assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events) assert len(bus.events) >= 3 - event_names = {x.event_name for x in bus.events} - assert event_names == { - "model_install_downloading", - "model_install_downloads_done", - "model_install_running", - "model_install_completed", - } @pytest.mark.timeout(timeout=10, method="thread") @@ -308,19 +308,24 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con assert hasattr(bus, "events") # the dummyeventservice has this assert len(bus.events) >= 3 - event_names = {x.event_name for x in bus.events} - assert event_names == { - "model_install_downloading", - "model_install_downloads_done", - "model_install_running", - "model_install_completed", - } + event_types = [type(x) for x in bus.events] + assert all( + x in event_types + for x in [ + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ] + ) - completed_events = [x for x in bus.events if x.event_name == "model_install_completed"] - downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"] - assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"] - assert job.total_bytes == completed_events[0].payload["total_bytes"] - assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) + completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)] + downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)] + assert completed_events[0].total_bytes == downloading_events[-1].bytes + assert job.total_bytes == completed_events[0].total_bytes + print(downloading_events[-1]) + print(job.download_parts) + assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts) def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 2af321d60f..a10bc4d66a 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -19,7 +19,7 @@ def mock_context( return build_invocation_context( services=mock_services, data=None, # type: ignore - cancel_event=None, # type: ignore + is_canceled=None, # type: ignore ) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 0301101a19..dc2ad2f1e4 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -3,16 +3,13 @@ import os import shutil from pathlib import Path -from typing import Any, Dict, List import pytest -from pydantic import BaseModel from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase @@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import ( RepoHFModelJson1, ) from tests.fixtures.sqlite_database import create_mock_sqlite_database - - -class DummyEvent(BaseModel): - """Dummy Event to use with Dummy Event service.""" - - event_name: str - payload: Dict[str, Any] - - -class DummyEventService(EventServiceBase): - """Dummy event service for testing.""" - - events: List[DummyEvent] - - def __init__(self) -> None: - super().__init__() - self.events = [] - - def dispatch(self, event_name: str, payload: Any) -> None: - """Dispatch an event by appending it to self.events.""" - self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) +from tests.test_nodes import TestEventService # Create a temporary directory using the contents of `./data/invokeai_root` as the template @@ -127,7 +104,7 @@ def mm2_installer( ) -> ModelInstallServiceBase: logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) - events = DummyEventService() + events = TestEventService() store = ModelRecordServiceSQL(db) installer = ModelInstallService( diff --git a/tests/conftest.py b/tests/conftest.py index 7a7fdf32bb..8a67e9473c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.backend.util.logging import InvokeAILogger +from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401 from tests.test_nodes import TestEventService diff --git a/tests/test_nodes.py b/tests/test_nodes.py index e1fe857040..2d413a2687 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,7 +1,5 @@ from typing import Any, Callable, Union -from pydantic import BaseModel - from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.shared.invocation_context import InvocationContext @@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg ) -class TestEvent(BaseModel): +class TestEvent(EventBase): __test__ = False # not a pytest test case - event_name: str - payload: Any + __event_name__ = "test_event" class TestEventService(EventServiceBase): @@ -129,10 +127,10 @@ class TestEventService(EventServiceBase): def __init__(self): super().__init__() - self.events: list[TestEvent] = [] + self.events: list[EventBase] = [] - def dispatch(self, event_name: str, payload: Any) -> None: - self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"])) + def dispatch(self, event: EventBase) -> None: + self.events.append(event) pass
{queueItem.error}
{queueItem?.error_traceback || queueItem?.error_message}