diff --git a/docker/README.md b/docker/README.md index d5a472da7d..9e7ac15145 100644 --- a/docker/README.md +++ b/docker/README.md @@ -64,7 +64,7 @@ GPU_DRIVER=nvidia Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail. -## Even Moar Customizing! +## Even More Customizing! See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below. diff --git a/docs/contributing/frontend/WORKFLOWS.md b/docs/contributing/frontend/WORKFLOWS.md index e71d797b8a..533419e070 100644 --- a/docs/contributing/frontend/WORKFLOWS.md +++ b/docs/contributing/frontend/WORKFLOWS.md @@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances "Custom" fields will always be treated as stateless fields. -##### Collection and Scalar Fields +##### Single and Collection Fields -Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field. +Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field. -If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`). - -If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). +- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`). +- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`). +- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). ## Implementation @@ -173,8 +173,7 @@ Field types are represented as structured objects: ```ts type FieldType = { name: string; - isCollection: boolean; - isCollectionOrScalar: boolean; + cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION'; }; ``` @@ -186,7 +185,7 @@ There are 4 general cases for field type parsing. When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. -We create a field type name from this `type` string (e.g. `string` -> `StringField`). +We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`. ##### Complex Types @@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. -We use the item type for field type name, adding `isCollection: true` to the field type. +We use the item type for field type name. The cardinality is `"COLLECTION"`. -##### Collection or Scalar Types +##### Single or Collection Types When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. -After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type. +After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`. ##### Optional Fields diff --git a/docs/features/CONTROLNET.md b/docs/features/CONTROLNET.md index d07353089d..718b12b0f8 100644 --- a/docs/features/CONTROLNET.md +++ b/docs/features/CONTROLNET.md @@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o There are several ways to install IP-Adapter models with an existing InvokeAI installation: 1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models. -2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. +2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. 3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder. #### Using IP-Adapter diff --git a/docs/help/diffusion.md b/docs/help/diffusion.md index 0dbb09f304..7182a51d67 100644 --- a/docs/help/diffusion.md +++ b/docs/help/diffusion.md @@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s 4. The VAE decodes the final latent image from latent space into image space. Image-to-image is a similar process, with only step 1 being different: -1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. +1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor). diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 36859a5795..059834eb45 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The ### Requirements -Before you start, go through the [installation requirements]. +Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md). ### Installation Walkthrough @@ -79,7 +79,7 @@ Before you start, go through the [installation requirements]. 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. - - You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command. + - You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command. !!! example "Install with an extra index URL" @@ -116,4 +116,4 @@ Before you start, go through the [installation requirements]. !!! warning - If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. + If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. diff --git a/installer/templates/invoke.bat.in b/installer/templates/invoke.bat.in index 9c9c08d82d..c8ef19710b 100644 --- a/installer/templates/invoke.bat.in +++ b/installer/templates/invoke.bat.in @@ -10,8 +10,7 @@ set INVOKEAI_ROOT=. echo Desired action: echo 1. Generate images with the browser-based interface echo 2. Open the developer console -echo 3. Run the InvokeAI image database maintenance script -echo 4. Command-line help +echo 3. Command-line help echo Q - Quit echo. echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest. @@ -34,9 +33,6 @@ IF /I "%choice%" == "1" ( echo *** Type `exit` to quit this shell and deactivate the Python virtual environment *** call cmd /k ) ELSE IF /I "%choice%" == "3" ( - echo Running the db maintenance script... - python .venv\Scripts\invokeai-db-maintenance.exe -) ELSE IF /I "%choice%" == "4" ( echo Displaying command line help... python .venv\Scripts\invokeai-web.exe --help %* pause diff --git a/installer/templates/invoke.sh.in b/installer/templates/invoke.sh.in index 9c45eba5b0..b8d5a7af23 100644 --- a/installer/templates/invoke.sh.in +++ b/installer/templates/invoke.sh.in @@ -47,11 +47,6 @@ do_choice() { bash --init-file "$file_name" ;; 3) - clear - printf "Running the db maintenance script\n" - invokeai-db-maintenance --root ${INVOKEAI_ROOT} - ;; - 4) clear printf "Command-line help\n" invokeai-web --help @@ -71,8 +66,7 @@ do_line_input() { printf "What would you like to do?\n" printf "1: Generate images using the browser-based interface\n" printf "2: Open the developer console\n" - printf "3: Run the InvokeAI image database maintenance script\n" - printf "4: Command-line help\n" + printf "3: Command-line help\n" printf "Q: Quit\n\n" printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n" read -p "Please enter 1-4, Q: [1] " yn diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0cfcf2f3b7..19a7bb083d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService +from ..services.events.events_fastapievents import FastAPIEventService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -29,11 +30,10 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService -from ..services.session_processor.session_processor_default import DefaultSessionProcessor +from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage -from .events import FastAPIEventService # TODO: is there a better way to achieve this? @@ -103,7 +103,7 @@ class ApiDependencies: ) names = SimpleNameService() performance_statistics = InvocationStatsService() - session_processor = DefaultSessionProcessor() + session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner()) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py deleted file mode 100644 index 2ac07e6dfe..0000000000 --- a/invokeai/app/api/events.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -import threading -from queue import Empty, Queue -from typing import Any - -from fastapi_events.dispatcher import dispatch - -from ..services.events.events_base import EventServiceBase - - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event)) - - super().__init__() - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put({"event_name": event_name, "payload": payload}) - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block=False) - if not event: # Probably stopping - continue - - dispatch( - event.get("event_name"), - payload=event.get("payload"), - middleware_id=self.event_handler_id, - ) - - except Empty: - await asyncio.sleep(0.1) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 9c55ff6531..a947b83abe 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin @@ -41,14 +41,17 @@ async def upload_image( board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), + metadata: Optional[JsonValue] = Body( + default=None, description="The metadata to associate with the image", embed=True + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") - metadata = None - workflow = None - graph = None + _metadata = None + _workflow = None + _graph = None contents = await file.read() try: @@ -62,27 +65,27 @@ async def upload_image( # TODO: retain non-invokeai metadata on upload? # attempt to parse metadata from image - metadata_raw = pil_image.info.get("invokeai_metadata", None) + metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None) if isinstance(metadata_raw, str): - metadata = metadata_raw + _metadata = metadata_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image") pass # attempt to parse workflow from image workflow_raw = pil_image.info.get("invokeai_workflow", None) if isinstance(workflow_raw, str): - workflow = workflow_raw + _workflow = workflow_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image") pass # attempt to extract graph from image graph_raw = pil_image.info.get("invokeai_graph", None) if isinstance(graph_raw, str): - graph = graph_raw + _graph = graph_raw else: - ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image") + ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image") pass try: @@ -92,9 +95,9 @@ async def upload_image( image_category=image_category, session_id=session_id, board_id=board_id, - metadata=metadata, - workflow=workflow, - graph=graph, + metadata=_metadata, + workflow=_workflow, + graph=_graph, is_intermediate=is_intermediate, ) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 1ba3e30e07..b1221f7a34 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException from typing_extensions import Annotated from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException -from invokeai.app.services.model_install import ModelInstallJob +from invokeai.app.services.model_install.model_install_common import ModelInstallJob from invokeai.app.services.model_records import ( DuplicateModelException, InvalidModelException, diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 40f1f2213b..7161e54a41 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -203,6 +203,7 @@ async def get_batch_status( responses={ 200: {"model": SessionQueueItem}, }, + response_model_exclude_none=True, ) async def get_queue_item( queue_id: str = Path(description="The queue id to perform this operation on"), diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 463545d9bc..b39922c69b 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -1,66 +1,125 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from typing import Any + from fastapi import FastAPI -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event +from pydantic import BaseModel from socketio import ASGIApp, AsyncServer -from ..services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadEventBase, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadEventBase, + DownloadProgressEvent, + DownloadStartedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelEventBase, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueEventBase, + QueueItemStatusChangedEvent, + register_events, +) + + +class QueueSubscriptionEvent(BaseModel): + """Event data for subscribing to the socket.io queue room. + This is a pydantic model to ensure the data is in the correct format.""" + + queue_id: str + + +class BulkDownloadSubscriptionEvent(BaseModel): + """Event data for subscribing to the socket.io bulk downloads room. + This is a pydantic model to ensure the data is in the correct format.""" + + bulk_download_id: str + + +QUEUE_EVENTS = { + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, +} + +MODEL_EVENTS = { + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ModelInstallCancelledEvent, + ModelInstallErrorEvent, +} + +BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} class SocketIO: - __sio: AsyncServer - __app: ASGIApp + _sub_queue = "subscribe_queue" + _unsub_queue = "unsubscribe_queue" - __sub_queue: str = "subscribe_queue" - __unsub_queue: str = "unsubscribe_queue" - - __sub_bulk_download: str = "subscribe_bulk_download" - __unsub_bulk_download: str = "unsubscribe_bulk_download" + _sub_bulk_download = "subscribe_bulk_download" + _unsub_bulk_download = "unsubscribe_bulk_download" def __init__(self, app: FastAPI): - self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") - self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") - app.mount("/ws", self.__app) + self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") + self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") + app.mount("/ws", self._app) - self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) - self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) - local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self._sio.on(self._sub_queue, handler=self._handle_sub_queue) + self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) + self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) + self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) - self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) - self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) - local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) + register_events(QUEUE_EVENTS, self._handle_queue_event) + register_events(MODEL_EVENTS, self._handle_model_event) + register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) - async def _handle_queue_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["queue_id"], - ) + async def _handle_sub_queue(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) - async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.enter_room(sid, data["queue_id"]) + async def _handle_unsub_queue(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) - async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.leave_room(sid, data["queue_id"]) + async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) + async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: + await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) - async def _handle_bulk_download_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["bulk_download_id"], - ) + async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id) - async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.enter_room(sid, data["bulk_download_id"]) + async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None: + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) - async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): - if "bulk_download_id" in data: - await self.__sio.leave_room(sid, data["bulk_download_id"]) + async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: + await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 062682f7d0..b7da548377 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice @@ -182,23 +183,14 @@ def custom_openapi() -> dict[str, Any]: openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type()) invoker_schema["class"] = "invocation" - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } + # Add all event schemas + for event in sorted(EventBase.get_events(), key=lambda e: e.__name__): + json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}") + if "$defs" in json_schema: + for schema_key, schema in json_schema["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema + del json_schema["$defs"] + openapi_schema["components"]["schemas"][event.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 158f11a58e..766b44fdc8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(self.clip.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(self.clip.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation): ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, CLIPTextModel) + assert isinstance(tokenizer, CLIPTokenizer) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) if context.config.get().log_tokenization: - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -136,11 +134,7 @@ class SDXLPromptInvocationBase: zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: tokenizer_info = context.models.load(clip_field.tokenizer) - tokenizer_model = tokenizer_info.model - assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(clip_field.text_encoder) - text_encoder_model = text_encoder_info.model - assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) # return zero on empty if prompt == "" and zero_on_empty: @@ -177,20 +171,23 @@ class SDXLPromptInvocationBase: ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) with ( - ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( - tokenizer, - ti_manager, - ), + # apply all patches while the model is on the target device text_encoder_info as text_encoder, - # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + tokenizer_info as tokenizer, ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), + ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( + patched_tokenizer, + ti_manager, + ), ): assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) + assert isinstance(tokenizer, CLIPTokenizer) + text_encoder = cast(CLIPTextModel, text_encoder) compel = Compel( - tokenizer=tokenizer, + tokenizer=patched_tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, @@ -203,7 +200,7 @@ class SDXLPromptInvocationBase: if context.config.get().log_tokenization: # TODO: better logging for and syntax - log_tokenization_for_conjunction(conjunction, tokenizer) + log_tokenization_for_conjunction(conjunction, patched_tokenizer) # TODO: ask for optimizations? to not run text_encoder twice c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e69f4b54ad..e533583829 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,7 +25,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, - Input, InputField, OutputField, UIType, @@ -82,13 +81,13 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) -@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") +@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") control_model: ModelIdentifierField = InputField( - description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel + description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel ) control_weight: Union[float, List[float]] = InputField( default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 34a30628da..de40879eef 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput): CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation): ip_adapter_model: ModelIdentifierField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.IPAdapterModel, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3ac3973bf..a88eff0fcb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), - set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 245034c481..94a6136fcb 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) @@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): pass +@invocation_output("model_identifier_output") +class ModelIdentifierOutput(BaseInvocationOutput): + """Model identifier output""" + + model: ModelIdentifierField = OutputField(description="Model identifier", title="Model") + + +@invocation( + "model_identifier", + title="Model identifier", + tags=["model"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class ModelIdentifierInvocation(BaseInvocation): + """Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as + input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an + error.""" + + model: ModelIdentifierField = InputField(description="The model to select", title="Model") + + def invoke(self, context: InvocationContext) -> ModelIdentifierOutput: + if not context.models.exists(self.model.key): + raise Exception(f"Unknown model {self.model.key}") + + return ModelIdentifierOutput(model=self.model) + + @invocation( "main_model_loader", title="Main Model", tags=["model"], category="model", - version="1.0.2", + version="1.0.3", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" - model: ModelIdentifierField = InputField( - description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel - ) + model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel) # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: @@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput): clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3") class LoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -197,12 +225,12 @@ class LoRASelectorOutput(BaseInvocationOutput): lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA") -@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0") +@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1") class LoRASelectorInvocation(BaseInvocation): """Selects a LoRA model and weight.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) @@ -273,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLLoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -414,12 +442,12 @@ class SDXLLoRACollectionLoader(BaseInvocation): return output -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3") class VAELoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" vae_model: ModelIdentifierField = InputField( - description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel + description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel ) def invoke(self, context: InvocationContext) -> VAEOutput: diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 9b1ee90350..1c0817cb92 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,4 @@ -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import SubModelType @@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel + description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel ) # TODO: precision? @@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel + description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel ) # TODO: precision? diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index b22a089d3f..04f9a6c695 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput): @invocation( - "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" + "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3" ) class T2IAdapterInvocation(BaseInvocation): """Collects T2I-Adapter info to pass to other nodes.""" @@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation): t2i_adapter_model: ModelIdentifierField = InputField( description="The T2I-Adapter model.", title="T2I-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.T2IAdapterModel, ) diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index 04cec928f4..d4bf059b8f 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_completed( @@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None - self._invoker.services.events.emit_bulk_download_completed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, + self._invoker.services.events.emit_bulk_download_complete( + bulk_download_id, bulk_download_item_id, bulk_download_item_name ) def _signal_job_failed( @@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase): if self._invoker: assert bulk_download_id is not None assert exception is not None - self._invoker.services.events.emit_bulk_download_failed( - bulk_download_id=bulk_download_id, - bulk_download_item_id=bulk_download_item_id, - bulk_download_item_name=bulk_download_item_name, - error=str(exception), + self._invoker.services.events.emit_bulk_download_error( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) ) def stop(self, *args, **kwargs): diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 4555477004..5025255c91 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,7 +8,7 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set import requests from pydantic.networks import AnyHttpUrl @@ -34,6 +34,9 @@ from .download_base import ( UnknownJobIDException, ) +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -45,7 +48,7 @@ class DownloadQueueService(DownloadQueueServiceBase): self, max_parallel_dl: int = 5, app_config: Optional[InvokeAIAppConfig] = None, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ @@ -408,28 +411,18 @@ class DownloadQueueService(DownloadQueueServiceBase): job.status = DownloadJobStatus.RUNNING self._execute_cb(job, "on_start") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix()) + self._event_bus.emit_download_started(job) def _signal_job_progress(self, job: DownloadJob) -> None: self._execute_cb(job, "on_progress") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_progress( - str(job.source), - download_path=job.download_path.as_posix(), - current_bytes=job.bytes, - total_bytes=job.total_bytes, - ) + self._event_bus.emit_download_progress(job) def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED self._execute_cb(job, "on_complete") if self._event_bus: - assert job.download_path - self._event_bus.emit_download_complete( - str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes - ) + self._event_bus.emit_download_complete(job) def _signal_job_cancelled(self, job: DownloadJob) -> None: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: @@ -437,7 +430,7 @@ class DownloadQueueService(DownloadQueueServiceBase): job.status = DownloadJobStatus.CANCELLED self._execute_cb(job, "on_cancelled") if self._event_bus: - self._event_bus.emit_download_cancelled(str(job.source)) + self._event_bus.emit_download_cancelled(job) # if multifile download, then signal the parent if parent_job := self._download_part2parent.get(job.source, None): @@ -451,9 +444,7 @@ class DownloadQueueService(DownloadQueueServiceBase): self._execute_cb(job, "on_error", excp) if self._event_bus: - assert job.error_type - assert job.error - self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error) + self._event_bus.emit_download_error(job) def _cleanup_cancelled_job(self, job: DownloadJob) -> None: self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa91cdaec8..3c0fb0a30b 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,490 +1,195 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Optional -from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.app.services.session_queue.session_queue_common import ( - BatchStatus, - EnqueueBatchResult, - SessionQueueItem, - SessionQueueStatus, +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + EventBase, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelInstallCancelledEvent, + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, ) -from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import AnyModelConfig -from invokeai.backend.model_manager.config import SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput + from invokeai.app.services.download.download_base import DownloadJob + from invokeai.app.services.events.events_common import EventBase + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + from invokeai.app.services.session_processor.session_processor_common import ProgressImage + from invokeai.app.services.session_queue.session_queue_common import ( + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, + ) + from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType class EventServiceBase: - queue_event: str = "queue_event" - bulk_download_event: str = "bulk_download_event" - download_event: str = "download_event" - model_event: str = "model_event" - """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + def dispatch(self, event: "EventBase") -> None: pass - def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: - """Bulk download events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.bulk_download_event, - payload={"event": event_name, "data": payload}, - ) + # region: Invocation - def __emit_queue_event(self, event_name: str, payload: dict) -> None: - """Queue events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.queue_event, - payload={"event": event_name, "data": payload}, - ) + def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None: + """Emitted when an invocation is started""" + self.dispatch(InvocationStartedEvent.build(queue_item, invocation)) - def __emit_download_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.download_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_model_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.model_event, - payload={"event": event_name, "data": payload}, - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress( + def emit_invocation_denoise_progress( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - source_node_id: str, - progress_image: Optional[ProgressImage], - step: int, - order: int, - total_steps: int, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", + intermediate_state: PipelineIntermediateState, + progress_image: "ProgressImage", ) -> None: - """Emitted when there is generation progress""" - self.__emit_queue_event( - event_name="generator_progress", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "source_node_id": source_node_id, - "progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None, - "step": step, - "order": order, - "total_steps": total_steps, - }, - ) + """Emitted at each step during denoising of an invocation.""" + self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image)) def emit_invocation_complete( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - result: dict, - node: dict, - source_node_id: str, + self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput" ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "result": result, - }, - ) + """Emitted when an invocation is complete""" + self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output)) def emit_invocation_error( self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, + queue_item: "SessionQueueItem", + invocation: "BaseInvocation", error_type: str, - error: str, - user_id: str | None, - project_id: str | None, + error_message: str, + error_traceback: str, ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "error_type": error_type, - "error": error, - "user_id": user_id, - "project_id": project_id, - }, - ) + """Emitted when an invocation encounters an error""" + self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback)) - def emit_invocation_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has started""" - self.__emit_queue_event( - event_name="invocation_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - }, - ) + # endregion - def emit_graph_execution_complete( - self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str - ) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_queue_event( - event_name="graph_execution_state_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) - - def emit_model_load_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is requested""" - self.__emit_queue_event( - event_name="model_load_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) - - def emit_model_load_completed( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> None: - """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_queue_event( - event_name="model_load_completed", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(mode="json"), - "submodel_type": submodel_type, - }, - ) - - def emit_session_canceled( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - ) -> None: - """Emitted when a session is canceled""" - self.__emit_queue_event( - event_name="session_canceled", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) + # region Queue def emit_queue_item_status_changed( - self, - session_queue_item: SessionQueueItem, - batch_status: BatchStatus, - queue_status: SessionQueueStatus, + self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus" ) -> None: """Emitted when a queue item's status changes""" - self.__emit_queue_event( - event_name="queue_item_status_changed", - payload={ - "queue_id": queue_status.queue_id, - "queue_item": { - "queue_id": session_queue_item.queue_id, - "item_id": session_queue_item.item_id, - "status": session_queue_item.status, - "batch_id": session_queue_item.batch_id, - "session_id": session_queue_item.session_id, - "error": session_queue_item.error, - "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, - "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, - "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - }, - "batch_status": batch_status.model_dump(mode="json"), - "queue_status": queue_status.model_dump(mode="json"), - }, - ) + self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: """Emitted when a batch is enqueued""" - self.__emit_queue_event( - event_name="batch_enqueued", - payload={ - "queue_id": enqueue_result.queue_id, - "batch_id": enqueue_result.batch.batch_id, - "enqueued": enqueue_result.enqueued, - }, - ) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) def emit_queue_cleared(self, queue_id: str) -> None: - """Emitted when the queue is cleared""" - self.__emit_queue_event( - event_name="queue_cleared", - payload={"queue_id": queue_id}, - ) + """Emitted when a queue is cleared""" + self.dispatch(QueueClearedEvent.build(queue_id)) - def emit_download_started(self, source: str, download_path: str) -> None: - """ - Emit when a download job is started. + # endregion - :param url: The downloaded url - """ - self.__emit_download_event( - event_name="download_started", - payload={"source": source, "download_path": download_path}, - ) + # region Download - def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: - """ - Emit "download_progress" events at regular intervals during a download job. + def emit_download_started(self, job: "DownloadJob") -> None: + """Emitted when a download is started""" + self.dispatch(DownloadStartedEvent.build(job)) - :param source: The downloaded source - :param download_path: The local downloaded file - :param current_bytes: Number of bytes downloaded so far - :param total_bytes: The size of the file being downloaded (if known) - """ - self.__emit_download_event( - event_name="download_progress", - payload={ - "source": source, - "download_path": download_path, - "current_bytes": current_bytes, - "total_bytes": total_bytes, - }, - ) + def emit_download_progress(self, job: "DownloadJob") -> None: + """Emitted at intervals during a download""" + self.dispatch(DownloadProgressEvent.build(job)) - def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: - """ - Emit a "download_complete" event at the end of a successful download. + def emit_download_complete(self, job: "DownloadJob") -> None: + """Emitted when a download is completed""" + self.dispatch(DownloadCompleteEvent.build(job)) - :param source: Source URL - :param download_path: Path to the locally downloaded file - :param total_bytes: The size of the downloaded file - """ - self.__emit_download_event( - event_name="download_complete", - payload={ - "source": source, - "download_path": download_path, - "total_bytes": total_bytes, - }, - ) + def emit_download_cancelled(self, job: "DownloadJob") -> None: + """Emitted when a download is cancelled""" + self.dispatch(DownloadCancelledEvent.build(job)) - def emit_download_cancelled(self, source: str) -> None: - """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" - self.__emit_download_event( - event_name="download_cancelled", - payload={ - "source": source, - }, - ) + def emit_download_error(self, job: "DownloadJob") -> None: + """Emitted when a download encounters an error""" + self.dispatch(DownloadErrorEvent.build(job)) - def emit_download_error(self, source: str, error_type: str, error: str) -> None: - """ - Emit a "download_error" event when an download job encounters an exception. + # endregion - :param source: Source URL - :param error_type: The name of the exception that raised the error - :param error: The traceback from this error - """ - self.__emit_download_event( - event_name="download_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) + # region Model loading - def emit_model_install_downloading( - self, - source: str, - local_path: str, - bytes: int, - total_bytes: int, - parts: List[Dict[str, Union[str, int]]], - id: int, + def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None: + """Emitted when a model load is started.""" + self.dispatch(ModelLoadStartedEvent.build(config, submodel_type)) + + def emit_model_load_complete( + self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None ) -> None: - """ - Emit at intervals while the install job is in progress (remote models only). + """Emitted when a model load is complete.""" + self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type)) - :param source: Source of the model - :param local_path: Where model is downloading to - :param parts: Progress of downloading URLs that comprise the model, if any. - :param bytes: Number of bytes downloaded so far. - :param total_bytes: Total size of download, including all files. - This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes". - """ - self.__emit_model_event( - event_name="model_install_downloading", - payload={ - "source": source, - "local_path": local_path, - "bytes": bytes, - "total_bytes": total_bytes, - "parts": parts, - "id": id, - }, - ) + # endregion - def emit_model_install_downloads_done(self, source: str) -> None: - """ - Emit once when all parts are downloaded, but before the probing and registration start. + # region Model install - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_downloads_done", - payload={"source": source}, - ) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: + """Emitted at intervals while the install job is in progress (remote models only).""" + self.dispatch(ModelInstallDownloadProgressEvent.build(job)) - def emit_model_install_running(self, source: str) -> None: - """ - Emit once when an install job becomes active. + def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None: + self.dispatch(ModelInstallDownloadsCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_running", - payload={"source": source}, - ) + def emit_model_install_started(self, job: "ModelInstallJob") -> None: + """Emitted once when an install job is started (after any download).""" + self.dispatch(ModelInstallStartedEvent.build(job)) - def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: - """ - Emit when an install job is completed successfully. + def emit_model_install_complete(self, job: "ModelInstallJob") -> None: + """Emitted when an install job is completed successfully.""" + self.dispatch(ModelInstallCompleteEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - :param key: Model config record key - :param total_bytes: Size of the model (may be None for installation of a local path) - """ - self.__emit_model_event( - event_name="model_install_completed", - payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id}, - ) + def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None: + """Emitted when an install job is cancelled.""" + self.dispatch(ModelInstallCancelledEvent.build(job)) - def emit_model_install_cancelled(self, source: str, id: int) -> None: - """ - Emit when an install job is cancelled. + def emit_model_install_error(self, job: "ModelInstallJob") -> None: + """Emitted when an install job encounters an exception.""" + self.dispatch(ModelInstallErrorEvent.build(job)) - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_cancelled", - payload={"source": source, "id": id}, - ) + # endregion - def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: - """ - Emit when an install job encounters an exception. - - :param source: Source of the model - :param error_type: The name of the exception - :param error: A text description of the exception - """ - self.__emit_model_event( - event_name="model_install_error", - payload={"source": source, "error_type": error_type, "error": error, "id": id}, - ) + # region Bulk image download def emit_bulk_download_started( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download starts""" - self._emit_bulk_download_event( - event_name="bulk_download_started", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + """Emitted when a bulk image download is started""" + self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_completed( + def emit_bulk_download_complete( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str ) -> None: - """Emitted when a bulk download completes""" - self._emit_bulk_download_event( - event_name="bulk_download_completed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - }, - ) + """Emitted when a bulk image download is complete""" + self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) - def emit_bulk_download_failed( + def emit_bulk_download_error( self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str ) -> None: - """Emitted when a bulk download fails""" - self._emit_bulk_download_event( - event_name="bulk_download_failed", - payload={ - "bulk_download_id": bulk_download_id, - "bulk_download_item_id": bulk_download_item_id, - "bulk_download_item_name": bulk_download_item_name, - "error": error, - }, + """Emitted when a bulk image download has an error""" + self.dispatch( + BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) ) + + # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py new file mode 100644 index 0000000000..7d3d489bf5 --- /dev/null +++ b/invokeai/app/services/events/events_common.py @@ -0,0 +1,591 @@ +from math import floor +from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar + +from fastapi_events.handlers.local import local_handler +from fastapi_events.registry.payload_schema import registry as payload_schema +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.services.session_queue.session_queue_common import ( + QUEUE_ITEM_STATUS, + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, +) +from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState + +if TYPE_CHECKING: + from invokeai.app.services.download.download_base import DownloadJob + from invokeai.app.services.model_install.model_install_common import ModelInstallJob + + +class EventBase(BaseModel): + """Base class for all events. All events must inherit from this class. + + Events must define a class attribute `__event_name__` to identify the event. + + All other attributes should be defined as normal for a pydantic model. + + A timestamp is automatically added to the event when it is created. + """ + + timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + @classmethod + def get_events(cls) -> set[type["EventBase"]]: + """Get a set of all event models.""" + + event_subclasses: set[type["EventBase"]] = set() + for subclass in cls.__subclasses__(): + # We only want to include subclasses that are event models, not intermediary classes + if hasattr(subclass, "__event_name__"): + event_subclasses.add(subclass) + event_subclasses.update(subclass.get_events()) + + return event_subclasses + + +TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True) + +FastAPIEvent: TypeAlias = tuple[str, TEvent] +""" +A tuple representing a `fastapi-events` event, with the event name and payload. +Provide a generic type to `TEvent` to specify the payload type. +""" + + +class FastAPIEventFunc(Protocol, Generic[TEvent]): + def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ... + + +def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None: + """Register a function to handle specific events. + + :param events: An event or set of events to handle + :param func: The function to handle the events + """ + events = events if isinstance(events, set) else {events} + for event in events: + assert hasattr(event, "__event_name__") + local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] + + +class QueueEventBase(EventBase): + """Base class for queue events""" + + queue_id: str = Field(description="The ID of the queue") + + +class QueueItemEventBase(QueueEventBase): + """Base class for queue item events""" + + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + + +class InvocationEventBase(QueueItemEventBase): + """Base class for invocation events""" + + session_id: str = Field(description="The ID of the session (aka graph execution state)") + queue_id: str = Field(description="The ID of the queue") + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation") + invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") + + +@payload_schema.register +class InvocationStartedEvent(InvocationEventBase): + """Event model for invocation_started""" + + __event_name__ = "invocation_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + ) + + +@payload_schema.register +class InvocationDenoiseProgressEvent(InvocationEventBase): + """Event model for invocation_denoise_progress""" + + __event_name__ = "invocation_denoise_progress" + + progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") + step: int = Field(description="The current step of the invocation") + total_steps: int = Field(description="The total number of steps in the invocation") + order: int = Field(description="The order of the invocation in the session") + percentage: float = Field(description="The percentage of completion of the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + intermediate_state: PipelineIntermediateState, + progress_image: ProgressImage, + ) -> "InvocationDenoiseProgressEvent": + step = intermediate_state.step + total_steps = intermediate_state.total_steps + order = intermediate_state.order + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + progress_image=progress_image, + step=step, + total_steps=total_steps, + order=order, + percentage=cls.calc_percentage(step, total_steps, order), + ) + + @staticmethod + def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float: + """Calculate the percentage of completion of denoising.""" + if total_steps == 0: + return 0.0 + if scheduler_order == 2: + return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2) + # order == 1 + return (step + 1 + 1) / (total_steps + 1) + + +@payload_schema.register +class InvocationCompleteEvent(InvocationEventBase): + """Event model for invocation_complete""" + + __event_name__ = "invocation_complete" + + result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + ) -> "InvocationCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + result=result, + ) + + +@payload_schema.register +class InvocationErrorEvent(InvocationEventBase): + """Event model for invocation_error""" + + __event_name__ = "invocation_error" + + error_type: str = Field(description="The error type") + error_message: str = Field(description="The error message") + error_traceback: str = Field(description="The error traceback") + user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation") + project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + error_type: str, + error_message: str, + error_traceback: str, + ) -> "InvocationErrorEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation=invocation, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + user_id=getattr(queue_item, "user_id", None), + project_id=getattr(queue_item, "project_id", None), + ) + + +@payload_schema.register +class QueueItemStatusChangedEvent(QueueItemEventBase): + """Event model for queue_item_status_changed""" + + __event_name__ = "queue_item_status_changed" + + status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item") + error_type: Optional[str] = Field(default=None, description="The error type, if any") + error_message: Optional[str] = Field(default=None, description="The error message, if any") + error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any") + created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created") + updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated") + started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started") + completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed") + batch_status: BatchStatus = Field(description="The status of the batch") + queue_status: SessionQueueStatus = Field(description="The status of the queue") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + ) -> "QueueItemStatusChangedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + status=queue_item.status, + error_type=queue_item.error_type, + error_message=queue_item.error_message, + error_traceback=queue_item.error_traceback, + created_at=str(queue_item.created_at) if queue_item.created_at else None, + updated_at=str(queue_item.updated_at) if queue_item.updated_at else None, + started_at=str(queue_item.started_at) if queue_item.started_at else None, + completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, + batch_status=batch_status, + queue_status=queue_status, + ) + + +@payload_schema.register +class BatchEnqueuedEvent(QueueEventBase): + """Event model for batch_enqueued""" + + __event_name__ = "batch_enqueued" + + batch_id: str = Field(description="The ID of the batch") + enqueued: int = Field(description="The number of invocations enqueued") + requested: int = Field( + description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" + ) + priority: int = Field(description="The priority of the batch") + + @classmethod + def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + return cls( + queue_id=enqueue_result.queue_id, + batch_id=enqueue_result.batch.batch_id, + enqueued=enqueue_result.enqueued, + requested=enqueue_result.requested, + priority=enqueue_result.priority, + ) + + +@payload_schema.register +class QueueClearedEvent(QueueEventBase): + """Event model for queue_cleared""" + + __event_name__ = "queue_cleared" + + @classmethod + def build(cls, queue_id: str) -> "QueueClearedEvent": + return cls(queue_id=queue_id) + + +class DownloadEventBase(EventBase): + """Base class for events associated with a download""" + + source: str = Field(description="The source of the download") + + +@payload_schema.register +class DownloadStartedEvent(DownloadEventBase): + """Event model for download_started""" + + __event_name__ = "download_started" + + download_path: str = Field(description="The local path where the download is saved") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadStartedEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix()) + + +@payload_schema.register +class DownloadProgressEvent(DownloadEventBase): + """Event model for download_progress""" + + __event_name__ = "download_progress" + + download_path: str = Field(description="The local path where the download is saved") + current_bytes: int = Field(description="The number of bytes downloaded so far") + total_bytes: int = Field(description="The total number of bytes to be downloaded") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadProgressEvent": + assert job.download_path + return cls( + source=str(job.source), + download_path=job.download_path.as_posix(), + current_bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register +class DownloadCompleteEvent(DownloadEventBase): + """Event model for download_complete""" + + __event_name__ = "download_complete" + + download_path: str = Field(description="The local path where the download is saved") + total_bytes: int = Field(description="The total number of bytes downloaded") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent": + assert job.download_path + return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes) + + +@payload_schema.register +class DownloadCancelledEvent(DownloadEventBase): + """Event model for download_cancelled""" + + __event_name__ = "download_cancelled" + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent": + return cls(source=str(job.source)) + + +@payload_schema.register +class DownloadErrorEvent(DownloadEventBase): + """Event model for download_error""" + + __event_name__ = "download_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build(cls, job: "DownloadJob") -> "DownloadErrorEvent": + assert job.error_type + assert job.error + return cls(source=str(job.source), error_type=job.error_type, error=job.error) + + +class ModelEventBase(EventBase): + """Base class for events associated with a model""" + + +@payload_schema.register +class ModelLoadStartedEvent(ModelEventBase): + """Event model for model_load_started""" + + __event_name__ = "model_load_started" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register +class ModelLoadCompleteEvent(ModelEventBase): + """Event model for model_load_complete""" + + __event_name__ = "model_load_complete" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register +class ModelInstallDownloadProgressEvent(ModelEventBase): + """Event model for model_install_download_progress""" + + __event_name__ = "model_install_download_progress" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + +@payload_schema.register +class ModelInstallDownloadsCompleteEvent(ModelEventBase): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_downloads_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallStartedEvent(ModelEventBase): + """Event model for model_install_started""" + + __event_name__ = "model_install_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallCompleteEvent(ModelEventBase): + """Event model for model_install_complete""" + + __event_name__ = "model_install_complete" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + key: str = Field(description="Model config record key") + total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent": + assert job.config_out is not None + return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes) + + +@payload_schema.register +class ModelInstallCancelledEvent(ModelEventBase): + """Event model for model_install_cancelled""" + + __event_name__ = "model_install_cancelled" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent": + return cls(id=job.id, source=str(job.source)) + + +@payload_schema.register +class ModelInstallErrorEvent(ModelEventBase): + """Event model for model_install_error""" + + __event_name__ = "model_install_error" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + error_type: str = Field(description="The name of the exception") + error: str = Field(description="A text description of the exception") + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent": + assert job.error_type is not None + assert job.error is not None + return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error) + + +class BulkDownloadEventBase(EventBase): + """Base class for events associated with a bulk image download""" + + bulk_download_id: str = Field(description="The ID of the bulk image download") + bulk_download_item_id: str = Field(description="The ID of the bulk image download item") + bulk_download_item_name: str = Field(description="The name of the bulk image download item") + + +@payload_schema.register +class BulkDownloadStartedEvent(BulkDownloadEventBase): + """Event model for bulk_download_started""" + + __event_name__ = "bulk_download_started" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadStartedEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register +class BulkDownloadCompleteEvent(BulkDownloadEventBase): + """Event model for bulk_download_complete""" + + __event_name__ = "bulk_download_complete" + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + ) -> "BulkDownloadCompleteEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + ) + + +@payload_schema.register +class BulkDownloadErrorEvent(BulkDownloadEventBase): + """Event model for bulk_download_error""" + + __event_name__ = "bulk_download_error" + + error: str = Field(description="The error message") + + @classmethod + def build( + cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + ) -> "BulkDownloadErrorEvent": + return cls( + bulk_download_id=bulk_download_id, + bulk_download_item_id=bulk_download_item_id, + bulk_download_item_name=bulk_download_item_name, + error=error, + ) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py new file mode 100644 index 0000000000..8279d3bb34 --- /dev/null +++ b/invokeai/app/services/events/events_fastapievents.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +import threading +from queue import Empty, Queue + +from fastapi_events.dispatcher import dispatch + +from invokeai.app.services.events.events_common import ( + EventBase, +) + +from .events_base import EventServiceBase + + +class FastAPIEventService(EventServiceBase): + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self._queue = Queue[EventBase | None]() + self._stop_event = threading.Event() + asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) + + super().__init__() + + def stop(self, *args, **kwargs): + self._stop_event.set() + self._queue.put(None) + + def dispatch(self, event: EventBase) -> None: + self._queue.put(event) + + async def _dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self._queue.get(block=False) + if not event: # Probably stopping + continue + # Leave the payloads as live pydantic models + dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) + + except Empty: + await asyncio.sleep(0.1) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/invokeai/app/services/model_install/__init__.py b/invokeai/app/services/model_install/__init__.py index 00a33c203e..941485a134 100644 --- a/invokeai/app/services/model_install/__init__.py +++ b/invokeai/app/services/model_install/__init__.py @@ -1,11 +1,13 @@ """Initialization file for model install service package.""" from .model_install_base import ( + ModelInstallServiceBase, +) +from .model_install_common import ( HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, UnknownInstallJobException, URLModelSource, diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index b622c8dade..76b77f0419 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -1,242 +1,17 @@ # Copyright 2023 Lincoln D. Stein and the InvokeAI development team """Baseclass definitions for the model installer.""" -import re -import traceback from abc import ABC, abstractmethod -from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union - -from pydantic import BaseModel, Field, PrivateAttr, field_validator -from pydantic.networks import AnyHttpUrl -from typing_extensions import Annotated +from typing import Any, Dict, List, Optional, Union from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob +from invokeai.app.services.download import DownloadQueueServiceBase from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant -from invokeai.backend.model_manager.config import ModelSourceType -from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata - - -class InstallStatus(str, Enum): - """State of an install job running in the background.""" - - WAITING = "waiting" # waiting to be dequeued - DOWNLOADING = "downloading" # downloading of model files in process - DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run - RUNNING = "running" # being processed - COMPLETED = "completed" # finished running - ERROR = "error" # terminated with an error message - CANCELLED = "cancelled" # terminated with an error message - - -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - -class UnknownInstallJobException(Exception): - """Raised when the status of an unknown job is requested.""" - - -class StringLikeSource(BaseModel): - """ - Base class for model sources, implements functions that lets the source be sorted and indexed. - - These shenanigans let this stuff work: - - source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') - mydict = {source1: 'model 1'} - assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' - assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' - - source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) - assert source1 == source2 - assert source1 == 'C:/users/mort/foo.safetensors' - """ - - def __hash__(self) -> int: - """Return hash of the path field, for indexing.""" - return hash(str(self)) - - def __lt__(self, other: object) -> int: - """Return comparison of the stringified version, for sorting.""" - return str(self) < str(other) - - def __eq__(self, other: object) -> bool: - """Return equality on the stringified version.""" - if isinstance(other, Path): - return str(self) == other.as_posix() - else: - return str(self) == str(other) - - -class LocalModelSource(StringLikeSource): - """A local file or directory path.""" - - path: str | Path - inplace: Optional[bool] = False - type: Literal["local"] = "local" - - # these methods allow the source to be used in a string-like way, - # for example as an index into a dict - def __str__(self) -> str: - """Return string version of path when string rep needed.""" - return Path(self.path).as_posix() - - -class HFModelSource(StringLikeSource): - """ - A HuggingFace repo_id with optional variant, sub-folder and access token. - Note that the variant option, if not provided to the constructor, will default to fp16, which is - what people (almost) always want. - """ - - repo_id: str - variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 - subfolder: Optional[Path] = None - access_token: Optional[str] = None - type: Literal["hf"] = "hf" - - @field_validator("repo_id") - @classmethod - def proper_repo_id(cls, v: str) -> str: # noqa D102 - if not re.match(r"^([.\w-]+/[.\w-]+)$", v): - raise ValueError(f"{v}: invalid repo_id format") - return v - - def __str__(self) -> str: - """Return string version of repoid when string rep needed.""" - base: str = self.repo_id - if self.variant: - base += f":{self.variant or ''}" - if self.subfolder: - base += f":{self.subfolder}" - return base - - -class URLModelSource(StringLikeSource): - """A generic URL point to a checkpoint file.""" - - url: AnyHttpUrl - access_token: Optional[str] = None - type: Literal["url"] = "url" - - def __str__(self) -> str: - """Return string version of the url when string rep needed.""" - return str(self.url) - - -ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] - -MODEL_SOURCE_TO_TYPE_MAP = { - URLModelSource: ModelSourceType.Url, - HFModelSource: ModelSourceType.HFRepoID, - LocalModelSource: ModelSourceType.Path, -} - - -class ModelInstallJob(BaseModel): - """Object that tracks the current status of an install request.""" - - id: int = Field(description="Unique ID for this job") - status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") - error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") - config_in: Dict[str, Any] = Field( - default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." - ) - config_out: Optional[AnyModelConfig] = Field( - default=None, description="After successful installation, this will hold the configuration object." - ) - inplace: bool = Field( - default=False, description="Leave model in its current location; otherwise install under models directory" - ) - source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") - local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") - bytes: int = Field( - default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" - ) - total_bytes: int = Field(default=0, description="Total size of the model to be installed") - source_metadata: Optional[AnyModelRepoMetadata] = Field( - default=None, description="Metadata provided by the model source" - ) - error: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the text of the exception" - ) - error_traceback: Optional[str] = Field( - default=None, description="On an error condition, this field will contain the exception traceback" - ) - # internal flags and transitory settings - _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) - _exception: Optional[Exception] = PrivateAttr(default=None) - - def set_error(self, e: Exception) -> None: - """Record the error and traceback from an exception.""" - self._exception = e - self.error = str(e) - self.error_traceback = self._format_error(e) - self.status = InstallStatus.ERROR - self.error_reason = self._exception.__class__.__name__ if self._exception else None - - def cancel(self) -> None: - """Call to cancel the job.""" - self.status = InstallStatus.CANCELLED - - @property - def error_type(self) -> Optional[str]: - """Class name of the exception that led to status==ERROR.""" - return self._exception.__class__.__name__ if self._exception else None - - def _format_error(self, exception: Exception) -> str: - """Error traceback.""" - return "".join(traceback.format_exception(exception)) - - @property - def cancelled(self) -> bool: - """Set status to CANCELLED.""" - return self.status == InstallStatus.CANCELLED - - @property - def errored(self) -> bool: - """Return true if job has errored.""" - return self.status == InstallStatus.ERROR - - @property - def waiting(self) -> bool: - """Return true if job is waiting to run.""" - return self.status == InstallStatus.WAITING - - @property - def downloading(self) -> bool: - """Return true if job is downloading.""" - return self.status == InstallStatus.DOWNLOADING - - @property - def downloads_done(self) -> bool: - """Return true if job's downloads ae done.""" - return self.status == InstallStatus.DOWNLOADS_DONE - - @property - def running(self) -> bool: - """Return true if job is running.""" - return self.status == InstallStatus.RUNNING - - @property - def complete(self) -> bool: - """Return true if job completed without errors.""" - return self.status == InstallStatus.COMPLETED - - @property - def in_terminal_state(self) -> bool: - """Return true if job is in a terminal state.""" - return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] +from invokeai.backend.model_manager import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -280,7 +55,7 @@ class ModelInstallServiceBase(ABC): @property @abstractmethod - def event_bus(self) -> Optional[EventServiceBase]: + def event_bus(self) -> Optional["EventServiceBase"]: """Return the event service base object associated with the installer.""" @abstractmethod diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py new file mode 100644 index 0000000000..751b5baa4b --- /dev/null +++ b/invokeai/app/services/model_install/model_install_common.py @@ -0,0 +1,227 @@ +import re +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Set, Union + +from pydantic import BaseModel, Field, PrivateAttr, field_validator +from pydantic.networks import AnyHttpUrl +from typing_extensions import Annotated + +from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob +from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant +from invokeai.backend.model_manager.config import ModelSourceType +from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata + + +class InstallStatus(str, Enum): + """State of an install job running in the background.""" + + WAITING = "waiting" # waiting to be dequeued + DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run + RUNNING = "running" # being processed + COMPLETED = "completed" # finished running + ERROR = "error" # terminated with an error message + CANCELLED = "cancelled" # terminated with an error message + + +class UnknownInstallJobException(Exception): + """Raised when the status of an unknown job is requested.""" + + +class StringLikeSource(BaseModel): + """ + Base class for model sources, implements functions that lets the source be sorted and indexed. + + These shenanigans let this stuff work: + + source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') + mydict = {source1: 'model 1'} + assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' + assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' + + source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) + assert source1 == source2 + assert source1 == 'C:/users/mort/foo.safetensors' + """ + + def __hash__(self) -> int: + """Return hash of the path field, for indexing.""" + return hash(str(self)) + + def __lt__(self, other: object) -> int: + """Return comparison of the stringified version, for sorting.""" + return str(self) < str(other) + + def __eq__(self, other: object) -> bool: + """Return equality on the stringified version.""" + if isinstance(other, Path): + return str(self) == other.as_posix() + else: + return str(self) == str(other) + + +class LocalModelSource(StringLikeSource): + """A local file or directory path.""" + + path: str | Path + inplace: Optional[bool] = False + type: Literal["local"] = "local" + + # these methods allow the source to be used in a string-like way, + # for example as an index into a dict + def __str__(self) -> str: + """Return string version of path when string rep needed.""" + return Path(self.path).as_posix() + + +class HFModelSource(StringLikeSource): + """ + A HuggingFace repo_id with optional variant, sub-folder and access token. + Note that the variant option, if not provided to the constructor, will default to fp16, which is + what people (almost) always want. + """ + + repo_id: str + variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16 + subfolder: Optional[Path] = None + access_token: Optional[str] = None + type: Literal["hf"] = "hf" + + @field_validator("repo_id") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(r"^([.\w-]+/[.\w-]+)$", v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def __str__(self) -> str: + """Return string version of repoid when string rep needed.""" + base: str = self.repo_id + if self.variant: + base += f":{self.variant or ''}" + if self.subfolder: + base += f":{self.subfolder}" + return base + + +class URLModelSource(StringLikeSource): + """A generic URL point to a checkpoint file.""" + + url: AnyHttpUrl + access_token: Optional[str] = None + type: Literal["url"] = "url" + + def __str__(self) -> str: + """Return string version of the url when string rep needed.""" + return str(self.url) + + +ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] + +MODEL_SOURCE_TO_TYPE_MAP = { + URLModelSource: ModelSourceType.Url, + HFModelSource: ModelSourceType.HFRepoID, + LocalModelSource: ModelSourceType.Path, +} + + +class ModelInstallJob(BaseModel): + """Object that tracks the current status of an install request.""" + + id: int = Field(description="Unique ID for this job") + status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process") + error_reason: Optional[str] = Field(default=None, description="Information about why the job failed") + config_in: Dict[str, Any] = Field( + default_factory=dict, description="Configuration information (e.g. 'description') to apply to model." + ) + config_out: Optional[AnyModelConfig] = Field( + default=None, description="After successful installation, this will hold the configuration object." + ) + inplace: bool = Field( + default=False, description="Leave model in its current location; otherwise install under models directory" + ) + source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model") + local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source") + bytes: int = Field( + default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)" + ) + total_bytes: int = Field(default=0, description="Total size of the model to be installed") + source_metadata: Optional[AnyModelRepoMetadata] = Field( + default=None, description="Metadata provided by the model source" + ) + download_parts: Set[DownloadJob] = Field( + default_factory=set, description="Download jobs contributing to this install" + ) + error: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the text of the exception" + ) + error_traceback: Optional[str] = Field( + default=None, description="On an error condition, this field will contain the exception traceback" + ) + # internal flags and transitory settings + _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) + _exception: Optional[Exception] = PrivateAttr(default=None) + + def set_error(self, e: Exception) -> None: + """Record the error and traceback from an exception.""" + self._exception = e + self.error = str(e) + self.error_traceback = self._format_error(e) + self.status = InstallStatus.ERROR + self.error_reason = self._exception.__class__.__name__ if self._exception else None + + def cancel(self) -> None: + """Call to cancel the job.""" + self.status = InstallStatus.CANCELLED + + @property + def error_type(self) -> Optional[str]: + """Class name of the exception that led to status==ERROR.""" + return self._exception.__class__.__name__ if self._exception else None + + def _format_error(self, exception: Exception) -> str: + """Error traceback.""" + return "".join(traceback.format_exception(exception)) + + @property + def cancelled(self) -> bool: + """Set status to CANCELLED.""" + return self.status == InstallStatus.CANCELLED + + @property + def errored(self) -> bool: + """Return true if job has errored.""" + return self.status == InstallStatus.ERROR + + @property + def waiting(self) -> bool: + """Return true if job is waiting to run.""" + return self.status == InstallStatus.WAITING + + @property + def downloading(self) -> bool: + """Return true if job is downloading.""" + return self.status == InstallStatus.DOWNLOADING + + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + + @property + def running(self) -> bool: + """Return true if job is running.""" + return self.status == InstallStatus.RUNNING + + @property + def complete(self) -> bool: + """Return true if job completed without errors.""" + return self.status == InstallStatus.COMPLETED + + @property + def in_terminal_state(self) -> bool: + """Return true if job is in a terminal state.""" + return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED] diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index cde9a6502e..c78a09ce87 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml @@ -21,6 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker +from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( @@ -45,13 +46,12 @@ from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.util import slugify -from .model_install_base import ( +from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, HFModelSource, InstallStatus, LocalModelSource, ModelInstallJob, - ModelInstallServiceBase, ModelSource, StringLikeSource, URLModelSource, @@ -59,6 +59,9 @@ from .model_install_base import ( TMPDIR_PREFIX = "tmpinstall_" +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + class ModelInstallService(ModelInstallServiceBase): """class for InvokeAI model installation.""" @@ -68,7 +71,7 @@ class ModelInstallService(ModelInstallServiceBase): app_config: InvokeAIAppConfig, record_store: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, - event_bus: Optional[EventServiceBase] = None, + event_bus: Optional["EventServiceBase"] = None, session: Optional[Session] = None, ): """ @@ -104,7 +107,7 @@ class ModelInstallService(ModelInstallServiceBase): return self._record_store @property - def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 + def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102 return self._event_bus # make the invoker optional here because we don't need it and it @@ -825,6 +828,7 @@ class ModelInstallService(ModelInstallServiceBase): else: # update sizes install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.download_parts = download_job.download_parts self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: @@ -864,36 +868,20 @@ class ModelInstallService(ModelInstallServiceBase): job.status = InstallStatus.RUNNING self._logger.info(f"Model install started: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_running(str(job.source)) + self._event_bus.emit_model_install_started(job) def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: assert job._download_job is not None - parts: List[Dict[str, str | int]] = [ - { - "url": str(x.source), - "local_path": str(x.download_path), - "bytes": x.bytes, - "total_bytes": x.total_bytes, - } - for x in job._download_job.download_parts - ] assert job.bytes is not None assert job.total_bytes is not None - self._event_bus.emit_model_install_downloading( - str(job.source), - local_path=job.local_path.as_posix(), - parts=parts, - bytes=sum(x["bytes"] for x in parts), - total_bytes=sum(x["total_bytes"] for x in parts), - id=job.id, - ) + self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: job.status = InstallStatus.DOWNLOADS_DONE self._logger.info(f"Model download complete: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_downloads_done(str(job.source)) + self._event_bus.emit_model_install_downloads_complete(job) def _signal_job_completed(self, job: ModelInstallJob) -> None: job.status = InstallStatus.COMPLETED @@ -903,24 +891,19 @@ class ModelInstallService(ModelInstallServiceBase): if self._event_bus: assert job.local_path is not None assert job.config_out is not None - key = job.config_out.key - self._event_bus.emit_model_install_completed( - source=str(job.source), key=key, id=job.id, total_bytes=job.bytes - ) + self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") if self._event_bus: - error_type = job.error_type - error = job.error - assert error_type is not None - assert error is not None - self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id) + assert job.error_type is not None + assert job.error is not None + self._event_bus.emit_model_install_error(job) def _signal_job_cancelled(self, job: ModelInstallJob) -> None: self._logger.info(f"Model install canceled: {job.source}") if self._event_bus: - self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) + self._event_bus.emit_model_install_cancelled(job) @staticmethod def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 7de36793fb..22d815483e 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -7,7 +7,6 @@ from typing import Callable, Dict, Optional from torch import Tensor -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -18,18 +17,12 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting """ @property diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index cd14235ee0..221a042da5 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -11,7 +11,6 @@ from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, @@ -59,25 +58,18 @@ class ModelLoadService(ModelLoadServiceBase): """Return the checkpoint convert cache used by this loader.""" return self._convert_cache - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting """ - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - ) + + # We don't have an invoker during testing + # TODO(psyche): Mock this method on the invoker in the tests + if hasattr(self, "_invoker"): + self._invoker.services.events.emit_model_load_started(model_config, submodel_type) implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore loaded_model: LoadedModel = implementation( @@ -87,13 +79,9 @@ class ModelLoadService(ModelLoadServiceBase): convert_cache=self._convert_cache, ).load_model(model_config, submodel_type) - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - submodel_type=submodel_type, - loaded=True, - ) + if hasattr(self, "_invoker"): + self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) + return loaded_model def load_model_from_path( @@ -150,32 +138,3 @@ class ModelLoadService(ModelLoadServiceBase): raw_model = loader(model_path) ram_cache.put(key=cache_key, model=raw_model) return LoadedModel(_locker=ram_cache.get(key=cache_key)) - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - submodel_type: Optional[SubModelType] = None, - ) -> None: - if not self._invoker: - return - - if not loaded: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) - else: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - submodel_type=submodel_type, - ) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 485ef2f8c3..15611bb5f8 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,49 @@ from abc import ABC, abstractmethod +from threading import Event +from typing import Optional, Protocol +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.util.profiler import Profiler + + +class SessionRunnerBase(ABC): + """ + Base class for session runner. + """ + + @abstractmethod + def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: + """Starts the session runner. + + Args: + services: The invocation services. + cancel_event: The cancel event. + profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session + stats will be still be recorded and logged when profiling is disabled. + """ + pass + + @abstractmethod + def run(self, queue_item: SessionQueueItem) -> None: + """Runs a session. + + Args: + queue_item: The session to run. + """ + pass + + @abstractmethod + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Run a single node in the graph. + + Args: + invocation: The invocation to run. + queue_item: The session queue item. + """ + pass class SessionProcessorBase(ABC): @@ -26,3 +69,85 @@ class SessionProcessorBase(ABC): def get_status(self) -> SessionProcessorStatus: """Gets the status of the session processor""" pass + + +class OnBeforeRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that will be executed. + queue_item: The session queue item. + """ + ... + + +class OnAfterRunNode(Protocol): + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that was executed. + queue_item: The session queue item. + """ + ... + + +class OnNodeError(Protocol): + def __call__( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Callback to run when a node has an error. + + Args: + invocation: The invocation that errored. + queue_item: The session queue item. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... + + +class OnBeforeRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a session. + + Args: + queue_item: The session queue item. + """ + ... + + +class OnAfterRunSession(Protocol): + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run after executing a session. + + Args: + queue_item: The session queue item. + """ + ... + + +class OnNonFatalProcessorError(Protocol): + def __call__( + self, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Callback to run when a non-fatal error occurs in the processor. + + Args: + queue_item: The session queue item, if one was being executed when the error occurred. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 894996b1e6..3f348fb239 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -4,24 +4,325 @@ from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent from typing import Optional -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, + register_events, +) from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError +from invokeai.app.services.session_processor.session_processor_base import ( + OnAfterRunNode, + OnAfterRunSession, + OnBeforeRunNode, + OnBeforeRunSession, + OnNodeError, + OnNonFatalProcessorError, +) from invokeai.app.services.session_processor.session_processor_common import CanceledException -from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError +from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker -from .session_processor_base import SessionProcessorBase +from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +class DefaultSessionRunner(SessionRunnerBase): + """Processes a single session's invocations.""" + + def __init__( + self, + on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None, + on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None, + on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None, + on_node_error_callbacks: Optional[list[OnNodeError]] = None, + on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, + ): + """ + Args: + on_before_run_session_callbacks: Callbacks to run before the session starts. + on_before_run_node_callbacks: Callbacks to run before each node starts. + on_after_run_node_callbacks: Callbacks to run after each node completes. + on_node_error_callbacks: Callbacks to run when a node errors. + on_after_run_session_callbacks: Callbacks to run after the session completes. + """ + + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] + self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] + self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] + self._on_node_error_callbacks = on_node_error_callbacks or [] + self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] + + def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): + self._services = services + self._cancel_event = cancel_event + self._profiler = profiler + + def _is_canceled(self) -> bool: + """Check if the cancel event is set. This is also passed to the invocation context builder and called during + denoising to check if the session has been canceled.""" + return self._cancel_event.is_set() + + def run(self, queue_item: SessionQueueItem): + # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. + + self._on_before_run_session(queue_item=queue_item) + + # Loop over invocations until the session is complete or canceled + while True: + try: + invocation = queue_item.session.next() + # Anything other than a `NodeInputError` is handled as a processor error + except NodeInputError as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=e.node, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + break + + if invocation is None or self._is_canceled(): + break + + self.run_node(invocation, queue_item) + + # The session is complete if all invocations have been run or there is an error on the session. + # At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must + # use the cancel event to check if the session is canceled. + if ( + queue_item.session.is_complete() + or self._is_canceled() + or queue_item.status in ["failed", "canceled", "completed"] + ): + break + + self._on_after_run_session(queue_item=queue_item) + + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + try: + # Any unhandled exception in this scope is an invocation error & will fail the graph + with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + context = build_invocation_context( + data=data, + services=self._services, + is_canceled=self._is_canceled, + ) + + # Invoke the node + output = invocation.invoke_internal(context=context, services=self._services) + # Save output and history + queue_item.session.complete(invocation.id, output) + + self._on_after_run_node(invocation, queue_item, output) + + except KeyboardInterrupt: + # TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here? + pass + except CanceledException: + # A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need + # to do any handling here, and no error should be set - just pass and the cancellation will be handled + # correctly in the next iteration of the session runner loop. + # + # See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we + # handle cancellation. + pass + except Exception as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: + """Called before a session is run. + + - Start the profiler if profiling is enabled. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + + # If profiling is enabled, start the profiler + if self._profiler is not None: + self._profiler.start(profile_id=queue_item.session_id) + + for callback in self._on_before_run_session_callbacks: + callback(queue_item=queue_item) + + def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: + """Called after a session is run. + + - Stop the profiler if profiling is enabled. + - Update the queue item's session object in the database. + - If not already canceled or failed, complete the queue item. + - Log and reset performance statistics. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}" + ) + + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler is not None: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._services.performance_statistics.dump_stats( + graph_execution_state_id=queue_item.session.id, output_path=stats_path + ) + + try: + # Update the queue item with the completed session. If the queue item has been removed from the queue, + # we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared + # while the session is running. + queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + + # The queue item may have been canceled or failed while the session was running. We should only complete it + # if it is not already canceled or failed. + if queue_item.status not in ["canceled", "failed"]: + queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id) + + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._services.performance_statistics.log_stats(queue_item.session.id) + self._services.performance_statistics.reset_stats() + + for callback in self._on_after_run_session_callbacks: + callback(queue_item=queue_item) + except SessionQueueItemNotFoundError: + pass + + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Called before a node is run. + + - Emits an invocation started event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Send starting event + self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation) + + for callback in self._on_before_run_node_callbacks: + callback(invocation=invocation, queue_item=queue_item) + + def _on_after_run_node( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput + ): + """Called after a node is run. + + - Emits an invocation complete event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Send complete event on successful runs + self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output) + + for callback in self._on_after_run_node_callbacks: + callback(invocation=invocation, queue_item=queue_item, output=output) + + def _on_node_error( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + error_type: str, + error_message: str, + error_traceback: str, + ): + """Called when a node errors. Node errors may occur when running or preparing the node.. + + - Set the node error on the session object. + - Log the error. + - Fail the queue item. + - Emits an invocation error event. + - Run any callbacks registered for this event. + """ + + self._services.logger.debug( + f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})" + ) + + # Node errors do not get the full traceback. Only the queue item gets the full traceback. + node_error = f"{error_type}: {error_message}" + queue_item.session.set_node_error(invocation.id, node_error) + self._services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}" + ) + self._services.logger.error(error_traceback) + + # Fail the queue item + queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + queue_item = self._services.session_queue.fail_queue_item( + queue_item.item_id, error_type, error_message, error_traceback + ) + + # Send error event + self._services.events.emit_invocation_error( + queue_item=queue_item, + invocation=invocation, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + for callback in self._on_node_error_callbacks: + callback( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: + def __init__( + self, + session_runner: Optional[SessionRunnerBase] = None, + on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None, + thread_limit: int = 1, + polling_interval: int = 1, + ) -> None: + super().__init__() + + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] + self._thread_limit = thread_limit + self._polling_interval = polling_interval + + def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None @@ -31,11 +332,11 @@ class DefaultSessionProcessor(SessionProcessorBase): self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) + register_events(QueueClearedEvent, self._on_queue_cleared) + register_events(BatchEnqueuedEvent, self._on_batch_enqueued) + register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_limit = thread_limit - self._thread_semaphore = BoundedSemaphore(thread_limit) - self._polling_interval = polling_interval + self._thread_semaphore = BoundedSemaphore(self._thread_limit) # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # the profiler will create a new profile for each session. @@ -49,6 +350,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) self._thread = Thread( name="session_processor", target=self._process, @@ -67,30 +369,25 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent) -> None: - event_name = event[1]["event"] + async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: + if self._queue_item and self._queue_item.queue_id == event[1].queue_id: + self._cancel_event.set() + self._poll_now() - if ( - event_name == "session_canceled" - and self._queue_item - and self._queue_item.item_id == event[1]["data"]["queue_item_id"] - ): - self._cancel_event.set() - self._poll_now() - elif ( - event_name == "queue_cleared" - and self._queue_item - and self._queue_item.queue_id == event[1]["data"]["queue_id"] - ): - self._cancel_event.set() - self._poll_now() - elif event_name == "batch_enqueued": - self._poll_now() - elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [ - "completed", - "failed", - "canceled", - ]: + async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: + self._poll_now() + + async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: + if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is + # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel + # event, which the session runner checks between invocations. If set, the session runner loop is broken. + # + # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such + # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item + # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. + if event[1].status == "canceled": + self._cancel_event.set() self._poll_now() def resume(self) -> SessionProcessorStatus: @@ -116,8 +413,8 @@ class DefaultSessionProcessor(SessionProcessorBase): resume_event: ThreadEvent, cancel_event: ThreadEvent, ): - # Outermost processor try block; any unhandled exception is a fatal processor error try: + # Any unhandled exception in this block is a fatal processor error and will stop the processor. self._thread_semaphore.acquire() stop_event.clear() resume_event.set() @@ -125,8 +422,8 @@ class DefaultSessionProcessor(SessionProcessorBase): while not stop_event.is_set(): poll_now_event.clear() - # Middle processor try block; any unhandled exception is a non-fatal processor error try: + # Any unhandled exception in this block is a nonfatal processor error and will be handled. # If we are paused, wait for resume event resume_event.wait() @@ -142,159 +439,69 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) + # Run the graph + self.session_runner.run(queue_item=self._queue_item) - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() - - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) - - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) - - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue - except Exception: - # Non-fatal error in processor - self._invoker.services.logger.error( - f"Non-fatal error in session processor:\n{traceback.format_exc()}" + except Exception as e: + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_non_fatal_processor_error( + queue_item=self._queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, ) - # Cancel the queue item - if self._queue_item is not None: - self._invoker.services.session_queue.cancel_queue_item( - self._queue_item.item_id, error=traceback.format_exc() - ) - # Reset the invocation to None to prepare for the next session - self._invocation = None - # Immediately poll for next queue item + # Wait for next polling interval or event to try again poll_now_event.wait(self._polling_interval) continue - except Exception: + except Exception as e: # Fatal error in processor, log and pass - we're done here - self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) pass finally: stop_event.clear() poll_now_event.clear() self._queue_item = None self._thread_semaphore.release() + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, + ) -> None: + """Called when a non-fatal error occurs in the processor. + + - Log the error. + - If a queue item is provided, update the queue item with the completed session & fail it. + - Run any callbacks registered for this event. + """ + + self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) + + if queue_item is not None: + # Update the queue item with the completed session & fail it + queue_item = self._invoker.services.session_queue.set_queue_item_session( + queue_item.item_id, queue_item.session + ) + queue_item = self._invoker.services.session_queue.fail_queue_item( + item_id=queue_item.item_id, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + + for callback in self._on_non_fatal_processor_error_callbacks: + callback( + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index e0b6e4f528..341e034487 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItemDTO, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults @@ -73,10 +74,22 @@ class SessionQueueBase(ABC): pass @abstractmethod - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def complete_queue_item(self, item_id: int) -> SessionQueueItem: + """Completes a session queue item""" + pass + + @abstractmethod + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: """Cancels a session queue item""" pass + @abstractmethod + def fail_queue_item( + self, item_id: int, error_type: str, error_message: str, error_traceback: str + ) -> SessionQueueItem: + """Fails a session queue item""" + pass + @abstractmethod def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: """Cancels all queue items with matching batch IDs""" @@ -103,3 +116,8 @@ class SessionQueueBase(ABC): def get_queue_item(self, item_id: int) -> SessionQueueItem: """Gets a session queue item by ID""" pass + + @abstractmethod + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + """Sets the session for a session queue item. Use this to update the session state.""" + pass diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 94db6999c2..7f4601eba7 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -3,7 +3,16 @@ import json from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast -from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + StrictStr, + TypeAdapter, + field_validator, + model_validator, +) from pydantic_core import to_jsonable_python from invokeai.app.invocations.baseinvocation import BaseInvocation @@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel): session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) - error: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored") + error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_traceback: Optional[str] = Field( + default=None, + description="The error traceback if this queue item errored", + validation_alias=AliasChoices("error_traceback", "error"), + ) created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index ffcd7c40ca..467853aae4 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,10 +2,6 @@ import sqlite3 import threading from typing import Optional, Union, cast -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( @@ -27,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( calc_session_count, prepare_values_to_insert, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -41,7 +38,7 @@ class SqliteSessionQueue(SessionQueueBase): self.__invoker = invoker self._set_in_progress_to_canceled() prune_result = self.prune(DEFAULT_QUEUE_ID) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) + if prune_result.deleted > 0: self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") @@ -51,52 +48,6 @@ class SqliteSessionQueue(SessionQueueBase): self.__conn = db.conn self.__cursor = self.__conn.cursor() - def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: - return event[1]["event"] in match_in - - async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: - event_name = event[1]["event"] - - # This was a match statement, but match is not supported on python 3.9 - if event_name == "graph_execution_state_complete": - await self._handle_complete_event(event) - elif event_name == "invocation_error": - await self._handle_error_event(event) - elif event_name == "session_canceled": - await self._handle_cancel_event(event) - return event - - async def _handle_complete_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - # When a queue item has an error, we get an error event, then a completed event. - # Mark the queue item completed only if it isn't already marked completed, e.g. - # by a previously-handled error event. - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed") - except SessionQueueItemNotFoundError: - return - - async def _handle_error_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] - queue_item = self.get_queue_item(item_id) - # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) - except SessionQueueItemNotFoundError: - return - - async def _handle_cancel_event(self, event: FastAPIEvent) -> None: - try: - item_id = event[1]["data"]["queue_item_id"] - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") - except SessionQueueItemNotFoundError: - return - def _set_in_progress_to_canceled(self) -> None: """ Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. @@ -271,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase): return SessionQueueItem.queue_item_from_dict(dict(result)) def _set_queue_item_status( - self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None + self, + item_id: int, + status: QUEUE_ITEM_STATUS, + error_type: Optional[str] = None, + error_message: Optional[str] = None, + error_traceback: Optional[str] = None, ) -> SessionQueueItem: try: self.__lock.acquire() self.__cursor.execute( """--sql UPDATE session_queue - SET status = ?, error = ? + SET status = ?, error_type = ?, error_message = ?, error_traceback = ? WHERE item_id = ? """, - (status, error, item_id), + (status, error_type, error_message, error_traceback, item_id), ) self.__conn.commit() except Exception: @@ -292,11 +248,7 @@ class SqliteSessionQueue(SessionQueueBase): queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=queue_item, - batch_status=batch_status, - queue_status=queue_status, - ) + self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item def is_empty(self, queue_id: str) -> IsEmptyResult: @@ -338,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return IsFullResult(is_full=is_full) - def delete_queue_item(self, item_id: int) -> SessionQueueItem: - queue_item = self.get_queue_item(item_id=item_id) - try: - self.__lock.acquire() - self.__cursor.execute( - """--sql - DELETE FROM session_queue - WHERE - item_id = ? - """, - (item_id,), - ) - self.__conn.commit() - except Exception: - self.__conn.rollback() - raise - finally: - self.__lock.release() - return queue_item - def clear(self, queue_id: str) -> ClearResult: try: self.__lock.acquire() @@ -424,17 +356,28 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return PruneResult(deleted=count) - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: - queue_item = self.get_queue_item(item_id) - if queue_item.status not in ["canceled", "failed", "completed"]: - status = "failed" if error is not None else "canceled" - queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") + return queue_item + + def complete_queue_item(self, item_id: int) -> SessionQueueItem: + queue_item = self._set_queue_item_status(item_id=item_id, status="completed") + return queue_item + + def fail_queue_item( + self, + item_id: int, + error_type: str, + error_message: str, + error_traceback: str, + ) -> SessionQueueItem: + queue_item = self._set_queue_item_status( + item_id=item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: @@ -470,18 +413,10 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -521,18 +456,10 @@ class SqliteSessionQueue(SessionQueueBase): ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + current_queue_item, batch_status, queue_status ) except Exception: self.__conn.rollback() @@ -562,6 +489,29 @@ class SqliteSessionQueue(SessionQueueBase): raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") return SessionQueueItem.queue_item_from_dict(dict(result)) + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + try: + # Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors + # when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced + # during execution. + session_json = session.model_dump_json(warnings=False, exclude_none=True) + self.__lock.acquire() + self.__cursor.execute( + """--sql + UPDATE session_queue + SET session = ? + WHERE item_id = ? + """, + (session_json, item_id), + ) + self.__conn.commit() + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return self.get_queue_item(item_id) + def list_queue_items( self, queue_id: str, @@ -578,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase): status, priority, field_values, - error, + error_type, + error_message, + error_traceback, created_at, updated_at, completed_at, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index cc2ea5cedb..8508d2484c 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -8,6 +8,7 @@ import networkx as nx from pydantic import ( BaseModel, GetJsonSchemaHandler, + ValidationError, field_validator, ) from pydantic.fields import Field @@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError): pass +class NodeInputError(ValueError): + """Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an + input fails validation. + + Attributes: + node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to + determine which field caused the failure. + """ + + def __init__(self, node: BaseInvocation, e: ValidationError): + self.original_error = e + self.node = node + # When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error + # represents the first input that failed. + self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"]) + super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}") + + +def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str: + """Helper to pretty-print pydantic error locations as dot-separated strings. + Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages + """ + path = "" + for i, x in enumerate(loc): + if isinstance(x, str): + if i > 0: + path += "." + path += x + else: + path += f"[{x}]" + return path + + @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -821,7 +855,10 @@ class GraphExecutionState(BaseModel): # Get values from edges if next_node is not None: - self._prepare_inputs(next_node) + try: + self._prepare_inputs(next_node) + except ValidationError as e: + raise NodeInputError(next_node, e) # If next is still none, there's no next node, return None return next_node diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 1acc5c725d..c932e66989 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,4 +1,3 @@ -import threading from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Callable, Dict, Optional, Union @@ -359,12 +358,11 @@ class ModelsInterface(InvocationContextInterface): if isinstance(identifier, str): model = self._services.model_manager.store.get_model(identifier) - result: LoadedModel = self._services.model_manager.load.load_model(model, submodel_type, self._data) + return self._services.model_manager.load.load_model(model, submodel_type) else: _submodel_type = submodel_type or identifier.submodel_type model = self._services.model_manager.store.get_model(identifier.key) - result = self._services.model_manager.load.load_model(model, _submodel_type, self._data) - return result + return self._services.model_manager.load.load_model(model, _submodel_type) def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None @@ -388,8 +386,7 @@ class ModelsInterface(InvocationContextInterface): if len(configs) > 1: raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") - result: LoadedModel = self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) - return result + return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Get a model's config. @@ -516,10 +513,10 @@ class ConfigInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event + self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool] ) -> None: super().__init__(services, data) - self._cancel_event = cancel_event + self._is_canceled = is_canceled def is_canceled(self) -> bool: """Checks if the current session has been canceled. @@ -527,7 +524,7 @@ class UtilInterface(InvocationContextInterface): Returns: True if the current session has been canceled, False if not. """ - return self._cancel_event.is_set() + return self._is_canceled() def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ @@ -602,7 +599,7 @@ class InvocationContext: def build_invocation_context( services: InvocationServices, data: InvocationContextData, - cancel_event: threading.Event, + is_canceled: Callable[[], bool], ) -> InvocationContext: """Builds the invocation context for a specific invocation execution. @@ -619,7 +616,7 @@ def build_invocation_context( tensors = TensorsInterface(services=services, data=data) models = ModelsInterface(services=services, data=data) config = ConfigInterface(services=services, data=data) - util = UtilInterface(services=services, data=data, cancel_event=cancel_event) + util = UtilInterface(services=services, data=data, is_canceled=is_canceled) conditioning = ConditioningInterface(services=services, data=data) boards = BoardsInterface(services=services, data=data) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 61f35a3b4e..cadf09f457 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -42,7 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) - migrator.register_migration(build_migration_10(app_config=config, logger=logger)) + migrator.register_migration(build_migration_10()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py index 4c4f742d4c..ce2cd2e965 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_10.py @@ -1,75 +1,35 @@ -import shutil import sqlite3 -from logging import Logger -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -LEGACY_CORE_MODELS = [ - # OpenPose - "any/annotators/dwpose/yolox_l.onnx", - "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - # DepthAnything - "any/annotators/depth_anything/depth_anything_vitl14.pth", - "any/annotators/depth_anything/depth_anything_vitb14.pth", - "any/annotators/depth_anything/depth_anything_vits14.pth", - # Lama inpaint - "core/misc/lama/lama.pt", - # RealESRGAN upscale - "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", - "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", - "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", -] - class Migration10Callback: - def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: - self._app_config = app_config - self._logger = logger - def __call__(self, cursor: sqlite3.Cursor) -> None: - self._remove_convert_cache() - self._remove_downloaded_models() - self._remove_unused_core_models() + self._update_error_cols(cursor) - def _remove_convert_cache(self) -> None: - """Rename models/.cache to models/.convert_cache.""" - self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") - legacy_convert_path = self._app_config.root_path / "models" / ".cache" - shutil.rmtree(legacy_convert_path, ignore_errors=True) + def _update_error_cols(self, cursor: sqlite3.Cursor) -> None: + """ + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ - def _remove_downloaded_models(self) -> None: - """Remove models from their old locations; they will re-download when needed.""" - self._logger.info( - "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." - ) - for model_path in LEGACY_CORE_MODELS: - legacy_dest_path = self._app_config.models_path / model_path - legacy_dest_path.unlink(missing_ok=True) - - def _remove_unused_core_models(self) -> None: - """Remove unused core models and their directories.""" - self._logger.info("Removing defunct core models.") - for dir in ["face_restoration", "misc", "upscaling"]: - path_to_remove = self._app_config.models_path / "core" / dir - shutil.rmtree(path_to_remove, ignore_errors=True) - shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;") + cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;") + cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;") -def build_migration_10(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: +def build_migration_10() -> Migration: """ Build the migration from database version 9 to 10. This migration does the following: - - Moves "core" models previously downloaded with download_with_progress_bar() into new - "models/.download_cache" directory. - - Renames "models/.cache" to "models/.convert_cache". + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. """ migration_10 = Migration( from_version=9, to_version=10, - callback=Migration10Callback(app_config=app_config, logger=logger), + callback=Migration10Callback(), ) return migration_10 diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py new file mode 100644 index 0000000000..3b616e2b82 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py @@ -0,0 +1,77 @@ +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = [ + # OpenPose + "any/annotators/dwpose/yolox_l.onnx", + "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "any/annotators/depth_anything/depth_anything_vitl14.pth", + "any/annotators/depth_anything/depth_anything_vitb14.pth", + "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "core/misc/lama/lama.pt", + # RealESRGAN upscale + "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +] + + +class Migration11Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._remove_convert_cache() + self._remove_downloaded_models() + self._remove_unused_core_models() + + def _remove_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + shutil.rmtree(legacy_convert_path, ignore_errors=True) + + def _remove_downloaded_models(self) -> None: + """Remove models from their old locations; they will re-download when needed.""" + self._logger.info( + "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." + ) + for model_path in LEGACY_CORE_MODELS: + legacy_dest_path = self._app_config.models_path / model_path + legacy_dest_path.unlink(missing_ok=True) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 9 to 10. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + - Adds `error_type` and `error_message` columns to the session queue table. + - Renames the `error` column to `error_traceback`. + """ + migration_11 = Migration( + from_version=10, + to_version=11, + callback=Migration11Callback(app_config=app_config, logger=logger), + ) + + return migration_11 diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8cb59f5b3a..8992e59ace 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional import torch from PIL import Image @@ -13,8 +13,36 @@ if TYPE_CHECKING: from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.shared.invocation_context import InvocationContextData +# fast latents preview matrix for sdxl +# generated by @StAlKeR7779 +SDXL_LATENT_RGB_FACTORS = [ + # R G B + [0.3816, 0.4930, 0.5320], + [-0.3753, 0.1631, 0.1739], + [0.1770, 0.3588, -0.2048], + [-0.4350, -0.2644, -0.4289], +] +SDXL_SMOOTH_MATRIX = [ + [0.0358, 0.0964, 0.0358], + [0.0964, 0.4711, 0.0964], + [0.0358, 0.0964, 0.0358], +] -def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): +# origingally adapted from code by @erucipe and @keturn here: +# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 +# these updated numbers for v1.5 are from @torridgristle +SD1_5_LATENT_RGB_FACTORS = [ + # R G B + [0.3444, 0.1385, 0.0670], # L1 + [0.1247, 0.4027, 0.1494], # L2 + [-0.3192, 0.2513, 0.2103], # L3 + [-0.1307, -0.1874, -0.7445], # L4 +] + + +def sample_to_lowres_estimated_image( + samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None +): latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors if smooth_matrix is not None: @@ -47,64 +75,12 @@ def stable_diffusion_step_callback( else: sample = intermediate_state.latents - # TODO: This does not seem to be needed any more? - # # txt2img provides a Tensor in the step_callback - # # img2img provides a PipelineIntermediateState - # if isinstance(sample, PipelineIntermediateState): - # # this was an img2img - # print('img2img') - # latents = sample.latents - # step = sample.step - # else: - # print('txt2img') - # latents = sample - # step = intermediate_state.step - - # TODO: only output a preview image when requested - if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]: - # fast latents preview matrix for sdxl - # generated by @StAlKeR7779 - sdxl_latent_rgb_factors = torch.tensor( - [ - # R G B - [0.3816, 0.4930, 0.5320], - [-0.3753, 0.1631, 0.1739], - [0.1770, 0.3588, -0.2048], - [-0.4350, -0.2644, -0.4289], - ], - dtype=sample.dtype, - device=sample.device, - ) - - sdxl_smooth_matrix = torch.tensor( - [ - [0.0358, 0.0964, 0.0358], - [0.0964, 0.4711, 0.0964], - [0.0358, 0.0964, 0.0358], - ], - dtype=sample.dtype, - device=sample.device, - ) - + sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device) + sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device) image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) else: - # origingally adapted from code by @erucipe and @keturn here: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 - - # these updated numbers for v1.5 are from @torridgristle - v1_5_latent_rgb_factors = torch.tensor( - [ - # R G B - [0.3444, 0.1385, 0.0670], # L1 - [0.1247, 0.4027, 0.1494], # L2 - [-0.3192, 0.2513, 0.2103], # L3 - [-0.1307, -0.1874, -0.7445], # L4 - ], - dtype=sample.dtype, - device=sample.device, - ) - + v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device) image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) (width, height) = image.size @@ -113,15 +89,9 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - events.emit_generator_progress( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - node_id=context_data.invocation.id, - source_node_id=context_data.source_invocation_id, - progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), - step=intermediate_state.step, - order=intermediate_state.order, - total_steps=intermediate_state.total_steps, + events.emit_invocation_denoise_progress( + context_data.queue_item, + context_data.invocation, + intermediate_state, + ProgressImage(dataURL=dataURL, width=width, height=height), ) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index ec77bbe477..bb7fd6f1d4 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -42,10 +42,26 @@ T = TypeVar("T") @dataclass class CacheRecord(Generic[T]): - """Elements of the cache.""" + """ + Elements of the cache: + + key: Unique key for each model, same as used in the models database. + model: Model in memory. + state_dict: A read-only copy of the model's state dict in RAM. It will be + used as a template for creating a copy in the VRAM. + size: Size of the model + loaded: True if the model's state dict is currently in VRAM + + Before a model is executed, the state_dict template is copied into VRAM, + and then injected into the model. When the model is finished, the VRAM + copy of the state dict is deleted, and the RAM version is reinjected + into the model. + """ key: str model: T + device: torch.device + state_dict: Optional[Dict[str, torch.Tensor]] size: int loaded: bool = False _locks: int = 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index bd7b2ffc7a..10c0210052 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -20,7 +20,6 @@ context. Use like this: import gc import math -import sys import time from contextlib import suppress from logging import Logger @@ -163,7 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]): return size = calc_model_size_by_data(model) self.make_room(size) - cache_record = CacheRecord(key, model, size) + + state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None + cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -253,21 +254,40 @@ class ModelCache(ModelCacheBase[AnyModel]): May raise a torch.cuda.OutOfMemoryError """ self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - model = cache_entry.model + source_device = cache_entry.device - source_device = model.device if hasattr(model, "device") else self.storage_device + # Note: We compare device types only so that 'cuda' == 'cuda:0'. + # This would need to be revised to support multi-GPU. if torch.device(source_device).type == torch.device(target_device).type: return + if not hasattr(cache_entry.model, "to"): + return + + # This roundabout method for moving the model around is done to avoid + # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. + # When moving to VRAM, we copy (not move) each element of the state dict from + # RAM to a new state dict in VRAM, and then inject it into the model. + # This operation is slightly faster than running `to()` on the whole model. + # + # When the model needs to be removed from VRAM we simply delete the copy + # of the state dict in VRAM, and reinject the state dict that is cached + # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() + try: - if hasattr(model, "to"): - model.to(target_device) - elif isinstance(model, dict): - for _, v in model.items(): - if hasattr(v, "to"): - v.to(target_device) + if cache_entry.state_dict is not None: + assert hasattr(cache_entry.model, "load_state_dict") + if target_device == self.storage_device: + cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + else: + new_dict: Dict[str, torch.Tensor] = {} + for k, v in cache_entry.state_dict.items(): + new_dict[k] = v.to(torch.device(target_device), copy=True) + cache_entry.model.load_state_dict(new_dict, assign=True) + cache_entry.model.to(target_device) + cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e @@ -347,43 +367,12 @@ class ModelCache(ModelCacheBase[AnyModel]): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - - refs = sys.getrefcount(cache_entry.model) - - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self.logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," - f" refs: {refs}" + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" ) - # Expected refs: - # 1 from cache_entry - # 1 from getrefcount function - # 1 from onnx runtime object - if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2): + if not cache_entry.locked: self.logger.debug( f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7de7a8e01c..f7a91ef756 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -2,6 +2,7 @@ "accessibility": { "about": "About", "createIssue": "Create Issue", + "submitSupportTicket": "Submit Support Ticket", "invokeProgressBar": "Invoke progress bar", "menu": "Menu", "mode": "Mode", @@ -146,7 +147,9 @@ "viewing": "Viewing", "viewingDesc": "Review images in a large gallery view", "editing": "Editing", - "editingDesc": "Edit on the Control Layers canvas" + "editingDesc": "Edit on the Control Layers canvas", + "enabled": "Enabled", + "disabled": "Disabled" }, "controlnet": { "controlAdapter_one": "Control Adapter", @@ -775,10 +778,14 @@ "cannotConnectToSelf": "Cannot connect to self", "cannotDuplicateConnection": "Cannot create duplicate connections", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", + "missingNode": "Missing invocation node", + "missingInvocationTemplate": "Missing invocation template", + "missingFieldTemplate": "Missing field template", "nodePack": "Node pack", "collection": "Collection", - "collectionFieldType": "{{name}} Collection", - "collectionOrScalarFieldType": "{{name}} Collection|Scalar", + "singleFieldType": "{{name}} (Single)", + "collectionFieldType": "{{name}} (Collection)", + "collectionOrScalarFieldType": "{{name}} (Single or Collection)", "colorCodeEdges": "Color-Code Edges", "colorCodeEdgesHelp": "Color-code edges according to their connected fields", "connectionWouldCreateCycle": "Connection would create a cycle", @@ -893,7 +900,10 @@ "zoomInNodes": "Zoom In", "zoomOutNodes": "Zoom Out", "betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.", - "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time." + "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.", + "imageAccessError": "Unable to find image {{image_name}}, resetting to default", + "boardAccessError": "Unable to find board {{board_id}}, resetting to default", + "modelAccessError": "Unable to find model {{key}}, resetting to default" }, "parameters": { "aspect": "Aspect", @@ -948,7 +958,7 @@ "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "controlAdapterNoImageSelected": "no Control Adapter image selected", "controlAdapterImageNotProcessed": "Control Adapter image not processed", - "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64", + "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}", "ipAdapterNoModelSelected": "no IP adapter selected", "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", @@ -1066,8 +1076,9 @@ }, "toast": { "addedToBoard": "Added to board", - "baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel", - "baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels", + "baseModelChanged": "Base Model Changed", + "baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel", + "baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels", "canceled": "Processing Canceled", "canvasCopiedClipboard": "Canvas Copied to Clipboard", "canvasDownloaded": "Canvas Downloaded", @@ -1088,10 +1099,17 @@ "metadataLoadFailed": "Failed to load metadata", "modelAddedSimple": "Model Added to Queue", "modelImportCanceled": "Model Import Canceled", + "outOfMemoryError": "Out of Memory Error", + "outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.", "parameters": "Parameters", - "parameterNotSet": "{{parameter}} not set", - "parameterSet": "{{parameter}} set", - "parametersNotSet": "Parameters Not Set", + "parameterSet": "Parameter Recalled", + "parameterSetDesc": "Recalled {{parameter}}", + "parameterNotSet": "Parameter Not Recalled", + "parameterNotSetDesc": "Unable to recall {{parameter}}", + "parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}", + "parametersSet": "Parameters Recalled", + "parametersNotSet": "Parameters Not Recalled", + "errorCopied": "Error Copied", "problemCopyingCanvas": "Problem Copying Canvas", "problemCopyingCanvasDesc": "Unable to export base layer", "problemCopyingImage": "Unable to Copy Image", @@ -1111,11 +1129,13 @@ "sentToImageToImage": "Sent To Image To Image", "sentToUnifiedCanvas": "Sent to Unified Canvas", "serverError": "Server Error", + "sessionRef": "Session: {{sessionId}}", "setAsCanvasInitialImage": "Set as canvas initial image", "setCanvasInitialImage": "Set canvas initial image", "setControlImage": "Set as control image", "setInitialImage": "Set as initial image", "setNodeField": "Set as node field", + "somethingWentWrong": "Something Went Wrong", "uploadFailed": "Upload failed", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadInitialImage": "Upload Initial Image", @@ -1555,7 +1575,6 @@ "controlLayers": "Control Layers", "globalMaskOpacity": "Global Mask Opacity", "autoNegative": "Auto Negative", - "toggleVisibility": "Toggle Layer Visibility", "deletePrompt": "Delete Prompt", "resetRegion": "Reset Region", "debugLayers": "Debug Layers", diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index dbdda8e209..169bfdb066 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -382,7 +382,7 @@ "canvasMerged": "Lienzo consolidado", "sentToImageToImage": "Enviar hacia Imagen a Imagen", "sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado", - "parametersNotSet": "Parámetros no establecidos", + "parametersNotSet": "Parámetros no recuperados", "metadataLoadFailed": "Error al cargar metadatos", "serverError": "Error en el servidor", "canceled": "Procesando la cancelación", @@ -390,7 +390,8 @@ "uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG", "parameterSet": "Conjunto de parámetros", "parameterNotSet": "Parámetro no configurado", - "problemCopyingImage": "No se puede copiar la imagen" + "problemCopyingImage": "No se puede copiar la imagen", + "errorCopied": "Error al copiar" }, "tooltip": { "feature": { diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index f365b43e10..bd82dd9a5b 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -524,7 +524,20 @@ "missingNodeTemplate": "Modello di nodo mancante", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante", "missingFieldTemplate": "Modello di campo mancante", - "imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata" + "imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata", + "layer": { + "initialImageNoImageSelected": "Nessuna immagine iniziale selezionata", + "t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}", + "controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato", + "controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile", + "controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata", + "controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata", + "ipAdapterNoModelSelected": "Nessun adattatore IP selezionato", + "ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile", + "ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata", + "rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP", + "rgNoRegion": "Nessuna regione selezionata" + } }, "useCpuNoise": "Usa la CPU per generare rumore", "iterations": "Iterazioni", @@ -824,8 +837,8 @@ "unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi", "addLinearView": "Aggiungi alla vista Lineare", "unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro", - "collectionFieldType": "{{name}} Raccolta", - "collectionOrScalarFieldType": "{{name}} Raccolta|Scalare", + "collectionFieldType": "{{name}} (Raccolta)", + "collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)", "nodeVersion": "Versione Nodo", "inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})", "unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"", @@ -863,7 +876,13 @@ "edit": "Modifica", "graph": "Grafico", "showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati", - "showEdgeLabels": "Mostra le etichette del collegamento" + "showEdgeLabels": "Mostra le etichette del collegamento", + "cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta", + "noGraph": "Nessun grafico", + "missingNode": "Nodo di invocazione mancante", + "missingInvocationTemplate": "Modello di invocazione mancante", + "missingFieldTemplate": "Modello di campo mancante", + "singleFieldType": "{{name}} (Singola)" }, "boards": { "autoAddBoard": "Aggiungi automaticamente bacheca", @@ -1034,7 +1053,16 @@ "graphFailedToQueue": "Impossibile mettere in coda il grafico", "batchFieldValues": "Valori Campi Lotto", "time": "Tempo", - "openQueue": "Apri coda" + "openQueue": "Apri coda", + "iterations_one": "Iterazione", + "iterations_many": "Iterazioni", + "iterations_other": "Iterazioni", + "prompts_one": "Prompt", + "prompts_many": "Prompt", + "prompts_other": "Prompt", + "generations_one": "Generazione", + "generations_many": "Generazioni", + "generations_other": "Generazioni" }, "models": { "noMatchingModels": "Nessun modello corrispondente", @@ -1563,7 +1591,6 @@ "brushSize": "Dimensioni del pennello", "globalMaskOpacity": "Opacità globale della maschera", "autoNegative": "Auto Negativo", - "toggleVisibility": "Attiva/disattiva la visibilità dei livelli", "deletePrompt": "Cancella il prompt", "debugLayers": "Debug dei Livelli", "rectangle": "Rettangolo", diff --git a/invokeai/frontend/web/public/locales/nl.json b/invokeai/frontend/web/public/locales/nl.json index 76377bd215..afcce62163 100644 --- a/invokeai/frontend/web/public/locales/nl.json +++ b/invokeai/frontend/web/public/locales/nl.json @@ -6,7 +6,7 @@ "settingsLabel": "Instellingen", "img2img": "Afbeelding naar afbeelding", "unifiedCanvas": "Centraal canvas", - "nodes": "Werkstroom-editor", + "nodes": "Werkstromen", "upload": "Upload", "load": "Laad", "statusDisconnected": "Niet verbonden", @@ -34,7 +34,60 @@ "controlNet": "ControlNet", "imageFailedToLoad": "Kan afbeelding niet laden", "learnMore": "Meer informatie", - "advanced": "Uitgebreid" + "advanced": "Uitgebreid", + "file": "Bestand", + "installed": "Geïnstalleerd", + "notInstalled": "Niet $t(common.installed)", + "simple": "Eenvoudig", + "somethingWentWrong": "Er ging iets mis", + "add": "Voeg toe", + "checkpoint": "Checkpoint", + "details": "Details", + "outputs": "Uitvoeren", + "save": "Bewaar", + "nextPage": "Volgende pagina", + "blue": "Blauw", + "alpha": "Alfa", + "red": "Rood", + "editor": "Editor", + "folder": "Map", + "format": "structuur", + "goTo": "Ga naar", + "template": "Sjabloon", + "input": "Invoer", + "loglevel": "Logboekniveau", + "safetensors": "Safetensors", + "saveAs": "Bewaar als", + "created": "Gemaakt", + "green": "Groen", + "tab": "Tab", + "positivePrompt": "Positieve prompt", + "negativePrompt": "Negatieve prompt", + "selected": "Geselecteerd", + "orderBy": "Sorteer op", + "prevPage": "Vorige pagina", + "beta": "Bèta", + "copyError": "$t(gallery.copy) Fout", + "toResolve": "Op te lossen", + "aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:", + "aboutHeading": "Creatieve macht voor jou", + "copy": "Kopieer", + "data": "Gegevens", + "or": "of", + "updated": "Bijgewerkt", + "outpaint": "outpainten", + "viewing": "Bekijken", + "viewingDesc": "Beoordeel afbeelding in een grote galerijweergave", + "editing": "Bewerken", + "editingDesc": "Bewerk op het canvas Stuurlagen", + "ai": "ai", + "inpaint": "inpainten", + "unknown": "Onbekend", + "delete": "Verwijder", + "direction": "Richting", + "error": "Fout", + "localSystem": "Lokaal systeem", + "unknownError": "Onbekende fout" }, "gallery": { "galleryImageSize": "Afbeeldingsgrootte", @@ -310,10 +363,41 @@ "modelSyncFailed": "Synchronisatie modellen mislukt", "modelDeleteFailed": "Model kon niet verwijderd worden", "convertingModelBegin": "Model aan het converteren. Even geduld.", - "predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)", + "predictionType": "Soort voorspelling", "advanced": "Uitgebreid", "modelType": "Soort model", - "vaePrecision": "Nauwkeurigheid VAE" + "vaePrecision": "Nauwkeurigheid VAE", + "loraTriggerPhrases": "LoRA-triggerzinnen", + "urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.", + "modelName": "Modelnaam", + "path": "Pad", + "triggerPhrases": "Triggerzinnen", + "typePhraseHere": "Typ zin hier in", + "useDefaultSettings": "Gebruik standaardinstellingen", + "modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding", + "modelImageUpdated": "Modelafbeelding bijgewerkt", + "modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding", + "noMatchingModels": "Geen overeenkomende modellen", + "scanPlaceholder": "Pad naar een lokale map", + "noModelsInstalled": "Geen modellen geïnstalleerd", + "noModelsInstalledDesc1": "Installeer modellen met de", + "noModelSelected": "Geen model geselecteerd", + "starterModels": "Beginnermodellen", + "textualInversions": "Tekstuele omkeringen", + "upcastAttention": "Upcast-aandacht", + "uploadImage": "Upload afbeelding", + "mainModelTriggerPhrases": "Triggerzinnen hoofdmodel", + "urlOrLocalPath": "URL of lokaal pad", + "scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.", + "simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map", + "modelSettings": "Modelinstellingen", + "pathToConfig": "Pad naar configuratie", + "prune": "Snoei", + "pruneTooltip": "Snoei voltooide importeringen uit wachtrij", + "repoVariant": "Repovariant", + "scanFolder": "Lees map in", + "scanResults": "Resultaten inlezen", + "source": "Bron" }, "parameters": { "images": "Afbeeldingen", @@ -353,13 +437,13 @@ "copyImage": "Kopieer afbeelding", "denoisingStrength": "Sterkte ontruisen", "scheduler": "Planner", - "seamlessXAxis": "X-as", - "seamlessYAxis": "Y-as", + "seamlessXAxis": "Naadloze tegels in x-as", + "seamlessYAxis": "Naadloze tegels in y-as", "clipSkip": "Overslaan CLIP", "negativePromptPlaceholder": "Negatieve prompt", "controlNetControlMode": "Aansturingsmodus", "positivePromptPlaceholder": "Positieve prompt", - "maskBlur": "Vervaag", + "maskBlur": "Vervaging van masker", "invoke": { "noNodesInGraph": "Geen knooppunten in graaf", "noModelSelected": "Geen model ingesteld", @@ -369,11 +453,25 @@ "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt", "noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding", "noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.", - "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.", + "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.", "systemDisconnected": "Systeem is niet verbonden", "missingNodeTemplate": "Knooppuntsjabloon ontbreekt", "missingFieldTemplate": "Veldsjabloon ontbreekt", - "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan" + "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan", + "layer": { + "initialImageNoImageSelected": "geen initiële afbeelding geselecteerd", + "controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd", + "controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter", + "controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd", + "controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt", + "ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter", + "ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd", + "rgNoRegion": "geen gebied geselecteerd", + "rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters", + "t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64", + "ipAdapterNoModelSelected": "geen IP-adapter geselecteerd" + }, + "imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt" }, "isAllowedToUpscale": { "useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model", @@ -383,7 +481,26 @@ "useCpuNoise": "Gebruik CPU-ruis", "imageActions": "Afbeeldingshandeling", "iterations": "Iteraties", - "coherenceMode": "Modus" + "coherenceMode": "Modus", + "infillColorValue": "Vulkleur", + "remixImage": "Meng afbeelding opnieuw", + "setToOptimalSize": "Optimaliseer grootte voor het model", + "setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)", + "aspect": "Beeldverhouding", + "infillMosaicTileWidth": "Breedte tegel", + "setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)", + "lockAspectRatio": "Zet beeldverhouding vast", + "infillMosaicTileHeight": "Hoogte tegel", + "globalNegativePromptPlaceholder": "Globale negatieve prompt", + "globalPositivePromptPlaceholder": "Globale positieve prompt", + "useSize": "Gebruik grootte", + "swapDimensions": "Wissel afmetingen om", + "globalSettings": "Globale instellingen", + "coherenceEdgeSize": "Randgrootte", + "coherenceMinDenoise": "Min. ontruising", + "infillMosaicMinColor": "Min. kleur", + "infillMosaicMaxColor": "Max. kleur", + "cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling" }, "settings": { "models": "Modellen", @@ -410,7 +527,12 @@ "intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist", "intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist", "clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.", - "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen" + "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen", + "clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken", + "enableInformationalPopovers": "Schakel informatieve hulpballonnen in", + "enableInvisibleWatermark": "Schakel onzichtbaar watermerk in", + "enableNSFWChecker": "Schakel NSFW-controle in", + "reloadingIn": "Opnieuw laden na" }, "toast": { "uploadFailed": "Upload mislukt", @@ -425,8 +547,8 @@ "connected": "Verbonden met server", "canceled": "Verwerking geannuleerd", "uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn", - "parameterNotSet": "Parameter niet ingesteld", - "parameterSet": "Instellen parameters", + "parameterNotSet": "{{parameter}} niet ingesteld", + "parameterSet": "{{parameter}} ingesteld", "problemCopyingImage": "Kan Afbeelding Niet Kopiëren", "baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld", "baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld", @@ -443,11 +565,11 @@ "maskSavedAssets": "Masker bewaard in Assets", "problemDownloadingCanvas": "Fout bij downloaden van canvas", "problemMergingCanvas": "Fout bij samenvoegen canvas", - "setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding", + "setCanvasInitialImage": "Initiële canvasafbeelding ingesteld", "imageUploaded": "Afbeelding geüpload", "addedToBoard": "Toegevoegd aan bord", "workflowLoaded": "Werkstroom geladen", - "modelAddedSimple": "Model toegevoegd", + "modelAddedSimple": "Model toegevoegd aan wachtrij", "problemImportingMaskDesc": "Kan masker niet exporteren", "problemCopyingCanvas": "Fout bij kopiëren canvas", "problemSavingCanvas": "Fout bij bewaren canvas", @@ -459,7 +581,18 @@ "maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets", "canvasSavedGallery": "Canvas bewaard in galerij", "imageUploadFailed": "Fout bij uploaden afbeelding", - "problemImportingMask": "Fout bij importeren masker" + "problemImportingMask": "Fout bij importeren masker", + "workflowDeleted": "Werkstroom verwijderd", + "invalidUpload": "Ongeldige upload", + "uploadInitialImage": "Initiële afbeelding uploaden", + "setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas", + "problemRetrievingWorkflow": "Fout bij ophalen van werkstroom", + "parameters": "Parameters", + "modelImportCanceled": "Importeren model geannuleerd", + "problemDeletingWorkflow": "Fout bij verwijderen van werkstroom", + "prunedQueue": "Wachtrij gesnoeid", + "problemDownloadingImage": "Fout bij downloaden afbeelding", + "resetInitialImage": "Initiële afbeelding hersteld" }, "tooltip": { "feature": { @@ -533,7 +666,11 @@ "showOptionsPanel": "Toon zijscherm", "menu": "Menu", "showGalleryPanel": "Toon deelscherm Galerij", - "loadMore": "Laad meer" + "loadMore": "Laad meer", + "about": "Over", + "mode": "Modus", + "resetUI": "$t(accessibility.reset) UI", + "createIssue": "Maak probleem aan" }, "nodes": { "zoomOutNodes": "Uitzoomen", @@ -547,7 +684,7 @@ "loadWorkflow": "Laad werkstroom", "downloadWorkflow": "Download JSON van werkstroom", "scheduler": "Planner", - "missingTemplate": "Ontbrekende sjabloon", + "missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)", "workflowDescription": "Korte beschrijving", "versionUnknown": " Versie onbekend", "noNodeSelected": "Geen knooppunt gekozen", @@ -563,7 +700,7 @@ "integer": "Geheel getal", "nodeTemplate": "Sjabloon knooppunt", "nodeOpacity": "Dekking knooppunt", - "unableToLoadWorkflow": "Kan werkstroom niet valideren", + "unableToLoadWorkflow": "Fout bij laden werkstroom", "snapToGrid": "Lijn uit op raster", "noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave", "nodeSearch": "Zoek naar knooppunten", @@ -614,11 +751,56 @@ "unknownField": "Onbekend veld", "colorCodeEdges": "Kleurgecodeerde randen", "unknownNode": "Onbekend knooppunt", - "mismatchedVersion": "Heeft niet-overeenkomende versie", + "mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)", "addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)", "loadingNodes": "Bezig met laden van knooppunten...", "snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing", - "workflowSettings": "Instellingen werkstroomeditor" + "workflowSettings": "Instellingen werkstroomeditor", + "addLinearView": "Voeg toe aan lineaire weergave", + "nodePack": "Knooppuntpakket", + "unknownInput": "Onbekende invoer: {{name}}", + "sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet", + "collectionFieldType": "Verzameling {{name}}", + "deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd", + "graph": "Grafiek", + "targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet", + "resetToDefaultValue": "Herstel naar standaardwaarden", + "editMode": "Bewerk in Werkstroom-editor", + "showEdgeLabels": "Toon randlabels", + "showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven", + "clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.", + "unableToParseFieldType": "fout bij bepalen soort veld", + "sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet", + "unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"", + "targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet", + "reorderLinearView": "Herorden lineaire weergave", + "newWorkflowDesc": "Een nieuwe werkstroom aanmaken?", + "collectionOrScalarFieldType": "Verzameling|scalair {{name}}", + "newWorkflow": "Nieuwe werkstroom", + "unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom", + "unsupportedAnyOfLength": "te veel union-leden ({{count}})", + "unknownOutput": "Onbekende uitvoer: {{name}}", + "viewMode": "Gebruik in lineaire weergave", + "unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref", + "unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}", + "unknownNodeType": "Onbekend soort knooppunt", + "edit": "Bewerk", + "updateAllNodes": "Werk knooppunten bij", + "allNodesUpdated": "Alle knooppunten bijgewerkt", + "nodeVersion": "Knooppuntversie", + "newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.", + "clearWorkflow": "Maak werkstroom leeg", + "clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?", + "inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})", + "outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})", + "unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties", + "unknownFieldType": "Soort $t(nodes.unknownField): {{type}}", + "unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom", + "betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.", + "prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.", + "noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.", + "unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt", + "unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten" }, "controlnet": { "amult": "a_mult", @@ -691,9 +873,28 @@ "canny": "Canny", "depthZoeDescription": "Genereer diepteblad via Zoe", "hedDescription": "Herkenning van holistisch-geneste randen", - "setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H", + "setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)", "scribble": "Krabbel", - "maxFaces": "Max. gezichten" + "maxFaces": "Max. gezichten", + "dwOpenpose": "DW Openpose", + "depthAnything": "Depth Anything", + "base": "Basis", + "hands": "Handen", + "selectCLIPVisionModel": "Selecteer een CLIP Vision-model", + "modelSize": "Modelgrootte", + "small": "Klein", + "large": "Groot", + "resizeSimple": "Wijzig grootte (eenvoudig)", + "beginEndStepPercentShort": "Begin-/eind-%", + "depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything", + "face": "Gezicht", + "body": "Lichaam", + "dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose", + "ipAdapterMethod": "Methode", + "full": "Volledig", + "style": "Alleen stijl", + "composition": "Alleen samenstelling", + "setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)" }, "dynamicPrompts": { "seedBehaviour": { @@ -706,7 +907,10 @@ "maxPrompts": "Max. prompts", "promptsWithCount_one": "{{count}} prompt", "promptsWithCount_other": "{{count}} prompts", - "dynamicPrompts": "Dynamische prompts" + "dynamicPrompts": "Dynamische prompts", + "showDynamicPrompts": "Toon dynamische prompts", + "loading": "Genereren van dynamische prompts...", + "promptsPreview": "Voorvertoning prompts" }, "popovers": { "noiseUseCPU": { @@ -719,7 +923,7 @@ }, "paramScheduler": { "paragraphs": [ - "De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model." + "De planner gebruikt gedurende het genereringsproces." ], "heading": "Planner" }, @@ -806,8 +1010,8 @@ }, "clipSkip": { "paragraphs": [ - "Kies hoeveel CLIP-modellagen je wilt overslaan.", - "Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen." + "Aantal over te slaan CLIP-modellagen.", + "Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen." ], "heading": "Overslaan CLIP" }, @@ -991,17 +1195,26 @@ "denoisingStrength": "Sterkte ontruising", "refinermodel": "Verfijningsmodel", "posAestheticScore": "Positieve esthetische score", - "concatPromptStyle": "Plak prompt- en stijltekst aan elkaar", + "concatPromptStyle": "Koppelen van prompt en stijl", "loading": "Bezig met laden...", "steps": "Stappen", - "posStylePrompt": "Positieve-stijlprompt" + "posStylePrompt": "Positieve-stijlprompt", + "freePromptStyle": "Handmatige stijlprompt", + "refinerSteps": "Aantal stappen verfijner" }, "models": { "noMatchingModels": "Geen overeenkomend modellen", "loading": "bezig met laden", "noMatchingLoRAs": "Geen overeenkomende LoRA's", "noModelsAvailable": "Geen modellen beschikbaar", - "selectModel": "Kies een model" + "selectModel": "Kies een model", + "noLoRAsInstalled": "Geen LoRA's geïnstalleerd", + "noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd", + "defaultVAE": "Standaard-VAE", + "lora": "LoRA", + "esrganModel": "ESRGAN-model", + "addLora": "Voeg LoRA toe", + "concepts": "Concepten" }, "boards": { "autoAddBoard": "Voeg automatisch bord toe", @@ -1019,7 +1232,13 @@ "downloadBoard": "Download bord", "changeBoard": "Wijzig bord", "loading": "Bezig met laden...", - "clearSearch": "Maak zoekopdracht leeg" + "clearSearch": "Maak zoekopdracht leeg", + "deleteBoard": "Verwijder bord", + "deleteBoardAndImages": "Verwijder bord en afbeeldingen", + "deleteBoardOnly": "Verwijder alleen bord", + "deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld", + "movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:", + "movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:" }, "invocationCache": { "disable": "Schakel uit", @@ -1036,5 +1255,39 @@ "clear": "Wis", "maxCacheSize": "Max. grootte cache", "cacheSize": "Grootte cache" + }, + "accordions": { + "generation": { + "title": "Genereren" + }, + "image": { + "title": "Afbeelding" + }, + "advanced": { + "title": "Geavanceerd", + "options": "$t(accordions.advanced.title) Opties" + }, + "control": { + "title": "Besturing" + }, + "compositing": { + "title": "Samenstellen", + "coherenceTab": "Coherentiefase", + "infillTab": "Invullen" + } + }, + "hrf": { + "upscaleMethod": "Opschaalmethode", + "metadata": { + "strength": "Sterkte oplossing voor hoge resolutie", + "method": "Methode oplossing voor hoge resolutie", + "enabled": "Oplossing voor hoge resolutie ingeschakeld" + }, + "hrf": "Oplossing voor hoge resolutie", + "enableHrf": "Schakel oplossing in voor hoge resolutie" + }, + "prompt": { + "addPromptTrigger": "Voeg prompttrigger toe", + "compatibleEmbeddings": "Compatibele embeddings" } } diff --git a/invokeai/frontend/web/public/locales/ru.json b/invokeai/frontend/web/public/locales/ru.json index 7fa1b73e7a..03ff7eb706 100644 --- a/invokeai/frontend/web/public/locales/ru.json +++ b/invokeai/frontend/web/public/locales/ru.json @@ -1594,7 +1594,6 @@ "deleteAll": "Удалить всё", "addLayer": "Добавить слой", "moveToFront": "На передний план", - "toggleVisibility": "Переключить видимость слоя", "addPositivePrompt": "Добавить $t(common.positivePrompt)", "addIPAdapter": "Добавить $t(common.ipAdapter)", "regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 30d8f41200..2d878d96e7 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -21,10 +21,10 @@ import i18n from 'i18n'; import { size } from 'lodash-es'; import { memo, useCallback, useEffect } from 'react'; import { ErrorBoundary } from 'react-error-boundary'; +import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import PreselectedImage from './PreselectedImage'; -import Toaster from './Toaster'; const DEFAULT_CONFIG = {}; @@ -46,6 +46,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { useSocketIO(); useGlobalModifiersInit(); useGlobalHotkeys(); + useGetOpenAPISchemaQuery(); const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone(); @@ -94,7 +95,6 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { - ); diff --git a/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx index d2992a8cd9..ced3037a40 100644 --- a/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx +++ b/invokeai/frontend/web/src/app/components/AppErrorBoundaryFallback.tsx @@ -1,5 +1,8 @@ -import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library'; +import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import newGithubIssueUrl from 'new-github-issue-url'; +import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi'; @@ -11,31 +14,39 @@ type Props = { }; const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => { - const toast = useToast(); const { t } = useTranslation(); + const isLocal = useAppSelector((s) => s.config.isLocal); const handleCopy = useCallback(() => { const text = JSON.stringify(serializeError(error), null, 2); navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``); toast({ - title: 'Error Copied', + id: 'ERROR_COPIED', + title: t('toast.errorCopied'), }); - }, [error, toast]); + }, [error, t]); - const url = useMemo( - () => - newGithubIssueUrl({ + const url = useMemo(() => { + if (isLocal) { + return newGithubIssueUrl({ user: 'invoke-ai', repo: 'InvokeAI', template: 'BUG_REPORT.yml', title: `[bug]: ${error.name}: ${error.message}`, - }), - [error.message, error.name] - ); + }); + } else { + return 'https://support.invoke.ai/support/tickets/new'; + } + }, [error.message, error.name, isLocal]); + return ( - {t('common.somethingWentWrong')} + + invoke-logo + {t('common.somethingWentWrong')} + + { {t('common.copyError')} - + diff --git a/invokeai/frontend/web/src/app/components/Toaster.ts b/invokeai/frontend/web/src/app/components/Toaster.ts deleted file mode 100644 index c86fd5060d..0000000000 --- a/invokeai/frontend/web/src/app/components/Toaster.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { useToast } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { addToast, clearToastQueue } from 'features/system/store/systemSlice'; -import type { MakeToastArg } from 'features/system/util/makeToast'; -import { makeToast } from 'features/system/util/makeToast'; -import { memo, useCallback, useEffect } from 'react'; - -/** - * Logical component. Watches the toast queue and makes toasts when the queue is not empty. - * @returns null - */ -const Toaster = () => { - const dispatch = useAppDispatch(); - const toastQueue = useAppSelector((s) => s.system.toastQueue); - const toast = useToast(); - useEffect(() => { - toastQueue.forEach((t) => { - toast(t); - }); - toastQueue.length > 0 && dispatch(clearToastQueue()); - }, [dispatch, toast, toastQueue]); - - return null; -}; - -/** - * Returns a function that can be used to make a toast. - * @example - * const toaster = useAppToaster(); - * toaster('Hello world!'); - * toaster({ title: 'Hello world!', status: 'success' }); - * @returns A function that can be used to make a toast. - * @see makeToast - * @see MakeToastArg - * @see UseToastOptions - */ -export const useAppToaster = () => { - const dispatch = useAppDispatch(); - const toaster = useCallback((arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), [dispatch]); - - return toaster; -}; - -export default memo(Toaster); diff --git a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts index aaa3b8f6f2..d3baf5f452 100644 --- a/invokeai/frontend/web/src/app/hooks/useSocketIO.ts +++ b/invokeai/frontend/web/src/app/hooks/useSocketIO.ts @@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks'; import type { MapStore } from 'nanostores'; import { atom, map } from 'nanostores'; import { useEffect, useMemo } from 'react'; +import { setEventListeners } from 'services/events/setEventListeners'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; -import { setEventListeners } from 'services/events/util/setEventListeners'; import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; import { io } from 'socket.io-client'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 0c0c8ed2bc..0fd2f1b79c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -35,28 +35,22 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; +import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; -import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; -import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; -import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError'; -import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed'; -import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed'; import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved'; import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested'; import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import type { AppDispatch, RootState } from 'app/store/store'; -import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings'; - export const listenerMiddleware = createListenerMiddleware(); export type AppStartListening = TypedStartListening; @@ -104,18 +98,13 @@ addCommitStagingAreaImageListener(startAppListening); // Socket.IO addGeneratorProgressEventListener(startAppListening); -addGraphExecutionStateCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening); addInvocationErrorEventListener(startAppListening); addInvocationStartedEventListener(startAppListening); addSocketConnectedEventListener(startAppListening); addSocketDisconnectedEventListener(startAppListening); -addSocketSubscribedEventListener(startAppListening); -addSocketUnsubscribedEventListener(startAppListening); addModelLoadEventListener(startAppListening); addModelInstallEventListener(startAppListening); -addSessionRetrievalErrorEventListener(startAppListening); -addInvocationRetrievalErrorEventListener(startAppListening); addSocketQueueItemStatusChangedEventListener(startAppListening); addBulkDownloadListeners(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts index ae26531722..9095a08431 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -8,7 +8,7 @@ import { resetCanvas, setInitialCanvasImage, } from 'features/canvas/store/canvasSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; @@ -30,22 +30,20 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis req.reset(); if (canceled > 0) { log.debug(`Canceled ${canceled} canvas batches`); - dispatch( - addToast({ - title: t('queue.cancelBatchSucceeded'), - status: 'success', - }) - ); + toast({ + id: 'CANCEL_BATCH_SUCCEEDED', + title: t('queue.cancelBatchSucceeded'), + status: 'success', + }); } dispatch(canvasBatchIdsReset()); } catch { log.error('Failed to cancel canvas batches'); - dispatch( - addToast({ - title: t('queue.cancelBatchFailed'), - status: 'error', - }) - ); + toast({ + id: 'CANCEL_BATCH_FAILED', + title: t('queue.cancelBatchFailed'), + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts index 68eda997b7..3f74bf9b61 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -1,8 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { parseify } from 'common/util/serialize'; -import { toast } from 'common/util/toast'; import { zPydanticValidationError } from 'features/system/store/zodSchemas'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { truncate, upperFirst } from 'lodash-es'; import { queueApi } from 'services/api/endpoints/queue'; @@ -16,18 +16,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = const arg = action.meta.arg.originalArgs; logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued'); - if (!toast.isActive('batch-queued')) { - toast({ - id: 'batch-queued', - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - count: response.enqueued, - direction: arg.prepend ? t('queue.front') : t('queue.back'), - }), - duration: 1000, - status: 'success', - }); - } + toast({ + id: 'QUEUE_BATCH_SUCCEEDED', + title: t('queue.batchQueued'), + status: 'success', + description: t('queue.batchQueuedDesc', { + count: response.enqueued, + direction: arg.prepend ? t('queue.front') : t('queue.back'), + }), + }); }, }); @@ -40,9 +37,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = if (!response) { toast({ + id: 'QUEUE_BATCH_FAILED', title: t('queue.batchFailedToQueue'), status: 'error', - description: 'Unknown Error', + description: t('common.unknownError'), }); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); return; @@ -52,7 +50,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = if (result.success) { result.data.data.detail.map((e) => { toast({ - id: 'batch-failed-to-queue', + id: 'QUEUE_BATCH_FAILED', title: truncate(upperFirst(e.msg), { length: 128 }), status: 'error', description: truncate( @@ -64,9 +62,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) = }); } else if (response.status !== 403) { toast({ + id: 'QUEUE_BATCH_FAILED', title: t('queue.batchFailedToQueue'), - description: t('common.unknownError'), status: 'error', + description: t('common.unknownError'), }); } logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index 38a0fd7911..489f218370 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -1,13 +1,12 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { toast } from 'common/util/toast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { - socketBulkDownloadCompleted, - socketBulkDownloadFailed, + socketBulkDownloadComplete, + socketBulkDownloadError, socketBulkDownloadStarted, } from 'services/events/actions'; @@ -28,7 +27,6 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // Show the response message if it exists, otherwise show the default message description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'), duration: null, - isClosable: true, }); }, }); @@ -40,9 +38,9 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // There isn't any toast to update if we get this event. toast({ + id: 'BULK_DOWNLOAD_REQUEST_FAILED', title: t('gallery.bulkDownloadRequestFailed'), - status: 'success', - isClosable: true, + status: 'error', }); }, }); @@ -56,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = }); startAppListening({ - actionCreator: socketBulkDownloadCompleted, + actionCreator: socketBulkDownloadComplete, effect: async (action) => { log.debug(action.payload.data, 'Bulk download preparation completed'); @@ -65,7 +63,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first const url = `/api/v1/images/download/${bulk_download_item_name}`; - const toastOptions: UseToastOptions = { + toast({ id: bulk_download_item_name, title: t('gallery.bulkDownloadReady', 'Download ready'), status: 'success', @@ -77,38 +75,24 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = /> ), duration: null, - isClosable: true, - }; - - if (toast.isActive(bulk_download_item_name)) { - toast.update(bulk_download_item_name, toastOptions); - } else { - toast(toastOptions); - } + }); }, }); startAppListening({ - actionCreator: socketBulkDownloadFailed, + actionCreator: socketBulkDownloadError, effect: async (action) => { log.debug(action.payload.data, 'Bulk download preparation failed'); const { bulk_download_item_name } = action.payload.data; - const toastOptions: UseToastOptions = { + toast({ id: bulk_download_item_name, title: t('gallery.bulkDownloadFailed'), status: 'error', description: action.payload.data.error, duration: null, - isClosable: true, - }; - - if (toast.isActive(bulk_download_item_name)) { - toast.update(bulk_download_item_name, toastOptions); - } else { - toast(toastOptions); - } + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts index e1f4804d56..311dda3e2e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts @@ -2,14 +2,14 @@ import { $logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: canvasCopiedToClipboard, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { getState }) => { const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' }); const state = getState(); @@ -19,22 +19,20 @@ export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartLi copyBlobToClipboard(blob); } catch (err) { moduleLog.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemCopyingCanvas'), - description: t('toast.problemCopyingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_COPY_FAILED', + title: t('toast.problemCopyingCanvas'), + description: t('toast.problemCopyingCanvasDesc'), + status: 'error', + }); return; } - dispatch( - addToast({ - title: t('toast.canvasCopiedClipboard'), - status: 'success', - }) - ); + toast({ + id: 'CANVAS_COPY_SUCCEEDED', + title: t('toast.canvasCopiedClipboard'), + status: 'success', + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts index 5b8150bd20..71e616b9ea 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts @@ -3,13 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; import { downloadBlob } from 'features/canvas/util/downloadBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: canvasDownloadedAsImage, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { getState }) => { const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' }); const state = getState(); @@ -18,18 +18,17 @@ export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartLi blob = await getBaseLayerBlob(state); } catch (err) { moduleLog.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemDownloadingCanvas'), - description: t('toast.problemDownloadingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_DOWNLOAD_FAILED', + title: t('toast.problemDownloadingCanvas'), + description: t('toast.problemDownloadingCanvasDesc'), + status: 'error', + }); return; } downloadBlob(blob, 'canvas.png'); - dispatch(addToast({ title: t('toast.canvasDownloaded'), status: 'success' })); + toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts index 55392ebff4..2aa1f52d6c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasImageToControlAdapter } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -20,13 +20,12 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi blob = await getBaseLayerBlob(state, true); } catch (err) { log.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemSavingCanvas'), - description: t('toast.problemSavingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_SAVING_CANVAS', + title: t('toast.problemSavingCanvas'), + description: t('toast.problemSavingCanvasDesc'), + status: 'error', + }); return; } @@ -43,7 +42,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi crop_visible: false, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasSentControlnetAssets') }, + title: t('toast.canvasSentControlnetAssets'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts index af0c3878fc..454342b997 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery.ts @@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -29,13 +29,12 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL if (!maskBlob) { log.error('Problem getting mask layer blob'); - dispatch( - addToast({ - title: t('toast.problemSavingMask'), - description: t('toast.problemSavingMaskDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_SAVING_MASK', + title: t('toast.problemSavingMask'), + description: t('toast.problemSavingMaskDesc'), + status: 'error', + }); return; } @@ -52,7 +51,7 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL crop_visible: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.maskSavedAssets') }, + title: t('toast.maskSavedAssets'), }, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts index 569b4badc7..2e6ca61d8a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts @@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { canvasMaskToControlAdapter } from 'features/canvas/store/actions'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -30,13 +30,12 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis if (!maskBlob) { log.error('Problem getting mask layer blob'); - dispatch( - addToast({ - title: t('toast.problemImportingMask'), - description: t('toast.problemImportingMaskDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_IMPORTING_MASK', + title: t('toast.problemImportingMask'), + description: t('toast.problemImportingMaskDesc'), + status: 'error', + }); return; } @@ -53,7 +52,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis crop_visible: false, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.maskSentControlnetAssets') }, + title: t('toast.maskSentControlnetAssets'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index 71b0e62b44..9ae6de2e76 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -4,7 +4,7 @@ import { canvasMerged } from 'features/canvas/store/actions'; import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore'; import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -17,13 +17,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => if (!blob) { moduleLog.error('Problem getting base layer blob'); - dispatch( - addToast({ - title: t('toast.problemMergingCanvas'), - description: t('toast.problemMergingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_MERGING_CANVAS', + title: t('toast.problemMergingCanvas'), + description: t('toast.problemMergingCanvasDesc'), + status: 'error', + }); return; } @@ -31,13 +30,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => if (!canvasBaseLayer) { moduleLog.error('Problem getting canvas base layer'); - dispatch( - addToast({ - title: t('toast.problemMergingCanvas'), - description: t('toast.problemMergingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'PROBLEM_MERGING_CANVAS', + title: t('toast.problemMergingCanvas'), + description: t('toast.problemMergingCanvasDesc'), + status: 'error', + }); return; } @@ -54,7 +52,7 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) => is_intermediate: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasMerged') }, + title: t('toast.canvasMerged'), }, }) ).unwrap(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index e3ba988886..71586b5f6e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -1,8 +1,9 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { parseify } from 'common/util/serialize'; import { canvasSavedToGallery } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -18,13 +19,12 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe blob = await getBaseLayerBlob(state); } catch (err) { log.error(String(err)); - dispatch( - addToast({ - title: t('toast.problemSavingCanvas'), - description: t('toast.problemSavingCanvasDesc'), - status: 'error', - }) - ); + toast({ + id: 'CANVAS_SAVE_FAILED', + title: t('toast.problemSavingCanvas'), + description: t('toast.problemSavingCanvasDesc'), + status: 'error', + }); return; } @@ -41,7 +41,10 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe crop_visible: true, postUploadAction: { type: 'TOAST', - toastOptions: { title: t('toast.canvasSavedGallery') }, + title: t('toast.canvasSavedGallery'), + }, + metadata: { + _canvas_objects: parseify(state.canvas.layerState.objects), }, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index 2a59cc0317..581146c25c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -14,8 +14,9 @@ import { } from 'features/controlLayers/store/controlLayersSlice'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { isImageOutput } from 'features/nodes/types/common'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { isEqual } from 'lodash-es'; import { getImageDTO } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig } from 'services/api/types'; @@ -47,8 +48,10 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { startAppListening({ matcher, - effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => { + effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => { const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId; + const state = getState(); + const originalState = getOriginalState(); // Cancel any in-progress instances of this listener cancelActiveListeners(); @@ -57,21 +60,33 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni // Delay before starting actual work await delay(DEBOUNCE_MS); - // Double-check that we are still eligible for processing - const state = getState(); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); - // If we have no image or there is no processor config, bail if (!layer) { return; } + // We should only process if the processor settings or image have changed + const originalLayer = originalState.controlLayers.present.layers + .filter(isControlAdapterLayer) + .find((l) => l.id === layerId); + const originalImage = originalLayer?.controlAdapter.image; + const originalConfig = originalLayer?.controlAdapter.processorConfig; + const image = layer.controlAdapter.image; const config = layer.controlAdapter.processorConfig; + if (isEqual(config, originalConfig) && isEqual(image, originalImage)) { + // Neither config nor image have changed, we can bail + return; + } + if (!image || !config) { - // The user has reset the image or config, so we should clear the processed image + // - If we have no image, we have nothing to process + // - If we have no processor config, we have nothing to process + // Clear the processed image and bail dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); + return; } // At this point, the user has stopped fiddling with the processor settings and there is a processor selected. @@ -81,8 +96,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId); } - // @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... - const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); + // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now + const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never); const enqueueBatchArg: BatchConfig = { prepend: true, batch: { @@ -118,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni const [invocationCompleteAction] = await take( (action): action is ReturnType => socketInvocationComplete.match(action) && - action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && - action.payload.data.source_node_id === processorNode.id + action.payload.data.batch_id === enqueueResult.batch.batch_id && + action.payload.data.invocation_source_id === processorNode.id ); // We still have to check the output type @@ -159,12 +174,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni } } - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } } finally { req.reset(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 0055866aa7..1e485b31d5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -10,7 +10,7 @@ import { } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isImageOutput } from 'features/nodes/types/common'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; @@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL const [invocationCompleteAction] = await take( (action): action is ReturnType => socketInvocationComplete.match(action) && - action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && - action.payload.data.source_node_id === nodeId + action.payload.data.batch_id === enqueueResult.batch.batch_id && + action.payload.data.invocation_source_id === nodeId ); // We still have to check the output type @@ -108,12 +108,11 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL } } - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index d5d74bf668..cd5304c32b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,4 +1,3 @@ -import type { UseToastOptions } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; @@ -14,7 +13,7 @@ import { } from 'features/controlLayers/store/controlLayersSlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { omit } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; @@ -42,16 +41,17 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis return; } - const DEFAULT_UPLOADED_TOAST: UseToastOptions = { + const DEFAULT_UPLOADED_TOAST = { + id: 'IMAGE_UPLOADED', title: t('toast.imageUploaded'), status: 'success', - }; + } as const; // default action - just upload and alert user if (postUploadAction?.type === 'TOAST') { - const { toastOptions } = postUploadAction; if (!autoAddBoardId || autoAddBoardId === 'none') { - dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions })); + const title = postUploadAction.title || DEFAULT_UPLOADED_TOAST.title; + toast({ ...DEFAULT_UPLOADED_TOAST, title }); } else { // Add this image to the board dispatch( @@ -70,24 +70,20 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis ? `${t('toast.addedToBoard')} ${board.board_name}` : `${t('toast.addedToBoard')} ${autoAddBoardId}`; - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description, - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description, + }); } return; } if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') { dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state))); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setAsCanvasInitialImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setAsCanvasInitialImage'), + }); return; } @@ -105,68 +101,56 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis controlImage: imageDTO.image_name, }) ); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); return; } if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(caLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(ipaLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') { const { layerId, ipAdapterId } = postUploadAction; dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') { const { layerId } = postUploadAction; dispatch(iiLayerImageChanged({ layerId, imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: t('toast.setControlImage'), - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: t('toast.setControlImage'), + }); } if (postUploadAction?.type === 'SET_NODES_IMAGE') { const { nodeId, fieldName } = postUploadAction; dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO })); - dispatch( - addToast({ - ...DEFAULT_UPLOADED_TOAST, - description: `${t('toast.setNodeField')} ${fieldName}`, - }) - ); + toast({ + ...DEFAULT_UPLOADED_TOAST, + description: `${t('toast.setNodeField')} ${fieldName}`, + }); return; } }, @@ -174,7 +158,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis startAppListening({ matcher: imagesApi.endpoints.uploadImage.matchRejected, - effect: (action, { dispatch }) => { + effect: (action) => { const log = logger('images'); const sanitizedData = { arg: { @@ -183,13 +167,11 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis }, }; log.error({ ...sanitizedData }, 'Image upload failed'); - dispatch( - addToast({ - title: t('toast.imageUploadFailed'), - description: action.error.message, - status: 'error', - }) - ); + toast({ + title: t('toast.imageUploadFailed'), + description: action.error.message, + status: 'error', + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index bc049cf498..239a5b863d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -8,8 +8,7 @@ import { loraRemoved } from 'features/lora/store/loraSlice'; import { modelSelected } from 'features/parameters/store/actions'; import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice'; import { zParameterModel } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; @@ -60,16 +59,14 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = }); if (modelsCleared > 0) { - dispatch( - addToast( - makeToast({ - title: t('toast.baseModelChangedCleared', { - count: modelsCleared, - }), - status: 'warning', - }) - ) - ); + toast({ + id: 'BASE_MODEL_CHANGED', + title: t('toast.baseModelChanged'), + description: t('toast.baseModelChangedCleared', { + count: modelsCleared, + }), + status: 'warning', + }); } } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts index 61a978d576..415c359d70 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts @@ -19,8 +19,7 @@ import { isParameterWidth, zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models'; import { isNonRefinerMainModelConfig } from 'services/api/types'; @@ -109,7 +108,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni } } - dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); + toast({ id: 'PARAMETER_SET', title: t('toast.parameterSet', { parameter: 'Default settings' }) }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index 2dd598396a..08ad830ba4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -1,7 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; -import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; +import { parseify } from 'common/util/serialize'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; @@ -11,13 +12,14 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis startAppListening({ actionCreator: socketGeneratorProgress, effect: (action) => { - log.trace(action.payload, `Generator progress`); - const { source_node_id, step, total_steps, progress_image } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + log.trace(parseify(action.payload), `Generator progress`); + const { invocation_source_id, step, total_steps, progress_image } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.IN_PROGRESS; nes.progress = (step + 1) / total_steps; nes.progressImage = progress_image ?? null; + upsertExecutionState(nes.nodeId, nes); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts deleted file mode 100644 index 5221679232..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketGraphExecutionStateComplete } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketGraphExecutionStateComplete, - effect: (action) => { - log.debug(action.payload, 'Session complete'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 06dc08d846..1a04f9493a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi actionCreator: socketInvocationComplete, effect: async (action, { dispatch, getState }) => { const { data } = action.payload; - log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); + log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`); - const { result, node, queue_batch_id, source_node_id } = data; + const { result, invocation_source_id } = data; // This complete event has an associated image output - if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { - const { image_name } = result.image; + if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) { + const { image_name } = data.result.image; const { canvas, gallery } = getState(); // This populates the `getImageDTO` cache @@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi imageDTORequest.unsubscribe(); // Add canvas images to the staging area - if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) { + if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) { dispatch(addImageToStagingArea(imageDTO)); } @@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi } } - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.COMPLETED; if (nes.progress !== null) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index ce26c4dd7d..b34f34a079 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; +import { parseify } from 'common/util/serialize'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationError } from 'services/events/actions'; @@ -11,14 +12,18 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe startAppListening({ actionCreator: socketInvocationError, effect: (action) => { - log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); - const { source_node_id } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data; + log.error(parseify(action.payload), `Invocation error (${invocation.type})`); + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.FAILED; - nes.error = action.payload.data.error; nes.progress = null; nes.progressImage = null; + nes.error = { + error_type, + error_message, + error_traceback, + }; upsertExecutionState(nes.nodeId, nes); } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts deleted file mode 100644 index 44da4c0ddb..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketInvocationRetrievalError } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketInvocationRetrievalError, - effect: (action) => { - log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts index 9d6e0ac14d..7dae869ce2 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; +import { parseify } from 'common/util/serialize'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationStarted } from 'services/events/actions'; @@ -11,9 +12,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis startAppListening({ actionCreator: socketInvocationStarted, effect: (action) => { - log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`); - const { source_node_id } = action.payload.data; - const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`); + const { invocation_source_id } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]); if (nes) { nes.status = zNodeStatus.enum.IN_PROGRESS; upsertExecutionState(nes.nodeId, nes); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index f474c2736b..7fafb8302c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api'; import { modelsApi } from 'services/api/endpoints/models'; import { socketModelInstallCancelled, - socketModelInstallCompleted, - socketModelInstallDownloading, + socketModelInstallComplete, + socketModelInstallDownloadProgress, socketModelInstallError, } from 'services/events/actions'; export const addModelInstallEventListener = (startAppListening: AppStartListening) => { startAppListening({ - actionCreator: socketModelInstallDownloading, + actionCreator: socketModelInstallDownloadProgress, effect: async (action, { dispatch }) => { const { bytes, total_bytes, id } = action.payload.data; @@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin }); startAppListening({ - actionCreator: socketModelInstallCompleted, + actionCreator: socketModelInstallComplete, effect: (action, { dispatch }) => { const { id } = action.payload.data; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts index 4f4ec7635e..0240fe219a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions'; +import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions'; const log = logger('socketio'); @@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening) startAppListening({ actionCreator: socketModelLoadStarted, effect: (action) => { - const { model_config, submodel_type } = action.payload.data; - const { name, base, type } = model_config; + const { config, submodel_type } = action.payload.data; + const { name, base, type } = config; const extras: string[] = [base, type]; + if (submodel_type) { extras.push(submodel_type); } @@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening) }); startAppListening({ - actionCreator: socketModelLoadCompleted, + actionCreator: socketModelLoadComplete, effect: (action) => { - const { model_config, submodel_type } = action.payload.data; - const { name, base, type } = model_config; + const { config, submodel_type } = action.payload.data; + const { name, base, type } = config; const extras: string[] = [base, type]; if (submodel_type) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx similarity index 55% rename from invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts rename to invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx index 2adc529766..8a83609b3c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.tsx @@ -3,6 +3,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { deepClone } from 'common/util/deepClone'; import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; +import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription'; +import { toast } from 'features/toast/toast'; import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; @@ -12,18 +14,38 @@ const log = logger('socketio'); export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: socketQueueItemStatusChanged, - effect: async (action, { dispatch }) => { + effect: async (action, { dispatch, getState }) => { // we've got new status for the queue item, batch and queue - const { queue_item, batch_status, queue_status } = action.payload.data; + const { + item_id, + session_id, + status, + started_at, + updated_at, + completed_at, + batch_status, + queue_status, + error_type, + error_message, + error_traceback, + } = action.payload.data; - log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`); + log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`); // Update this specific queue item in the list of queue items (this is the queue item DTO, without the session) dispatch( queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => { queueItemsAdapter.updateOne(draft, { - id: String(queue_item.item_id), - changes: queue_item, + id: String(item_id), + changes: { + status, + started_at, + updated_at: updated_at ?? undefined, + completed_at: completed_at ?? undefined, + error_type, + error_message, + error_traceback, + }, }); }) ); @@ -43,23 +65,18 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status) ); - // Update the queue item status (this is the full queue item, including the session) - dispatch( - queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => { - if (!draft) { - return; - } - Object.assign(draft, queue_item); - }) - ); - // Invalidate caches for things we cannot update // TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again dispatch( - queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus']) + queueApi.util.invalidateTags([ + 'CurrentSessionQueueItem', + 'NextSessionQueueItem', + 'InvocationCacheStatus', + { type: 'SessionQueueItem', id: item_id }, + ]) ); - if (['in_progress'].includes(action.payload.data.queue_item.status)) { + if (status === 'in_progress') { forEach($nodeExecutionStates.get(), (nes) => { if (!nes) { return; @@ -72,6 +89,25 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: clone.outputs = []; $nodeExecutionStates.setKey(clone.nodeId, clone); }); + } else if (status === 'failed' && error_type) { + const isLocal = getState().config.isLocal ?? true; + const sessionId = session_id; + + toast({ + id: `INVOCATION_ERROR_${error_type}`, + title: getTitleFromErrorType(error_type), + status: 'error', + duration: null, + updateDescription: isLocal, + description: ( + + ), + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts deleted file mode 100644 index a1a497dc08..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketSessionRetrievalError } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketSessionRetrievalError, - effect: (action) => { - log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts deleted file mode 100644 index 48324cb652..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketSubscribedSession } from 'services/events/actions'; - -const log = logger('socketio'); - -export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketSubscribedSession, - effect: (action) => { - log.debug(action.payload, 'Subscribed'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts deleted file mode 100644 index 7a76a809d6..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { socketUnsubscribedSession } from 'services/events/actions'; -const log = logger('socketio'); - -export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: socketUnsubscribedSession, - effect: (action) => { - log.debug(action.payload, 'Unsubscribed'); - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts index 6816e25bc1..6c4c2a9df1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -1,6 +1,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { stagingAreaImageSaved } from 'features/canvas/store/actions'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; @@ -29,15 +29,14 @@ export const addStagingAreaImageSavedListener = (startAppListening: AppStartList }) ); } - dispatch(addToast({ title: t('toast.imageSaved'), status: 'success' })); + toast({ id: 'IMAGE_SAVED', title: t('toast.imageSaved'), status: 'success' }); } catch (error) { - dispatch( - addToast({ - title: t('toast.imageSavingFailed'), - description: (error as Error)?.message, - status: 'error', - }) - ); + toast({ + id: 'IMAGE_SAVE_FAILED', + title: t('toast.imageSavingFailed'), + description: (error as Error)?.message, + status: 'error', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 63d960b406..07df2a4f42 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,12 +1,11 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice'; +import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice'; import { NodeUpdateError } from 'features/nodes/types/error'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartListening) => { @@ -31,7 +30,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi } try { const updatedNode = updateNode(node, template); - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + dispatch( + nodesChanged([ + { type: 'remove', id: updatedNode.id }, + { type: 'add', item: updatedNode }, + ]) + ); } catch (e) { if (e instanceof NodeUpdateError) { unableToUpdateCount++; @@ -45,24 +49,18 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi count: unableToUpdateCount, }) ); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToUpdateNodes', { - count: unableToUpdateCount, - }), - }) - ) - ); + toast({ + id: 'UNABLE_TO_UPDATE_NODES', + title: t('nodes.unableToUpdateNodes', { + count: unableToUpdateCount, + }), + }); } else { - dispatch( - addToast( - makeToast({ - title: t('nodes.allNodesUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'ALL_NODES_UPDATED', + title: t('nodes.allNodesUpdated'), + status: 'success', + }); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index ff5d5f24be..ce480a3573 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -4,7 +4,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware' import { parseify } from 'common/util/serialize'; import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph'; import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale'; -import { addToast } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig, ImageDTO } from 'services/api/types'; @@ -29,12 +29,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening { imageDTO }, t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce ); - dispatch( - addToast({ - title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce - status: 'error', - }) - ); + toast({ + id: 'NOT_ALLOWED_TO_UPSCALE', + title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce + status: 'error', + }); return; } @@ -65,12 +64,11 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening if (error instanceof Object && 'status' in error && error.status === 403) { return; } else { - dispatch( - addToast({ - title: t('queue.graphFailedToQueue'), - status: 'error', - }) - ); + toast({ + id: 'GRAPH_QUEUE_FAILED', + title: t('queue.graphFailedToQueue'), + status: 'error', + }); } } }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index a680bbca97..e6fc5a526a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -8,23 +8,23 @@ import type { Templates } from 'features/nodes/store/types'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow'; import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; +import { checkBoardAccess, checkImageAccess, checkModelAccess } from 'services/api/hooks/accessChecks'; import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types'; import { z } from 'zod'; import { fromZodError } from 'zod-validation-error'; -const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { +const getWorkflow = async (data: GraphAndWorkflowResponse, templates: Templates) => { if (data.workflow) { // Prefer to load the workflow if it's available - it has more information const parsed = JSON.parse(data.workflow); - return validateWorkflow(parsed, templates); + return await validateWorkflow(parsed, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else if (data.graph) { // Else we fall back on the graph, using the graphToWorkflow function to convert and do layout const parsed = JSON.parse(data.graph); const workflow = graphToWorkflow(parsed as NonNullableGraph, true); - return validateWorkflow(workflow, templates); + return await validateWorkflow(workflow, templates, checkImageAccess, checkBoardAccess, checkModelAccess); } else { throw new Error('No workflow or graph provided'); } @@ -33,13 +33,13 @@ const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => { export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: workflowLoadRequested, - effect: (action, { dispatch }) => { + effect: async (action, { dispatch }) => { const log = logger('nodes'); const { data, asCopy } = action.payload; const nodeTemplates = $templates.get(); try { - const { workflow, warnings } = getWorkflow(data, nodeTemplates); + const { workflow, warnings } = await getWorkflow(data, nodeTemplates); if (asCopy) { // If we're loading a copy, we need to remove the ID so that the backend will create a new workflow @@ -48,23 +48,18 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList dispatch(workflowLoaded(workflow)); if (!warnings.length) { - dispatch( - addToast( - makeToast({ - title: t('toast.workflowLoaded'), - status: 'success', - }) - ) - ); + toast({ + id: 'WORKFLOW_LOADED', + title: t('toast.workflowLoaded'), + status: 'success', + }); } else { - dispatch( - addToast( - makeToast({ - title: t('toast.loadedWithWarnings'), - status: 'warning', - }) - ) - ); + toast({ + id: 'WORKFLOW_LOADED', + title: t('toast.loadedWithWarnings'), + status: 'warning', + }); + warnings.forEach(({ message, ...rest }) => { log.warn(rest, message); }); @@ -77,54 +72,42 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList if (e instanceof WorkflowVersionError) { // The workflow version was not recognized in the valid list of versions log.error({ error: parseify(e) }, e.message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: e.message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }); } else if (e instanceof WorkflowMigrationError) { // There was a problem migrating the workflow to the latest version log.error({ error: parseify(e) }, e.message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: e.message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }); } else if (e instanceof z.ZodError) { // There was a problem validating the workflow itself const { message } = fromZodError(e, { prefix: t('nodes.workflowValidation'), }); log.error({ error: parseify(e) }, message); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: message, - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: message, + }); } else { // Some other error occurred log.error({ error: parseify(e) }, t('nodes.unknownErrorValidatingWorkflow')); - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - description: t('nodes.unknownErrorValidatingWorkflow'), - }) - ) - ); + toast({ + id: 'UNABLE_TO_VALIDATE_WORKFLOW', + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: t('nodes.unknownErrorValidatingWorkflow'), + }); } } }, diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 4982dbb83f..21636ada49 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -74,6 +74,7 @@ export type AppConfig = { maxUpscalePixels?: number; metadataFetchDebounce?: number; workflowFetchDebounce?: number; + isLocal?: boolean; sd: { defaultModel?: string; disabledControlNetModels: string[]; diff --git a/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts b/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts index ef9db44a9d..233b841034 100644 --- a/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts +++ b/invokeai/frontend/web/src/common/hooks/useCopyImageToClipboard.ts @@ -1,11 +1,10 @@ -import { useAppToaster } from 'app/components/Toaster'; import { useImageUrlToBlob } from 'common/hooks/useImageUrlToBlob'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; +import { toast } from 'features/toast/toast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; export const useCopyImageToClipboard = () => { - const toaster = useAppToaster(); const { t } = useTranslation(); const imageUrlToBlob = useImageUrlToBlob(); @@ -16,12 +15,11 @@ export const useCopyImageToClipboard = () => { const copyImageToClipboard = useCallback( async (image_url: string) => { if (!isClipboardAPIAvailable) { - toaster({ + toast({ + id: 'PROBLEM_COPYING_IMAGE', title: t('toast.problemCopyingImage'), description: "Your browser doesn't support the Clipboard API.", status: 'error', - duration: 2500, - isClosable: true, }); } try { @@ -33,23 +31,21 @@ export const useCopyImageToClipboard = () => { copyBlobToClipboard(blob); - toaster({ + toast({ + id: 'IMAGE_COPIED', title: t('toast.imageCopied'), status: 'success', - duration: 2500, - isClosable: true, }); } catch (err) { - toaster({ + toast({ + id: 'PROBLEM_COPYING_IMAGE', title: t('toast.problemCopyingImage'), description: String(err), status: 'error', - duration: 2500, - isClosable: true, }); } }, - [imageUrlToBlob, isClipboardAPIAvailable, t, toaster] + [imageUrlToBlob, isClipboardAPIAvailable, t] ); return { isClipboardAPIAvailable, copyImageToClipboard }; diff --git a/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts b/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts index 26a17e1d0c..ede247b9fb 100644 --- a/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts +++ b/invokeai/frontend/web/src/common/hooks/useDownloadImage.ts @@ -1,13 +1,12 @@ import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { $authToken } from 'app/store/nanostores/authToken'; import { useAppDispatch } from 'app/store/storeHooks'; import { imageDownloaded } from 'features/gallery/store/actions'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; export const useDownloadImage = () => { - const toaster = useAppToaster(); const { t } = useTranslation(); const dispatch = useAppDispatch(); const authToken = useStore($authToken); @@ -37,16 +36,15 @@ export const useDownloadImage = () => { window.URL.revokeObjectURL(url); dispatch(imageDownloaded()); } catch (err) { - toaster({ + toast({ + id: 'PROBLEM_DOWNLOADING_IMAGE', title: t('toast.problemDownloadingImage'), description: String(err), status: 'error', - duration: 2500, - isClosable: true, }); } }, - [t, toaster, dispatch, authToken] + [t, dispatch, authToken] ); return { downloadImage }; diff --git a/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts b/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts index 0334294e98..5b1bf1f5b3 100644 --- a/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts +++ b/invokeai/frontend/web/src/common/hooks/useFullscreenDropzone.ts @@ -1,6 +1,6 @@ -import { useAppToaster } from 'app/components/Toaster'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { toast } from 'features/toast/toast'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { useCallback, useEffect, useState } from 'react'; import type { Accept, FileRejection } from 'react-dropzone'; @@ -26,7 +26,6 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac export const useFullscreenDropzone = () => { const { t } = useTranslation(); - const toaster = useAppToaster(); const postUploadAction = useAppSelector(selectPostUploadAction); const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId); const [isHandlingUpload, setIsHandlingUpload] = useState(false); @@ -37,13 +36,14 @@ export const useFullscreenDropzone = () => { (rejection: FileRejection) => { setIsHandlingUpload(true); - toaster({ + toast({ + id: 'UPLOAD_FAILED', title: t('toast.uploadFailed'), description: rejection.errors.map((error) => error.message).join('\n'), status: 'error', }); }, - [t, toaster] + [t] ); const fileAcceptedCallback = useCallback( @@ -62,7 +62,8 @@ export const useFullscreenDropzone = () => { const onDrop = useCallback( (acceptedFiles: Array, fileRejections: Array) => { if (fileRejections.length > 1) { - toaster({ + toast({ + id: 'UPLOAD_FAILED', title: t('toast.uploadFailed'), description: t('toast.uploadFailedInvalidUploadDesc'), status: 'error', @@ -78,7 +79,7 @@ export const useFullscreenDropzone = () => { fileAcceptedCallback(file); }); }, - [t, toaster, fileAcceptedCallback, fileRejectionCallback] + [t, fileAcceptedCallback, fileRejectionCallback] ); const onDragOver = useCallback(() => { diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 41d6f4607e..dbf3c41480 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -137,7 +137,7 @@ const createSelector = (templates: Templates) => if (l.controlAdapter.type === 't2i_adapter') { const multiple = model?.base === 'sdxl' ? 32 : 64; if (size.width % multiple !== 0 || size.height % multiple !== 0) { - problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); + problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple })); } } } diff --git a/invokeai/frontend/web/src/common/util/toast.ts b/invokeai/frontend/web/src/common/util/toast.ts deleted file mode 100644 index ac61a4a12d..0000000000 --- a/invokeai/frontend/web/src/common/util/toast.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { createStandaloneToast, theme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; - -export const { toast } = createStandaloneToast({ - theme: theme, - defaultOptions: TOAST_OPTIONS.defaultOptions, -}); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index a22f23d9d3..fbb6378166 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -613,7 +613,7 @@ export const canvasSlice = createSlice({ state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id); } - const queueItemStatus = action.payload.data.queue_item.status; + const queueItemStatus = action.payload.data.status; if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') { resetStagingAreaIfEmpty(state); } diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx index 36509ec1d3..9e71ad943c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CALayer/CALayer.tsx @@ -4,7 +4,7 @@ import { CALayerControlAdapterWrapper } from 'features/controlLayers/components/ import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { layerSelected, selectCALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; @@ -26,7 +26,7 @@ export const CALayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx index c6efd041ca..c53c4c7631 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IILayer/IILayer.tsx @@ -5,7 +5,7 @@ import { InitialImagePreview } from 'features/controlLayers/components/IILayer/I import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { iiLayerDenoisingStrengthChanged, @@ -66,7 +66,7 @@ export const IILayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx index 2077700104..e8f60c8d07 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/IPALayer/IPALayer.tsx @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { IPALayerIPAdapterWrapper } from 'features/controlLayers/components/IPALayer/IPALayerIPAdapterWrapper'; import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { layerSelected, selectIPALayerOrThrow } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; @@ -22,7 +22,7 @@ export const IPALayer = memo(({ layerId }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx index d2dab39e36..227d74c35a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/LayerCommon/LayerVisibilityToggle.tsx @@ -1,8 +1,8 @@ import { IconButton } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { stopPropagation } from 'common/util/stopPropagation'; -import { useLayerIsVisible } from 'features/controlLayers/hooks/layerStateHooks'; -import { layerVisibilityToggled } from 'features/controlLayers/store/controlLayersSlice'; +import { useLayerIsEnabled } from 'features/controlLayers/hooks/layerStateHooks'; +import { layerIsEnabledToggled } from 'features/controlLayers/store/controlLayersSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiCheckBold } from 'react-icons/pi'; @@ -11,21 +11,21 @@ type Props = { layerId: string; }; -export const LayerVisibilityToggle = memo(({ layerId }: Props) => { +export const LayerIsEnabledToggle = memo(({ layerId }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const isVisible = useLayerIsVisible(layerId); + const isEnabled = useLayerIsEnabled(layerId); const onClick = useCallback(() => { - dispatch(layerVisibilityToggled(layerId)); + dispatch(layerIsEnabledToggled(layerId)); }, [dispatch, layerId]); return ( : undefined} + icon={isEnabled ? : undefined} onClick={onClick} colorScheme="base" onDoubleClick={stopPropagation} // double click expands the layer @@ -33,4 +33,4 @@ export const LayerVisibilityToggle = memo(({ layerId }: Props) => { ); }); -LayerVisibilityToggle.displayName = 'LayerVisibilityToggle'; +LayerIsEnabledToggle.displayName = 'LayerVisibilityToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx index a6bce5316e..cc331017d3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayer.tsx @@ -6,7 +6,7 @@ import { AddPromptButtons } from 'features/controlLayers/components/AddPromptBut import { LayerDeleteButton } from 'features/controlLayers/components/LayerCommon/LayerDeleteButton'; import { LayerMenu } from 'features/controlLayers/components/LayerCommon/LayerMenu'; import { LayerTitle } from 'features/controlLayers/components/LayerCommon/LayerTitle'; -import { LayerVisibilityToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; +import { LayerIsEnabledToggle } from 'features/controlLayers/components/LayerCommon/LayerVisibilityToggle'; import { LayerWrapper } from 'features/controlLayers/components/LayerCommon/LayerWrapper'; import { isRegionalGuidanceLayer, @@ -55,7 +55,7 @@ export const RGLayer = memo(({ layerId }: Props) => { return ( - + {autoNegative === 'invert' && ( diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx index ce02811ebf..ba02aa9242 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerNegativePrompt.tsx @@ -45,7 +45,6 @@ export const RGLayerNegativePrompt = memo(({ layerId }: Props) => { variant="darkFilled" paddingRight={30} fontSize="sm" - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx index 56d3953e25..6f85ea077c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RGLayer/RGLayerPositivePrompt.tsx @@ -45,7 +45,6 @@ export const RGLayerPositivePrompt = memo(({ layerId }: Props) => { variant="darkFilled" paddingRight={30} minH={28} - spellCheck={false} /> diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts index f2054779d4..21e49ba15e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/layerStateHooks.ts @@ -39,7 +39,7 @@ export const useLayerNegativePrompt = (layerId: string) => { return prompt; }; -export const useLayerIsVisible = (layerId: string) => { +export const useLayerIsEnabled = (layerId: string) => { const selectLayer = useMemo( () => createSelector(selectControlLayersSlice, (controlLayers) => { diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 32e29918ae..5fa8cc3dfb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -139,7 +139,7 @@ export const controlLayersSlice = createSlice({ layerSelected: (state, action: PayloadAction) => { exclusivelySelectLayer(state, action.payload); }, - layerVisibilityToggled: (state, action: PayloadAction) => { + layerIsEnabledToggled: (state, action: PayloadAction) => { const layer = state.layers.find((l) => l.id === action.payload); if (layer) { layer.isEnabled = !layer.isEnabled; @@ -616,12 +616,24 @@ export const controlLayersSlice = createSlice({ iiLayerAdded: { reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; + + // Retain opacity and denoising strength of existing initial image layer if exists + let opacity = 1; + let denoisingStrength = 0.75; + const iiLayer = state.layers.find((l) => l.id === layerId); + if (iiLayer) { + assert(isInitialImageLayer(iiLayer)); + opacity = iiLayer.opacity; + denoisingStrength = iiLayer.denoisingStrength; + } + // Highlander! There can be only one! state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true)); + const layer: InitialImageLayer = { id: layerId, type: 'initial_image_layer', - opacity: 1, + opacity, x: 0, y: 0, bbox: null, @@ -629,7 +641,7 @@ export const controlLayersSlice = createSlice({ isEnabled: true, image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, isSelected: true, - denoisingStrength: 0.75, + denoisingStrength, }; state.layers.push(layer); exclusivelySelectLayer(state, layer.id); @@ -779,7 +791,7 @@ class LayerColors { export const { // Any Layer Type layerSelected, - layerVisibilityToggled, + layerIsEnabledToggled, layerTranslated, layerBboxChanged, layerReset, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index a25f6d8c0e..b3119aa8fa 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -1,6 +1,5 @@ import { Flex, MenuDivider, MenuItem, Spinner } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { $customStarUI } from 'app/store/nanostores/customStarUI'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard'; @@ -11,10 +10,13 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { toast } from 'features/toast/toast'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; @@ -44,10 +46,10 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const optimalDimension = useAppSelector(selectOptimalDimension); const dispatch = useAppDispatch(); const { t } = useTranslation(); - const toaster = useAppToaster(); const isCanvasEnabled = useFeatureStatus('canvas'); const customStarUi = useStore($customStarUI); const { downloadImage } = useDownloadImage(); + const templates = useStore($templates); const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } = useImageActions(imageDTO?.image_name); @@ -83,13 +85,12 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { }); dispatch(setInitialCanvasImage(imageDTO, optimalDimension)); - toaster({ + toast({ + id: 'SENT_TO_CANVAS', title: t('toast.sentToUnifiedCanvas'), status: 'success', - duration: 2500, - isClosable: true, }); - }, [dispatch, imageDTO, t, toaster, optimalDimension]); + }, [dispatch, imageDTO, t, optimalDimension]); const handleChangeBoard = useCallback(() => { dispatch(imagesToChangeSelected([imageDTO])); @@ -133,7 +134,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { : } onClickCapture={handleLoadWorkflow} - isDisabled={!imageDTO.has_workflow} + isDisabled={!imageDTO.has_workflow || !size(templates)} > {t('nodes.loadWorkflow')} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index ada9c35d28..d500d692fe 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,4 +1,5 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; @@ -12,12 +13,14 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; +import { $templates } from 'features/nodes/store/nodesSlice'; import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { selectSystemSlice } from 'features/system/store/systemSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -48,7 +51,7 @@ const CurrentImageButtons = () => { const lastSelectedImage = useAppSelector(selectLastSelectedImage); const selection = useAppSelector((s) => s.gallery.selection); const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); - + const templates = useStore($templates); const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); const { t } = useTranslation(); @@ -143,7 +146,7 @@ const CurrentImageButtons = () => { icon={} tooltip={`${t('nodes.loadWorkflow')} (W)`} aria-label={`${t('nodes.loadWorkflow')} (W)`} - isDisabled={!imageDTO?.has_workflow} + isDisabled={!imageDTO?.has_workflow || !size(templates)} onClick={handleLoadWorkflow} isLoading={getAndLoadEmbeddedWorkflowResult.isLoading} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx index 4bf55116db..df3fbe2765 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ToggleMetadataViewerButton.tsx @@ -1,6 +1,5 @@ import { IconButton } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; @@ -14,7 +13,6 @@ export const ToggleMetadataViewerButton = memo(() => { const dispatch = useAppDispatch(); const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails); const lastSelectedImage = useAppSelector(selectLastSelectedImage); - const toaster = useAppToaster(); const { t } = useTranslation(); const { currentData: imageDTO } = useGetImageDTOQuery(lastSelectedImage?.image_name ?? skipToken); @@ -24,7 +22,7 @@ export const ToggleMetadataViewerButton = memo(() => { [dispatch, shouldShowImageDetails] ); - useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails, toaster]); + useHotkeys('i', toggleMetadataViewer, { enabled: Boolean(imageDTO) }, [imageDTO, shouldShowImageDetails]); return ( { const recallSeed = useCallback(() => { handlers.seed.parse(metadata).then((seed) => { - handlers.seed.recall && handlers.seed.recall(seed); + handlers.seed.recall && handlers.seed.recall(seed, true); }); }, [metadata]); diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index b0d0e22688..2829507dcd 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -1,5 +1,4 @@ import { objectKeys } from 'common/util/objectKeys'; -import { toast } from 'common/util/toast'; import type { Layer } from 'features/controlLayers/store/types'; import type { LoRA } from 'features/lora/store/loraSlice'; import type { @@ -15,6 +14,7 @@ import type { import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { validators } from 'features/metadata/util/validators'; import type { ModelIdentifierField } from 'features/nodes/types/common'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { assert } from 'tsafe'; @@ -89,23 +89,23 @@ const renderLayersValue: MetadataRenderValueFunc = async (layers) => { return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`; }; -const parameterSetToast = (parameter: string, description?: string) => { +const parameterSetToast = (parameter: string) => { toast({ - title: t('toast.parameterSet', { parameter }), - description, + id: 'PARAMETER_SET', + title: t('toast.parameterSet'), + description: t('toast.parameterSetDesc', { parameter }), status: 'info', - duration: 2500, - isClosable: true, }); }; -const parameterNotSetToast = (parameter: string, description?: string) => { +const parameterNotSetToast = (parameter: string, message?: string) => { toast({ - title: t('toast.parameterNotSet', { parameter }), - description, + id: 'PARAMETER_NOT_SET', + title: t('toast.parameterNotSet'), + description: message + ? t('toast.parameterNotSetDescWithMessage', { parameter, message }) + : t('toast.parameterNotSetDesc', { parameter }), status: 'warning', - duration: 2500, - isClosable: true, }); }; @@ -458,7 +458,18 @@ export const parseAndRecallAllMetadata = async ( }); }) ); + if (results.some((result) => result.status === 'fulfilled')) { - parameterSetToast(t('toast.parameters')); + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersSet'), + status: 'info', + }); + } else { + toast({ + id: 'PARAMETER_SET', + title: t('toast.parametersNotSet'), + status: 'warning', + }); } }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts new file mode 100644 index 0000000000..7636b9f314 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useInstallModel.ts @@ -0,0 +1,48 @@ +import { toast } from 'features/toast/toast'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useInstallModelMutation } from 'services/api/endpoints/models'; + +type InstallModelArg = { + source: string; + inplace?: boolean; + onSuccess?: () => void; + onError?: (error: unknown) => void; +}; + +export const useInstallModel = () => { + const { t } = useTranslation(); + const [_installModel, request] = useInstallModelMutation(); + + const installModel = useCallback( + ({ source, inplace, onSuccess, onError }: InstallModelArg) => { + _installModel({ source, inplace }) + .unwrap() + .then((_) => { + if (onSuccess) { + onSuccess(); + } + toast({ + id: 'MODEL_INSTALL_QUEUED', + title: t('toast.modelAddedSimple'), + status: 'success', + }); + }) + .catch((error) => { + if (onError) { + onError(error); + } + if (error) { + toast({ + id: 'MODEL_INSTALL_QUEUE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); + } + }); + }, + [_installModel, t] + ); + + return [installModel, request] as const; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx index 6106264b78..6da320aa0b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx @@ -17,7 +17,11 @@ export const useStarterModelsToast = () => { useEffect(() => { if (toast.isActive(TOAST_ID)) { - return; + if (mainModels.length === 0) { + return; + } else { + toast.close(TOAST_ID); + } } if (data && mainModels.length === 0 && !didToast && isEnabled) { toast({ diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx index 184429478e..ee5960f7d2 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceForm.tsx @@ -1,11 +1,9 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEventHandler } from 'react'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; +import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models'; import { HuggingFaceResults } from './HuggingFaceResults'; @@ -14,50 +12,19 @@ export const HuggingFaceForm = () => { const [displayResults, setDisplayResults] = useState(false); const [errorMessage, setErrorMessage] = useState(''); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery(); - const [installModel] = useInstallModelMutation(); - - const handleInstallModel = useCallback( - (source: string) => { - installModel({ source }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); - }, - [installModel, dispatch, t] - ); + const [installModel] = useInstallModel(); const getModels = useCallback(async () => { _getHuggingFaceModels(huggingFaceRepo) .unwrap() .then((response) => { if (response.is_diffusers) { - handleInstallModel(huggingFaceRepo); + installModel({ source: huggingFaceRepo }); setDisplayResults(false); } else if (response.urls?.length === 1 && response.urls[0]) { - handleInstallModel(response.urls[0]); + installModel({ source: response.urls[0] }); setDisplayResults(false); } else { setDisplayResults(true); @@ -66,7 +33,7 @@ export const HuggingFaceForm = () => { .catch((error) => { setErrorMessage(error.data.detail || ''); }); - }, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]); + }, [_getHuggingFaceModels, installModel, huggingFaceRepo]); const handleSetHuggingFaceRepo: ChangeEventHandler = useCallback((e) => { setHuggingFaceRepo(e.target.value); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx index 1595a61147..32970a3666 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResultItem.tsx @@ -1,47 +1,20 @@ import { Flex, IconButton, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPlusBold } from 'react-icons/pi'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type Props = { result: string; }; export const HuggingFaceResultItem = ({ result }: Props) => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); - const handleInstall = useCallback(() => { - installModel({ source: result }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); - }, [installModel, result, dispatch, t]); + const onClick = useCallback(() => { + installModel({ source: result }); + }, [installModel, result]); return ( @@ -51,7 +24,7 @@ export const HuggingFaceResultItem = ({ result }: Props) => { {result} - } onClick={handleInstall} size="sm" /> + } onClick={onClick} size="sm" /> ); }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx index 8144accf3f..826fd177ea 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/HuggingFaceFolder/HuggingFaceResults.tsx @@ -8,15 +8,12 @@ import { InputGroup, InputRightElement, } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEventHandler } from 'react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiXBold } from 'react-icons/pi'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; import { HuggingFaceResultItem } from './HuggingFaceResultItem'; @@ -27,9 +24,8 @@ type HuggingFaceResultsProps = { export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { const { t } = useTranslation(); const [searchTerm, setSearchTerm] = useState(''); - const dispatch = useAppDispatch(); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); const filteredResults = useMemo(() => { return results.filter((result) => { @@ -46,34 +42,11 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { setSearchTerm(''); }, []); - const handleAddAll = useCallback(() => { + const onClickAddAll = useCallback(() => { for (const result of filteredResults) { - installModel({ source: result }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source: result }); } - }, [filteredResults, installModel, dispatch, t]); + }, [filteredResults, installModel]); return ( <> @@ -82,7 +55,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => { {t('modelManager.availableModels')} - diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx index 282e07ee27..cc052878bf 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/InstallModelForm.tsx @@ -1,12 +1,9 @@ import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import { t } from 'i18next'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type SimpleImportModelConfig = { location: string; @@ -14,9 +11,7 @@ type SimpleImportModelConfig = { }; export const InstallModelForm = () => { - const dispatch = useAppDispatch(); - - const [installModel, { isLoading }] = useInstallModelMutation(); + const [installModel, { isLoading }] = useInstallModel(); const { register, handleSubmit, formState, reset } = useForm({ defaultValues: { @@ -26,40 +21,22 @@ export const InstallModelForm = () => { mode: 'onChange', }); + const resetForm = useCallback(() => reset(undefined, { keepValues: true }), [reset]); + const onSubmit = useCallback>( (values) => { if (!values?.location) { return; } - installModel({ source: values.location, inplace: values.inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - reset(undefined, { keepValues: true }); - }) - .catch((error) => { - reset(undefined, { keepValues: true }); - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ + source: values.location, + inplace: values.inplace, + onSuccess: resetForm, + onError: resetForm, + }); }, - [dispatch, reset, installModel] + [installModel, resetForm] ); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx index 5db2743669..b3544af5b3 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueue.tsx @@ -1,8 +1,6 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { useCallback, useMemo } from 'react'; import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models'; @@ -10,8 +8,6 @@ import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } fro import { ModelInstallQueueItem } from './ModelInstallQueueItem'; export const ModelInstallQueue = () => { - const dispatch = useAppDispatch(); - const { data } = useListModelInstallsQuery(); const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation(); @@ -20,28 +16,22 @@ export const ModelInstallQueue = () => { _pruneCompletedModelInstalls() .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.prunedQueue'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_QUEUE_PRUNED', + title: t('toast.prunedQueue'), + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_QUEUE_PRUNE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); - }, [_pruneCompletedModelInstalls, dispatch]); + }, [_pruneCompletedModelInstalls]); const pruneAvailable = useMemo(() => { return data?.some( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx index d1fc600510..82a28b2d75 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ModelInstallQueue/ModelInstallQueueItem.tsx @@ -1,7 +1,5 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; @@ -29,7 +27,6 @@ const formatBytes = (bytes: number) => { export const ModelInstallQueueItem = (props: ModelListItemProps) => { const { installJob } = props; - const dispatch = useAppDispatch(); const [deleteImportModel] = useCancelModelInstallMutation(); @@ -37,28 +34,22 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => { deleteImportModel(installJob.id) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelImportCanceled'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_CANCELED', + title: t('toast.modelImportCanceled'), + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_INSTALL_CANCEL_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); - }, [deleteImportModel, installJob, dispatch]); + }, [deleteImportModel, installJob]); const sourceLocation = useMemo(() => { switch (installJob.source.type) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx index 360d6c1403..749ef4c8e0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ScanFolder/ScanFolderResults.tsx @@ -11,15 +11,13 @@ import { InputGroup, InputRightElement, } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import type { ChangeEvent, ChangeEventHandler } from 'react'; import { useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiXBold } from 'react-icons/pi'; -import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; +import type { ScanFolderResponse } from 'services/api/endpoints/models'; import { ScanModelResultItem } from './ScanFolderResultItem'; @@ -30,9 +28,8 @@ type ScanModelResultsProps = { export const ScanModelsResults = ({ results }: ScanModelResultsProps) => { const { t } = useTranslation(); const [searchTerm, setSearchTerm] = useState(''); - const dispatch = useAppDispatch(); const [inplace, setInplace] = useState(true); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); const filteredResults = useMemo(() => { return results.filter((result) => { @@ -58,61 +55,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => { if (result.is_installed) { continue; } - installModel({ source: result.path, inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source: result.path, inplace }); } - }, [filteredResults, installModel, inplace, dispatch, t]); + }, [filteredResults, installModel, inplace]); const handleInstallOne = useCallback( (source: string) => { - installModel({ source, inplace }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source, inplace }); }, - [installModel, inplace, dispatch, t] + [installModel, inplace] ); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx index a32fad810f..98e1e39640 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx @@ -1,20 +1,16 @@ import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; +import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel'; import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPlusBold } from 'react-icons/pi'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; -import { useInstallModelMutation } from 'services/api/endpoints/models'; type Props = { result: GetStarterModelsResponse[number]; }; export const StarterModelsResultItem = ({ result }: Props) => { const { t } = useTranslation(); - const dispatch = useAppDispatch(); const allSources = useMemo(() => { const _allSources = [result.source]; if (result.dependencies) { @@ -22,36 +18,13 @@ export const StarterModelsResultItem = ({ result }: Props) => { } return _allSources; }, [result]); - const [installModel] = useInstallModelMutation(); + const [installModel] = useInstallModel(); - const handleQuickAdd = useCallback(() => { + const onClick = useCallback(() => { for (const source of allSources) { - installModel({ source }) - .unwrap() - .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('toast.modelAddedSimple'), - status: 'success', - }) - ) - ); - }) - .catch((error) => { - if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); - } - }); + installModel({ source }); } - }, [allSources, installModel, dispatch, t]); + }, [allSources, installModel]); return ( @@ -67,7 +40,7 @@ export const StarterModelsResultItem = ({ result }: Props) => { {result.is_installed ? ( {t('common.installed')} ) : ( - } onClick={handleQuickAdd} size="sm" /> + } onClick={onClick} size="sm" /> )} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx index c9d0c03ed8..a4a6d5c833 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelListItem.tsx @@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice'; import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge'; import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import type { MouseEvent } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -53,25 +52,19 @@ const ModelListItem = (props: ModelListItemProps) => { deleteModel({ key: model.key }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelDeleted')}: ${model.name}`, - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_DELETED', + title: `${t('modelManager.modelDeleted')}: ${model.name}`, + status: 'success', + }); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_DELETE_FAILED', + title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`, + status: 'error', + }); } }); dispatch(setSelectedModelKey(null)); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx index dcdc4e2a36..9a84fbc726 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx @@ -1,10 +1,9 @@ import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings'; import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor'; import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -19,7 +18,6 @@ export type ControlNetOrT2IAdapterDefaultSettingsFormData = { export const ControlNetOrT2IAdapterDefaultSettings = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } = useControlNetOrT2IAdapterDefaultSettings(selectedModelKey); @@ -46,30 +44,24 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => { }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.defaultSettingsSaved'), - status: 'success', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVED', + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }); reset(data); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); }, - [selectedModelKey, dispatch, reset, updateModel, t] + [selectedModelKey, reset, updateModel, t] ); if (isLoadingDefaultSettings) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx index 0d7920ef77..292835a7b7 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload.tsx @@ -1,8 +1,6 @@ import { Box, Button, Flex, Icon, IconButton, Image, Tooltip } from '@invoke-ai/ui-library'; -import { useAppDispatch } from 'app/store/storeHooks'; import { typedMemo } from 'common/util/typedMemo'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback, useState } from 'react'; import { useDropzone } from 'react-dropzone'; import { useTranslation } from 'react-i18next'; @@ -15,7 +13,6 @@ type Props = { }; const ModelImageUpload = ({ model_key, model_image }: Props) => { - const dispatch = useAppDispatch(); const [image, setImage] = useState(model_image || null); const { t } = useTranslation(); @@ -34,27 +31,21 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => { .unwrap() .then(() => { setImage(URL.createObjectURL(file)); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_IMAGE_UPDATED', + title: t('modelManager.modelImageUpdated'), + status: 'success', + }); }) - .catch((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageUpdateFailed'), - status: 'error', - }) - ) - ); + .catch(() => { + toast({ + id: 'MODEL_IMAGE_UPDATE_FAILED', + title: t('modelManager.modelImageUpdateFailed'), + status: 'error', + }); }); }, - [dispatch, model_key, t, updateModelImage] + [model_key, t, updateModelImage] ); const handleResetImage = useCallback(() => { @@ -65,26 +56,20 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => { deleteModelImage(model_key) .unwrap() .then(() => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageDeleted'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_IMAGE_DELETED', + title: t('modelManager.modelImageDeleted'), + status: 'success', + }); }) - .catch((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelImageDeleteFailed'), - status: 'error', - }) - ) - ); + .catch(() => { + toast({ + id: 'MODEL_IMAGE_DELETE_FAILED', + title: t('modelManager.modelImageDeleteFailed'), + status: 'error', + }); }); - }, [dispatch, model_key, t, deleteModelImage]); + }, [model_key, t, deleteModelImage]); const { getInputProps, getRootProps } = useDropzone({ accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] }, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx index e096b11209..233fc7bc6b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx @@ -1,11 +1,10 @@ import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings'; import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight'; import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth'; import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -39,7 +38,6 @@ export type MainModelDefaultSettingsFormData = { export const MainModelDefaultSettings = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const { t } = useTranslation(); - const dispatch = useAppDispatch(); const { defaultSettingsDefaults, @@ -76,30 +74,24 @@ export const MainModelDefaultSettings = () => { }) .unwrap() .then((_) => { - dispatch( - addToast( - makeToast({ - title: t('modelManager.defaultSettingsSaved'), - status: 'success', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVED', + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }); reset(data); }) .catch((error) => { if (error) { - dispatch( - addToast( - makeToast({ - title: `${error.data.detail} `, - status: 'error', - }) - ) - ); + toast({ + id: 'DEFAULT_SETTINGS_SAVE_FAILED', + title: `${error.data.detail} `, + status: 'error', + }); } }); }, - [selectedModelKey, dispatch, reset, updateModel, t] + [selectedModelKey, reset, updateModel, t] ); if (isLoadingDefaultSettings) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx index d95eed8d24..fa7ca4c394 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/Model.tsx @@ -4,8 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton'; import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import type { SubmitHandler } from 'react-hook-form'; import { useForm } from 'react-hook-form'; @@ -47,25 +46,19 @@ export const Model = () => { .then((payload) => { form.reset(payload, { keepDefaultValues: true }); dispatch(setSelectedModelMode('view')); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdated'), - status: 'success', - }) - ) - ); + toast({ + id: 'MODEL_UPDATED', + title: t('modelManager.modelUpdated'), + status: 'success', + }); }) .catch((_) => { form.reset(); - dispatch( - addToast( - makeToast({ - title: t('modelManager.modelUpdateFailed'), - status: 'error', - }) - ) - ); + toast({ + id: 'MODEL_UPDATE_FAILED', + title: t('modelManager.modelUpdateFailed'), + status: 'error', + }); }); }, [dispatch, data?.key, form, t, updateModel] diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx index bcff0451d6..40ffca76b4 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton.tsx @@ -9,9 +9,7 @@ import { useDisclosure, } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; +import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models'; @@ -22,7 +20,6 @@ interface ModelConvertProps { export const ModelConvertButton = (props: ModelConvertProps) => { const { modelKey } = props; - const dispatch = useAppDispatch(); const { t } = useTranslation(); const { data } = useGetModelConfigQuery(modelKey ?? skipToken); const [convertModel, { isLoading }] = useConvertModelMutation(); @@ -33,38 +30,26 @@ export const ModelConvertButton = (props: ModelConvertProps) => { return; } - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`, - status: 'info', - }) - ) - ); + const toastId = `CONVERTING_MODEL_${data.key}`; + toast({ + id: toastId, + title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`, + status: 'info', + }); convertModel(data?.key) .unwrap() .then(() => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConverted')}: ${data?.name}`, - status: 'success', - }) - ) - ); + toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' }); }) .catch(() => { - dispatch( - addToast( - makeToast({ - title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`, - status: 'error', - }) - ) - ); + toast({ + id: toastId, + title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`, + status: 'error', + }); }); - }, [data, isLoading, dispatch, t, convertModel]); + }, [data, isLoading, t, convertModel]); if (data?.format !== 'checkpoint') { return; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx index 1f8e50b9da..8bc775c872 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelEdit.tsx @@ -72,10 +72,12 @@ export const ModelEdit = ({ form }: Props) => { {t('modelManager.baseModel')} - - {t('modelManager.variant')} - - + {data.type === 'main' && ( + + {t('modelManager.variant')} + + + )} {data.type === 'main' && data.format === 'checkpoint' && ( <> diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 95104c683c..6da87f4e98 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -3,33 +3,35 @@ import 'reactflow/dist/style.css'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppToaster } from 'app/components/Toaster'; import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { $cursorPos, + $edgePendingUpdate, $isAddNodePopoverOpen, $pendingConnection, $templates, closeAddNodePopover, - connectionMade, - nodeAdded, + edgesChanged, + nodesChanged, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; -import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; +import { toast } from 'features/toast/toast'; import { filter, map, memoize, some } from 'lodash-es'; -import type { KeyboardEventHandler } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { flushSync } from 'react-dom'; import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; -import { assert } from 'tsafe'; +import type { EdgeChange, NodeChange } from 'reactflow'; const createRegex = memoize( (inputValue: string) => @@ -58,7 +60,6 @@ const filterOption = memoize((option: FilterOptionOption, inputV const AddNodePopover = () => { const dispatch = useAppDispatch(); const buildInvocation = useBuildNode(); - const toaster = useAppToaster(); const { t } = useTranslation(); const selectRef = useRef | null>(null); const inputRef = useRef(null); @@ -69,17 +70,19 @@ const AddNodePopover = () => { const filteredTemplates = useMemo(() => { // If we have a connection in progress, we need to filter the node choices + const templatesArray = map(templates); if (!pendingConnection) { - return map(templates); + return templatesArray; } return filter(templates, (template) => { - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; - const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; - return some(fields, (field) => { - const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; - const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; - return validateSourceAndTargetTypes(sourceType, targetType); + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + return some(candidateFields, (field) => { + const sourceType = + pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type; + return validateConnectionTypes(sourceType, targetType); }); }); }, [templates, pendingConnection]); @@ -123,17 +126,43 @@ const AddNodePopover = () => { const errorMessage = t('nodes.unknownNode', { nodeType: nodeType, }); - toaster({ + toast({ status: 'error', title: errorMessage, }); return null; } + + // Find a cozy spot for the node const cursorPos = $cursorPos.get(); - dispatch(nodeAdded({ node, cursorPos })); + const { nodes, edges } = store.getState().nodes.present; + node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y); + node.selected = true; + + // Deselect all other nodes and edges + const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; + const edgeChanges: EdgeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (selected) { + nodeChanges.push({ type: 'select', id, selected: false }); + } + }); + edges.forEach(({ id, selected }) => { + if (selected) { + edgeChanges.push({ type: 'select', id, selected: false }); + } + }); + + // Onwards! + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } return node; }, - [dispatch, buildInvocation, toaster, t] + [buildInvocation, store, dispatch, t] ); const onChange = useCallback( @@ -145,12 +174,28 @@ const AddNodePopover = () => { // Auto-connect an edge if we just added a node and have a pending connection if (pendingConnection && isInvocationNode(node)) { - const template = templates[node.data.type]; - assert(template, 'Template not found'); + const edgePendingUpdate = $edgePendingUpdate.get(); + const { handleType } = pendingConnection; + + const source = handleType === 'source' ? pendingConnection.nodeId : node.id; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : node.id; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const { nodes, edges } = store.getState().nodes.present; - const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template); + const connection = getFirstValidConnection( + source, + sourceHandle, + target, + targetHandle, + nodes, + edges, + templates, + edgePendingUpdate + ); if (connection) { - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); } } @@ -160,25 +205,24 @@ const AddNodePopover = () => { ); const handleHotkeyOpen: HotkeyCallback = useCallback((e) => { - e.preventDefault(); - openAddNodePopover(); - flushSync(() => { - selectRef.current?.inputRef?.focus(); - }); + if (!$isAddNodePopoverOpen.get()) { + e.preventDefault(); + openAddNodePopover(); + flushSync(() => { + selectRef.current?.inputRef?.focus(); + }); + } }, []); const handleHotkeyClose: HotkeyCallback = useCallback(() => { - closeAddNodePopover(); - }, []); - - useHotkeys(['shift+a', 'space'], handleHotkeyOpen); - useHotkeys(['escape'], handleHotkeyClose); - const onKeyDown: KeyboardEventHandler = useCallback((e) => { - if (e.key === 'Escape') { + if ($isAddNodePopoverOpen.get()) { closeAddNodePopover(); } }, []); + useHotkeys(['shift+a', 'space'], handleHotkeyOpen); + useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] }); + const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]); return ( @@ -215,7 +259,6 @@ const AddNodePopover = () => { filterOption={filterOption} onChange={onChange} onMenuClose={closeAddNodePopover} - onKeyDown={onKeyDown} inputRef={inputRef} closeMenuOnSelect={false} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 656de737c7..1748989394 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,6 +1,6 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useConnection } from 'features/nodes/hooks/useConnection'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState'; @@ -8,38 +8,35 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection' import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, + $didUpdateEdge, + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, + $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, - connectionMade, - edgeAdded, - edgeDeleted, edgesChanged, - edgesDeleted, nodesChanged, - nodesDeleted, redo, - selectedAll, undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import type { + EdgeChange, + NodeChange, OnEdgesChange, - OnEdgesDelete, OnEdgeUpdateFunc, OnInit, OnMoveEnd, OnNodesChange, - OnNodesDelete, ProOptions, ReactFlowProps, ReactFlowState, } from 'reactflow'; -import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow'; +import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; @@ -48,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; -const DELETE_KEYS = ['Delete', 'Backspace']; - const edgeTypes = { collapsed: InvocationCollapsedEdge, default: InvocationDefaultEdge, @@ -81,6 +76,8 @@ export const Flow = memo(() => { const flowWrapper = useRef(null); const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); + const updateNodeInternals = useUpdateNodeInternals(); + const store = useAppStore(); useWorkflowWatcher(); useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); @@ -93,29 +90,17 @@ export const Flow = memo(() => { ); const onNodesChange: OnNodesChange = useCallback( - (changes) => { - dispatch(nodesChanged(changes)); + (nodeChanges) => { + dispatch(nodesChanged(nodeChanges)); }, [dispatch] ); const onEdgesChange: OnEdgesChange = useCallback( (changes) => { - dispatch(edgesChanged(changes)); - }, - [dispatch] - ); - - const onEdgesDelete: OnEdgesDelete = useCallback( - (edges) => { - dispatch(edgesDeleted(edges)); - }, - [dispatch] - ); - - const onNodesDelete: OnNodesDelete = useCallback( - (nodes) => { - dispatch(nodesDeleted(nodes)); + if (changes.length > 0) { + dispatch(edgesChanged(changes)); + } }, [dispatch] ); @@ -157,45 +142,50 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - // We have a ref for cursor position, but it is the *projected* cursor position. - // Easiest to just keep track of the last mouse event for this particular feature - const edgeUpdateMouseEvent = useRef(); - - const onEdgeUpdateStart: NonNullable = useCallback( - (e, edge, _handleType) => { - $isUpdatingEdge.set(true); - // update mouse event - edgeUpdateMouseEvent.current = e; - // always delete the edge when starting an updated - dispatch(edgeDeleted(edge.id)); - }, - [dispatch] - ); + const onEdgeUpdateStart: NonNullable = useCallback((e, edge, _handleType) => { + $edgePendingUpdate.set(edge); + $didUpdateEdge.set(false); + $lastEdgeUpdateMouseEvent.set(e); + }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (_oldEdge, newConnection) => { - // Because we deleted the edge when the update started, we must create a new edge from the connection - dispatch(connectionMade(newConnection)); + (oldEdge, newConnection) => { + // This event is fired when an edge update is successful + $didUpdateEdge.set(true); + // When an edge update is successful, we need to delete the old edge and create a new one + const newEdge = connectionToEdge(newConnection); + dispatch( + edgesChanged([ + { type: 'remove', id: oldEdge.id }, + { type: 'add', item: newEdge }, + ]) + ); + // Because we shift the position of handles depending on whether a field is connected or not, we must use + // updateNodeInternals to tell reactflow to recalculate the positions of the handles + updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]); }, - [dispatch] + [dispatch, updateNodeInternals] ); const onEdgeUpdateEnd: NonNullable = useCallback( (e, edge, _handleType) => { - $isUpdatingEdge.set(false); - $pendingConnection.set(null); - // Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting - // the edge update - we need to add it back - if ( - // ignore touch events - !('touches' in e) && - edgeUpdateMouseEvent.current?.clientX === e.clientX && - edgeUpdateMouseEvent.current?.clientY === e.clientY - ) { - dispatch(edgeAdded(edge)); + const didUpdateEdge = $didUpdateEdge.get(); + // Fall back to a reasonable default event + const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 }; + // We have to narrow this event down to MouseEvents - could be TouchEvent + const didMouseMove = + !('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5; + + // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, + // the user probably intended to delete the edge + if (!didUpdateEdge && didMouseMove) { + dispatch(edgesChanged([{ type: 'remove', id: edge.id }])); } - // reset mouse event - edgeUpdateMouseEvent.current = undefined; + + $edgePendingUpdate.set(null); + $didUpdateEdge.set(false); + $pendingConnection.set(null); + $lastEdgeUpdateMouseEvent.set(null); }, [dispatch] ); @@ -216,9 +206,27 @@ export const Flow = memo(() => { const onSelectAllHotkey = useCallback( (e: KeyboardEvent) => { e.preventDefault(); - dispatch(selectedAll()); + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (!selected) { + nodeChanges.push({ type: 'select', id, selected: true }); + } + }); + edges.forEach(({ id, selected }) => { + if (!selected) { + edgeChanges.push({ type: 'select', id, selected: true }); + } + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }, - [dispatch] + [dispatch, store] ); useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey); @@ -255,12 +263,37 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - $pendingConnection.set(null); - $isAddNodePopoverOpen.set(false); - cancelConnection(); + if (!$edgePendingUpdate.get()) { + $pendingConnection.set(null); + $isAddNodePopoverOpen.set(false); + cancelConnection(); + } }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); + const onDeleteHotkey = useCallback(() => { + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes + .filter((n) => n.selected) + .forEach(({ id }) => { + nodeChanges.push({ type: 'remove', id }); + }); + edges + .filter((e) => e.selected) + .forEach(({ id }) => { + edgeChanges.push({ type: 'remove', id }); + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } + }, [dispatch, store]); + useHotkeys(['delete', 'backspace'], onDeleteHotkey); + return ( { onMouseMove={onMouseMove} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} - onEdgesDelete={onEdgesDelete} onEdgeUpdate={onEdgeUpdate} onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateEnd={onEdgeUpdateEnd} - onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} onConnectEnd={onConnectEnd} @@ -292,9 +323,10 @@ export const Flow = memo(() => { proOptions={proOptions} style={flowStyles} onPaneClick={handlePaneClick} - deleteKeyCode={DELETE_KEYS} + deleteKeyCode={null} selectionMode={selectionMode} elevateEdgesOnSelect + nodeDragThreshold={1} > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx index 2e2fb31154..0d7e7b7d5e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx @@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; +import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector'; import { $templates } from 'features/nodes/store/nodesSlice'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; -import { makeEdgeSelector } from './util/makeEdgeSelector'; - const InvocationCollapsedEdge = ({ sourceX, sourceY, @@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({ targetPosition, markerEnd, data, - selected, + selected = false, source, - target, sourceHandleId, + target, targetHandleId, }: EdgeProps<{ count: number }>) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, selected, source, sourceHandleId, target, targetHandleId] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector); const [edgePath, labelX, labelY] = getBezierPath({ sourceX, @@ -44,14 +44,8 @@ const InvocationCollapsedEdge = ({ const { base500 } = useChakraThemeTokens(); const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke: base500, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [base500, isSelected, shouldAnimate] + () => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, base500, selected, shouldAnimateEdges] ); return ( @@ -60,11 +54,15 @@ const InvocationCollapsedEdge = ({ {data?.count && data.count > 1 && ( - + {data.count} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx index 2e4340975b..5a27e974e5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx @@ -1,8 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { $templates } from 'features/nodes/store/nodesSlice'; -import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; @@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({ sourcePosition, targetPosition, markerEnd, - selected, + selected = false, source, target, sourceHandleId, @@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({ }: EdgeProps) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, source, sourceHandleId, target, targetHandleId, selected] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector); const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels); const [edgePath, labelX, labelY] = getBezierPath({ @@ -41,15 +41,9 @@ const InvocationDefaultEdge = ({ targetPosition, }); - const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [isSelected, shouldAnimate, stroke] + const edgeStyles = useMemo( + () => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, stroke, selected, shouldAnimateEdges] ); return ( @@ -65,13 +59,13 @@ const InvocationDefaultEdge = ({ bg="base.800" borderRadius="base" borderWidth={1} - borderColor={isSelected ? 'undefined' : 'transparent'} - opacity={isSelected ? 1 : 0.5} + borderColor={selected ? 'undefined' : 'transparent'} + opacity={selected ? 1 : 0.5} py={1} px={3} shadow="md" > - + {label} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts index e7fa43015b..b5801c45ed 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -1,6 +1,7 @@ import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { FIELD_COLORS } from 'features/nodes/types/constants'; import type { FieldType } from 'features/nodes/types/field'; +import type { CSSProperties } from 'react'; export const getFieldColor = (fieldType: FieldType | null): string => { if (!fieldType) { @@ -10,3 +11,16 @@ export const getFieldColor = (fieldType: FieldType | null): string => { return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); }; + +export const getEdgeStyles = ( + stroke: string, + selected: boolean, + shouldAnimateEdges: boolean, + areConnectedNodesSelected: boolean +): CSSProperties => ({ + strokeWidth: 3, + stroke, + opacity: selected ? 1 : 0.5, + animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined, + strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none', +}); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 87ef8eb629..9c67728722 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,5 +1,6 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { deepClone } from 'common/util/deepClone'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { Templates } from 'features/nodes/store/types'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; @@ -8,8 +9,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; const defaultReturnValue = { - isSelected: false, - shouldAnimate: false, + areConnectedNodesSelected: false, + shouldAnimateEdges: false, stroke: colorTokenToCssVar('base.500'), label: '', }; @@ -19,21 +20,27 @@ export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, target: string, - targetHandleId: string | null | undefined, - selected?: boolean + targetHandleId: string | null | undefined ) => createMemoizedSelector( selectNodesSlice, selectWorkflowSettingsSlice, - (nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => { + ( + nodes, + workflowSettings + ): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => { + const { shouldAnimateEdges, shouldColorEdges } = workflowSettings; const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); + const returnValue = deepClone(defaultReturnValue); + returnValue.shouldAnimateEdges = shouldAnimateEdges; + const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected); if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { - return defaultReturnValue; + return returnValue; } const sourceNodeTemplate = templates[sourceNode.data.type]; @@ -42,16 +49,10 @@ export const makeEdgeSelector = ( const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; - const stroke = - sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); + returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); - const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; - return { - isSelected, - shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected, - stroke, - label, - }; + return returnValue; } ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx index 0147bcaed2..baa7fc262a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx @@ -1,7 +1,7 @@ import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper'; -import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames'; -import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames'; +import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck'; +import { useFieldNames } from 'features/nodes/hooks/useFieldNames'; import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames'; import { useWithFooter } from 'features/nodes/hooks/useWithFooter'; import { memo } from 'react'; @@ -20,8 +20,7 @@ type Props = { }; const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { - const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId); - const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId); + const fieldNames = useFieldNames(nodeId); const withFooter = useWithFooter(nodeId); const outputFieldNames = useOutputFieldNames(nodeId); @@ -41,9 +40,11 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { > - {inputConnectionFieldNames.map((fieldName, i) => ( + {fieldNames.connectionFields.map((fieldName, i) => ( - + + + ))} {outputFieldNames.map((fieldName, i) => ( @@ -52,8 +53,23 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { ))} - {inputAnyOrDirectFieldNames.map((fieldName) => ( - + {fieldNames.anyOrDirectFields.map((fieldName) => ( + + + + ))} + {fieldNames.missingFields.map((fieldName) => ( + + + ))} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 959b13c2d0..143dee983f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; +import type { ValidationResult } from 'features/nodes/store/util/validateConnection'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; -import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import type { HandleType } from 'reactflow'; import { Handle, Position } from 'reactflow'; @@ -14,11 +16,12 @@ type FieldHandleProps = { handleType: HandleType; isConnectionInProgress: boolean; isConnectionStartField: boolean; - connectionError?: string; + validationResult: ValidationResult; }; const FieldHandle = (props: FieldHandleProps) => { - const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props; + const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props; + const { t } = useTranslation(); const { name } = fieldTemplate; const type = fieldTemplate.type; const fieldTypeName = useFieldTypeName(type); @@ -26,11 +29,11 @@ const FieldHandle = (props: FieldHandleProps) => { const isModelType = MODEL_TYPES.some((t) => t === type.name); const color = getFieldColor(type); const s: CSSProperties = { - backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color, + backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0, + borderWidth: !isSingle(type) ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', @@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => { s.insetInlineEnd = '-1rem'; } - if (isConnectionInProgress && !isConnectionStartField && connectionError) { + if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) { s.filter = 'opacity(0.4) grayscale(0.7)'; } - if (isConnectionInProgress && connectionError) { + if (isConnectionInProgress && !validationResult.isValid) { if (isConnectionStartField) { s.cursor = 'grab'; } else { @@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => { } return s; - }, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]); + }, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]); const tooltip = useMemo(() => { - if (isConnectionInProgress && connectionError) { - return connectionError; + if (isConnectionInProgress && validationResult.messageTKey) { + return t(validationResult.messageTKey); } return fieldTypeName; - }, [connectionError, fieldTypeName, isConnectionInProgress]); + }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]); return ( { - const { t } = useTranslation(); const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); - const fieldInstance = useFieldInputInstance(nodeId, fieldName); const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const [isHovered, setIsHovered] = useState(false); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { @@ -55,20 +51,6 @@ const InputField = ({ nodeId, fieldName }: Props) => { setIsHovered(false); }, []); - if (!fieldTemplate || !fieldInstance) { - return ( - - - - {t('nodes.unknownInput', { - name: fieldInstance?.label ?? fieldTemplate?.title ?? fieldName, - })} - - - - ); - } - if (fieldTemplate.input === 'connection' || isConnected) { return ( @@ -88,7 +70,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); @@ -126,7 +108,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> )} @@ -134,27 +116,3 @@ const InputField = ({ nodeId, fieldName }: Props) => { }; export default memo(InputField); - -type InputFieldWrapperProps = PropsWithChildren<{ - shouldDim: boolean; -}>; - -const InputFieldWrapper = memo(({ shouldDim, children }: InputFieldWrapperProps) => { - return ( - - {children} - - ); -}); - -InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index b6e331c114..99937ceec4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,3 +1,4 @@ +import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { @@ -23,6 +24,8 @@ import { isLoRAModelFieldInputTemplate, isMainModelFieldInputInstance, isMainModelFieldInputTemplate, + isModelIdentifierFieldInputInstance, + isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, isSDXLMainModelFieldInputInstance, @@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx new file mode 100644 index 0000000000..8723538f85 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper.tsx @@ -0,0 +1,27 @@ +import { Flex } from '@invoke-ai/ui-library'; +import type { PropsWithChildren } from 'react'; +import { memo } from 'react'; + +type InputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +export const InputFieldWrapper = memo(({ shouldDim, children }: InputFieldWrapperProps) => { + return ( + + {children} + + ); +}); + +InputFieldWrapper.displayName = 'InputFieldWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx new file mode 100644 index 0000000000..f4b6be0cd6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck.tsx @@ -0,0 +1,59 @@ +import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectInvocationNode } from 'features/nodes/store/selectors'; +import type { PropsWithChildren } from 'react'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +type Props = PropsWithChildren<{ + nodeId: string; + fieldName: string; +}>; + +export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }: Props) => { + const { t } = useTranslation(); + const templates = useStore($templates); + const selector = useMemo( + () => + createSelector(selectNodesSlice, (nodesSlice) => { + const node = selectInvocationNode(nodesSlice, nodeId); + const instance = node.data.inputs[fieldName]; + const template = templates[node.data.type]; + const fieldTemplate = template?.inputs[fieldName]; + return { + name: instance?.label || fieldTemplate?.title || fieldName, + hasInstance: Boolean(instance), + hasTemplate: Boolean(fieldTemplate), + }; + }), + [fieldName, nodeId, templates] + ); + const { hasInstance, hasTemplate, name } = useAppSelector(selector); + + if (!hasTemplate || !hasInstance) { + return ( + + + + {t('nodes.unknownInput', { name })} + + + + ); + } + + return children; +}); + +InvocationInputFieldCheck.displayName = 'InvocationInputFieldCheck'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index 0cd199f7a4..ef466b2882 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities'; import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; +import { InvocationInputFieldCheck } from 'features/nodes/components/flow/nodes/Invocation/fields/InvocationFieldCheck'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice'; @@ -20,7 +21,7 @@ type Props = { fieldName: string; }; -const LinearViewField = ({ nodeId, fieldName }: Props) => { +const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => { const dispatch = useAppDispatch(); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); @@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { ); }; +const LinearViewField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index f2d776a2da..94e8b62744 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'outputs' }); if (!fieldTemplate) { @@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { handleType="source" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx new file mode 100644 index 0000000000..4019689978 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx @@ -0,0 +1,66 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice'; +import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback, useMemo } from 'react'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const ModelIdentifierFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const { data, isLoading } = useGetModelConfigsQuery(); + const _onChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldModelIdentifierValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const modelConfigs = useMemo(() => { + if (!data) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(data); + }, [data]); + + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + groupByType: true, + }); + + return ( + + + + + + ); +}; + +export default memo(ModelIdentifierFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx index 966809cb0e..76666af396 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx @@ -48,7 +48,7 @@ const NotesNode = (props: NodeProps) => { gap={1} > -