mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
292 Commits
v4.2.0b2
...
psyche/fea
Author | SHA1 | Date | |
---|---|---|---|
8eb5316c9f | |||
12ce095bb2 | |||
242b2a0b59 | |||
86e201612f | |||
f6d1e1be22 | |||
edf043d1d6 | |||
70a21eda78 | |||
452f4fe0e6 | |||
87f2d04ddd | |||
9c45cbe8f7 | |||
5f7c852493 | |||
35ef02bdf7 | |||
7da283f433 | |||
812cf277b8 | |||
182cb51bf0 | |||
64a3adfc64 | |||
a48ef9f7a7 | |||
9aeabf10df | |||
7b93cc8538 | |||
a2480c16e7 | |||
b1e2dd222e | |||
1f92e9eec2 | |||
fb402f3b46 | |||
0abc328ddf | |||
cfa4e5f88e | |||
24d0d4932d | |||
20db93b901 | |||
500a733d79 | |||
338d5f158b | |||
63e4b224b2 | |||
e9043ff060 | |||
c725851c64 | |||
a1c4ef55d7 | |||
e25b39aca2 | |||
32a02b3329 | |||
620ee2875e | |||
5553588147 | |||
1c29b3bd85 | |||
e88b807a13 | |||
9e55ef3d4b | |||
8062a47d16 | |||
dba8c43ecb | |||
8ebf2ddf15 | |||
f4625c2671 | |||
c94742bde6 | |||
a34faf0bd8 | |||
ecfff6cb1e | |||
ba8bed6870 | |||
ca186bca61 | |||
e2f109807c | |||
281bd31db2 | |||
cea1874e00 | |||
89b0e9e4de | |||
26d0d55d97 | |||
059c5586a4 | |||
9ed5698aa8 | |||
0b5696c5d4 | |||
a51142674a | |||
b8b671c0db | |||
7cceafe0dd | |||
cbe32b647a | |||
9a8e0842bb | |||
1d7671298f | |||
e38d75c3dc | |||
21fab9785a | |||
b3429553bb | |||
e480844042 | |||
26029108f7 | |||
504ac82077 | |||
6b11740dda | |||
a80e3448f5 | |||
4bda174eb9 | |||
b1e28c2f2c | |||
83000a4190 | |||
c98205d0d7 | |||
ce2ad5903c | |||
fe3980a369 | |||
ea97ae5ae8 | |||
3605b6b1a3 | |||
fc31dddbf7 | |||
6ad01d824d | |||
78f9f3ee95 | |||
972398d203 | |||
857889d1fa | |||
8074a802d6 | |||
059d5a682c | |||
00c2d8f95d | |||
04a596179b | |||
3fcb2720d7 | |||
6f7160b9fd | |||
6b4e464d17 | |||
9f7841a04b | |||
468644ab18 | |||
9d127fee6b | |||
6658897210 | |||
af7b194bec | |||
de1ea50e6d | |||
2680ef52c2 | |||
a012bb6e07 | |||
6a2c53f6c5 | |||
2cbf7d9221 | |||
fe7ed72c9c | |||
85a5a7c47a | |||
af3fd26d4e | |||
5127fd6320 | |||
124d34a8cc | |||
e8387d7523 | |||
a5d08c981b | |||
811d0da0f0 | |||
17e1fc5254 | |||
84e031edc2 | |||
b6b7e737e0 | |||
5f3e7afd45 | |||
b0cfca9d24 | |||
985ef89825 | |||
5928ade5fd | |||
93ebc175c6 | |||
386d552493 | |||
799cf06d20 | |||
922716d2ab | |||
66fc110b64 | |||
822f1e1f06 | |||
5d60c3c8e1 | |||
4e21d01c7f | |||
6b7b0b3777 | |||
07feb5ba07 | |||
a18d7adad4 | |||
32dff2c4e3 | |||
575ecb4028 | |||
ad8778df6c | |||
d2f5103f9f | |||
dd42a56084 | |||
23ac340a3f | |||
6791b4eaa8 | |||
a8b042177d | |||
76825f4261 | |||
78cb4d75ad | |||
a18bbac262 | |||
9ff5596963 | |||
8ea596b1e9 | |||
e3a143eaed | |||
c359ab6d9b | |||
dbfaa07e03 | |||
7f78fe7a36 | |||
6cf5b402c6 | |||
b0c7c7cb47 | |||
4d68cd8dbb | |||
2c1fa30639 | |||
708c68413d | |||
1d884fb794 | |||
f6a44681a8 | |||
d4df312300 | |||
9c0d44b412 | |||
27826369f0 | |||
31d8b50276 | |||
40b4fa7238 | |||
3b1743b7c2 | |||
f489c818f1 | |||
af477fa295 | |||
0ff0290735 | |||
67dbe6d949 | |||
4c3c2297b9 | |||
cadea55521 | |||
c8f30b1392 | |||
3d14a98abf | |||
77024bfca7 | |||
4a1c3786a1 | |||
b239891986 | |||
9fb03d43ff | |||
bdc59786bd | |||
fb6e926500 | |||
48ccd63dba | |||
ee647a05dc | |||
154b52ca4d | |||
5dd460c3ce | |||
4897ce2a13 | |||
5425526d50 | |||
5a4b050e66 | |||
8d39520232 | |||
04d12a1e98 | |||
39aa70963b | |||
5743254a41 | |||
c538ffea26 | |||
e8d3a7c870 | |||
2be66b1546 | |||
76e181fd44 | |||
b5d42fbc66 | |||
b463cd763e | |||
eb320df41d | |||
de1869773f | |||
ef89c7e537 | |||
008645d386 | |||
f8042ffb41 | |||
dbe22be598 | |||
8f6078d007 | |||
4020bf47e2 | |||
9d685da759 | |||
e3289856c0 | |||
47b8153728 | |||
7901e4c082 | |||
18b0977a31 | |||
fc6b214470 | |||
e22211dac0 | |||
e222484663 | |||
2a9cea6689 | |||
93da75209c | |||
9c819f0fd8 | |||
eef6fcf286 | |||
e375d9f787 | |||
ab18174774 | |||
9265841384 | |||
c5fd08125d | |||
11d88dae7f | |||
3b495659b0 | |||
15c9a3a4b6 | |||
60e77e4ed6 | |||
fa832a8ac6 | |||
f7834d7d59 | |||
63d7461510 | |||
1de704160e | |||
b118a2565c | |||
eb166baafe | |||
818d37f304 | |||
9cdb801c1c | |||
5da8cde4fc | |||
6ec3dc0c0d | |||
6050dffb25 | |||
93efeafe30 | |||
f167e8a8d3 | |||
124d49f35e | |||
52d8efa892 | |||
4ea8416c68 | |||
8dd0bfb068 | |||
6ff1c7d541 | |||
19f5a9c3a9 | |||
d9ce9c62ac | |||
cdc468a38c | |||
2656f13a4a | |||
da61396b1c | |||
6c9fb617dc | |||
5dd73fe53e | |||
e6793be465 | |||
63e62c5720 | |||
0848cb8ebd | |||
1b777bb972 | |||
029ee90351 | |||
2f9a064d48 | |||
b180666497 | |||
4740cd4f64 | |||
8b51298ba1 | |||
1533429e54 | |||
fc000214a5 | |||
f631aea4ee | |||
32f4c1f966 | |||
adebe639e3 | |||
44280ed472 | |||
cec8840038 | |||
fc7f484935 | |||
1aa7cd57c2 | |||
722a91aedb | |||
03c24ca9cb | |||
5820579237 | |||
6c768bfe7e | |||
5ca794b94f | |||
d20695260d | |||
d8557d573b | |||
6c1fd584d2 | |||
e8e764be20 | |||
e8023c44b0 | |||
a3a6449786 | |||
e9d2ffe3d7 | |||
23ad6fb730 | |||
00f36cb491 | |||
3f489c92c8 | |||
f147f99bef | |||
6107e3d281 | |||
de33d6e647 | |||
e36e5871a1 | |||
8b25c1a62e | |||
dfbd7eb1cf | |||
b43b2714cc | |||
e537de2f6d | |||
ccd399e277 | |||
bfad814862 | |||
6e8b7f9421 | |||
e47629cbe7 | |||
e840de27ed | |||
8342f32f2e | |||
a7aa529b99 | |||
4adc592657 | |||
e8d60e8d83 | |||
886f5c90a3 |
@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
|
||||
|
||||
"Custom" fields will always be treated as stateless fields.
|
||||
|
||||
##### Collection and Scalar Fields
|
||||
##### Single and Collection Fields
|
||||
|
||||
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
|
||||
Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
|
||||
|
||||
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
|
||||
|
||||
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
|
||||
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
|
||||
- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||
|
||||
## Implementation
|
||||
|
||||
@ -173,8 +173,7 @@ Field types are represented as structured objects:
|
||||
```ts
|
||||
type FieldType = {
|
||||
name: string;
|
||||
isCollection: boolean;
|
||||
isCollectionOrScalar: boolean;
|
||||
cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
|
||||
};
|
||||
```
|
||||
|
||||
@ -186,7 +185,7 @@ There are 4 general cases for field type parsing.
|
||||
|
||||
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
|
||||
|
||||
We create a field type name from this `type` string (e.g. `string` -> `StringField`).
|
||||
We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
|
||||
|
||||
##### Complex Types
|
||||
|
||||
@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
|
||||
|
||||
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
|
||||
|
||||
We use the item type for field type name, adding `isCollection: true` to the field type.
|
||||
We use the item type for field type name. The cardinality is `"COLLECTION"`.
|
||||
|
||||
##### Collection or Scalar Types
|
||||
##### Single or Collection Types
|
||||
|
||||
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
|
||||
|
||||
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type.
|
||||
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
|
||||
|
||||
##### Optional Fields
|
||||
|
||||
|
@ -98,7 +98,7 @@ Updating is exactly the same as installing - download the latest installer, choo
|
||||
|
||||
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
|
||||
|
||||
[installation requirements]: INSTALLATION.md#installation-requirements
|
||||
[installation requirements]: INSTALL_REQUIREMENTS.md
|
||||
[FAQ]: ../help/FAQ.md
|
||||
[install some models]: 050_INSTALLING_MODELS.md
|
||||
[configuration docs]: ../features/CONFIGURATION.md
|
||||
|
@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
|
||||
|
||||
### Requirements
|
||||
|
||||
Before you start, go through the [installation requirements].
|
||||
Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
|
||||
|
||||
### Installation Walkthrough
|
||||
|
||||
@ -79,7 +79,7 @@ Before you start, go through the [installation requirements].
|
||||
|
||||
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
|
||||
|
||||
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
|
||||
|
||||
!!! example "Install with an extra index URL"
|
||||
|
||||
@ -116,4 +116,4 @@ Before you start, go through the [installation requirements].
|
||||
|
||||
!!! warning
|
||||
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
|
||||
|
@ -37,13 +37,13 @@ Invoke runs best with a dedicated GPU, but will fall back to running on CPU, alb
|
||||
=== "Nvidia"
|
||||
|
||||
```
|
||||
Any GPU with at least 8GB VRAM. Linux only.
|
||||
Any GPU with at least 8GB VRAM.
|
||||
```
|
||||
|
||||
=== "AMD"
|
||||
|
||||
```
|
||||
Any GPU with at least 16GB VRAM.
|
||||
Any GPU with at least 16GB VRAM. Linux only.
|
||||
```
|
||||
|
||||
=== "Mac"
|
||||
|
@ -18,6 +18,7 @@ from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.events.events_fastapievents import FastAPIEventService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
@ -33,7 +34,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
|
@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
@ -13,7 +13,6 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
|
||||
@ -109,9 +108,7 @@ async def get_config() -> AppConfig:
|
||||
upscaling_models.append(str(Path(model).stem))
|
||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||
|
||||
nsfw_methods = []
|
||||
if SafetyChecker.safety_checker_available():
|
||||
nsfw_methods.append("nsfw_checker")
|
||||
nsfw_methods = ["nsfw_checker"]
|
||||
|
||||
watermarking_methods = ["invisible_watermark"]
|
||||
|
||||
|
@ -6,13 +6,12 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -42,13 +41,17 @@ async def upload_image(
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
metadata: Optional[JsonValue] = Body(
|
||||
default=None, description="The metadata to associate with the image", embed=True
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
metadata = None
|
||||
workflow = None
|
||||
_metadata = None
|
||||
_workflow = None
|
||||
_graph = None
|
||||
|
||||
contents = await file.read()
|
||||
try:
|
||||
@ -62,22 +65,28 @@ async def upload_image(
|
||||
|
||||
# TODO: retain non-invokeai metadata on upload?
|
||||
# attempt to parse metadata from image
|
||||
metadata_raw = pil_image.info.get("invokeai_metadata", None)
|
||||
if metadata_raw:
|
||||
try:
|
||||
metadata = MetadataFieldValidator.validate_json(metadata_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
|
||||
if isinstance(metadata_raw, str):
|
||||
_metadata = metadata_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to parse workflow from image
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
if isinstance(workflow_raw, str):
|
||||
_workflow = workflow_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
|
||||
pass
|
||||
|
||||
# attempt to extract graph from image
|
||||
graph_raw = pil_image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph_raw, str):
|
||||
_graph = graph_raw
|
||||
else:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
|
||||
pass
|
||||
|
||||
try:
|
||||
image_dto = ApiDependencies.invoker.services.images.create(
|
||||
@ -86,8 +95,9 @@ async def upload_image(
|
||||
image_category=image_category,
|
||||
session_id=session_id,
|
||||
board_id=board_id,
|
||||
metadata=metadata,
|
||||
workflow=workflow,
|
||||
metadata=_metadata,
|
||||
workflow=_workflow,
|
||||
graph=_graph,
|
||||
is_intermediate=is_intermediate,
|
||||
)
|
||||
|
||||
@ -185,14 +195,21 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
class WorkflowAndGraphResponse(BaseModel):
|
||||
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
|
||||
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
|
||||
)
|
||||
async def get_image_workflow(
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> Optional[WorkflowWithoutID]:
|
||||
) -> WorkflowAndGraphResponse:
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
|
||||
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
@ -16,7 +16,8 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -52,6 +53,13 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -118,8 +126,7 @@ async def list_model_records(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
model = add_cover_image_to_model_config(model, ApiDependencies)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -160,12 +167,9 @@ async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
|
||||
return add_cover_image_to_model_config(config, ApiDependencies)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@ -294,14 +298,15 @@ async def update_model_record(
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
record_store.update_model(key, changes=changes)
|
||||
model_response: AnyModelConfig = installer.sync_model_path(key)
|
||||
config = installer.sync_model_path(key)
|
||||
config = add_cover_image_to_model_config(config, ApiDependencies)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
return config
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
@ -648,6 +653,14 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# Update the model image if the model had one
|
||||
try:
|
||||
model_image = ApiDependencies.invoker.services.model_images.get(key)
|
||||
ApiDependencies.invoker.services.model_images.save(model_image, new_key)
|
||||
ApiDependencies.invoker.services.model_images.delete(key)
|
||||
except ModelImageFileNotFoundException:
|
||||
pass
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -655,7 +668,8 @@ async def convert_model(
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config: AnyModelConfig = store.get_model(new_key)
|
||||
new_config = store.get_model(new_key)
|
||||
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
|
||||
return new_config
|
||||
|
||||
|
||||
|
@ -203,6 +203,7 @@ async def get_batch_status(
|
||||
responses={
|
||||
200: {"model": SessionQueueItem},
|
||||
},
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
|
@ -1,66 +1,131 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadEventBase,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
FastAPIEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelEventBase,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
register_events,
|
||||
)
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
queue_id: str
|
||||
|
||||
|
||||
class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
class SocketIO:
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
_sub_queue = "subscribe_queue"
|
||||
_unsub_queue = "unsubscribe_queue"
|
||||
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
_sub_bulk_download = "subscribe_bulk_download"
|
||||
_unsub_bulk_download = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
register_events(
|
||||
{
|
||||
InvocationStartedEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionStartedEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionCanceledEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
},
|
||||
self._handle_queue_event,
|
||||
)
|
||||
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
register_events(
|
||||
{
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallErrorEvent,
|
||||
},
|
||||
self._handle_model_event,
|
||||
)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
register_events(
|
||||
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
|
||||
self._handle_bulk_image_download_event,
|
||||
)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
event_name, payload = event
|
||||
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)
|
||||
|
@ -27,6 +27,7 @@ import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -164,6 +165,12 @@ def custom_openapi() -> dict[str, Any]:
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
@ -172,25 +179,18 @@ def custom_openapi() -> dict[str, Any]:
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
# Add all event schemas
|
||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
if "$defs" in json_schema:
|
||||
for schema_key, schema in json_schema["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||
del json_schema["$defs"]
|
||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
@ -24,7 +24,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import cv2
|
||||
@ -504,7 +503,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.3",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
@ -516,23 +515,12 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
logger = context.logger
|
||||
logger.debug("Running NSFW checker")
|
||||
if SafetyChecker.has_nsfw_concept(image):
|
||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = self._get_caution_img()
|
||||
blurry_image.paste(caution, (0, 0), caution)
|
||||
image = blurry_image
|
||||
image = SafetyChecker.blur_if_nsfw(image)
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
def _get_caution_img(self) -> Image.Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_watermark",
|
||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
|
@ -586,13 +586,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
# )
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
|
@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
@invocation_output("model_identifier_output")
|
||||
class ModelIdentifierOutput(BaseInvocationOutput):
|
||||
"""Model identifier output"""
|
||||
|
||||
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
|
||||
|
||||
|
||||
@invocation(
|
||||
"model_identifier",
|
||||
title="Model identifier",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ModelIdentifierInvocation(BaseInvocation):
|
||||
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
|
||||
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
|
||||
error."""
|
||||
|
||||
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.3",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
|
||||
)
|
||||
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@ -190,6 +218,75 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation_output("lora_selector_output")
|
||||
class LoRASelectorOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
|
||||
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
|
||||
|
||||
|
||||
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
|
||||
class LoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
loras: LoRAField | list[LoRAField] = InputField(
|
||||
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
output = LoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
for lora in loras:
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.unet is not None:
|
||||
if output.unet is None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(lora)
|
||||
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
@ -204,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.3",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
@ -279,12 +376,78 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
|
||||
@invocation(
|
||||
"sdxl_lora_collection_loader",
|
||||
title="SDXL LoRA Collection Loader",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
loras: LoRAField | list[LoRAField] = InputField(
|
||||
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
output = SDXLLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
for lora in loras:
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base is BaseModelType.StableDiffusionXL
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.unet is not None:
|
||||
if output.unet is None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(lora)
|
||||
|
||||
if self.clip is not None:
|
||||
if output.clip is None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(lora)
|
||||
|
||||
if self.clip2 is not None:
|
||||
if output.clip2 is None:
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2.loras.append(lora)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
|
||||
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.3",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
|
@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||
@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
t2i_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
)
|
||||
|
@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
self._invoker.services.events.emit_bulk_download_complete(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
self._invoker.services.events.emit_bulk_download_error(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
|
@ -8,14 +8,13 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -30,6 +29,9 @@ from .download_base import (
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
@ -343,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
self._event_bus.emit_download_started(job)
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
@ -355,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
self._event_bus.emit_download_progress(job)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
@ -373,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
self._event_bus.emit_download_complete(job)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
@ -390,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@ -403,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
self._event_bus.emit_download_error(job)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
|
@ -1,486 +1,253 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
EventBase,
|
||||
ExtraData,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
SessionStartedEvent,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
def dispatch(self, event: "EventBase") -> None:
|
||||
pass
|
||||
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
# region: Invocation
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_download_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
def emit_invocation_started(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
"""Emitted when an invocation is started"""
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
|
||||
|
||||
def emit_invocation_denoise_progress(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: "ProgressImage",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted at each step during denoising of an invocation."""
|
||||
self.dispatch(
|
||||
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
output: "BaseInvocationOutput",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"result": result,
|
||||
},
|
||||
)
|
||||
"""Emitted when an invocation is complete"""
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
error_type: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
"""Emitted when an invocation encounters an error"""
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
|
||||
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
},
|
||||
)
|
||||
# endregion
|
||||
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
# region Session
|
||||
|
||||
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has started"""
|
||||
self.dispatch(SessionStartedEvent.build(queue_item, extra))
|
||||
|
||||
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
|
||||
|
||||
# endregion
|
||||
|
||||
# region Queue
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self,
|
||||
session_queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
queue_item: "SessionQueueItem",
|
||||
batch_status: "BatchStatus",
|
||||
queue_status: "SessionQueueStatus",
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload={
|
||||
"queue_id": queue_status.queue_id,
|
||||
"queue_item": {
|
||||
"queue_id": session_queue_item.queue_id,
|
||||
"item_id": session_queue_item.item_id,
|
||||
"status": session_queue_item.status,
|
||||
"batch_id": session_queue_item.batch_id,
|
||||
"session_id": session_queue_item.session_id,
|
||||
"error": session_queue_item.error,
|
||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload={
|
||||
"queue_id": enqueue_result.queue_id,
|
||||
"batch_id": enqueue_result.batch.batch_id,
|
||||
"enqueued": enqueue_result.enqueued,
|
||||
},
|
||||
)
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id, extra))
|
||||
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
"""
|
||||
Emit when a download job is started.
|
||||
# endregion
|
||||
|
||||
:param url: The downloaded url
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_started",
|
||||
payload={"source": source, "download_path": download_path},
|
||||
)
|
||||
# region Download
|
||||
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit "download_progress" events at regular intervals during a download job.
|
||||
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(job, extra))
|
||||
|
||||
:param source: The downloaded source
|
||||
:param download_path: The local downloaded file
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: The size of the file being downloaded (if known)
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"current_bytes": current_bytes,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(job, extra))
|
||||
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit a "download_complete" event at the end of a successful download.
|
||||
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(job, extra))
|
||||
|
||||
:param source: Source URL
|
||||
:param download_path: Path to the locally downloaded file
|
||||
:param total_bytes: The size of the downloaded file
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_complete",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(job, extra))
|
||||
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
|
||||
self.__emit_download_event(
|
||||
event_name="download_cancelled",
|
||||
payload={
|
||||
"source": source,
|
||||
},
|
||||
)
|
||||
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(job, extra))
|
||||
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
"""
|
||||
Emit a "download_error" event when an download job encounters an exception.
|
||||
# endregion
|
||||
|
||||
:param source: Source URL
|
||||
:param error_type: The name of the exception that raised the error
|
||||
:param error: The traceback from this error
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
# region Model loading
|
||||
|
||||
def emit_model_install_downloading(
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
source: str,
|
||||
local_path: str,
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
|
||||
|
||||
:param source: Source of the model
|
||||
:param local_path: Where model is downloading to
|
||||
:param parts: Progress of downloading URLs that comprise the model, if any.
|
||||
:param bytes: Number of bytes downloaded so far.
|
||||
:param total_bytes: Total size of download, including all files.
|
||||
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloading",
|
||||
payload={
|
||||
"source": source,
|
||||
"local_path": local_path,
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
def emit_model_load_complete(
|
||||
self,
|
||||
config: "AnyModelConfig",
|
||||
submodel_type: Optional["SubModelType"] = None,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
|
||||
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
# endregion
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
# region Model install
|
||||
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_running",
|
||||
payload={"source": source},
|
||||
)
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted once when an install job is started (after any download)."""
|
||||
self.dispatch(ModelInstallStartedEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
:param total_bytes: Size of the model (may be None for installation of a local path)
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is completed successfully."""
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job is cancelled."""
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source, "id": id},
|
||||
)
|
||||
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
|
||||
"""Emitted when an install job encounters an exception."""
|
||||
self.dispatch(ModelInstallErrorEvent.build(job, extra))
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
# endregion
|
||||
|
||||
:param source: Source of the model
|
||||
:param error_type: The name of the exception
|
||||
:param error: A text description of the exception
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
# region Bulk image download
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(
|
||||
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
def emit_bulk_download_complete(
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(
|
||||
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
def emit_bulk_download_error(
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"error": error,
|
||||
},
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
707
invokeai/app/services/events/events_common.py
Normal file
707
invokeai/app/services/events/events_common.py
Normal file
@ -0,0 +1,707 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
ExtraData: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
"""Base class for all events. All events must inherit from this class.
|
||||
|
||||
Events must define a class attribute `__event_name__` to identify the event.
|
||||
|
||||
All other attributes should be defined as normal for a pydantic model.
|
||||
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
|
||||
@classmethod
|
||||
def get_events(cls) -> set[type["EventBase"]]:
|
||||
"""Get a set of all event models."""
|
||||
|
||||
event_subclasses: set[type["EventBase"]] = set()
|
||||
for subclass in cls.__subclasses__():
|
||||
# We only want to include subclasses that are event models, not intermediary classes
|
||||
if hasattr(subclass, "__event_name__"):
|
||||
event_subclasses.add(subclass)
|
||||
event_subclasses.update(subclass.get_events())
|
||||
|
||||
return event_subclasses
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=EventBase)
|
||||
|
||||
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
||||
"""
|
||||
A tuple representing a `fastapi-events` event, with the event name and payload.
|
||||
Provide a generic type to `TEvent` to specify the payload type.
|
||||
"""
|
||||
|
||||
|
||||
class FastAPIEventFunc(Protocol):
|
||||
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
|
||||
|
||||
|
||||
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
|
||||
"""Register a function to handle a list of events.
|
||||
|
||||
:param events: A list of event classes to handle
|
||||
:param func: The function to handle the events
|
||||
"""
|
||||
for event in events:
|
||||
assert hasattr(event, "__event_name__")
|
||||
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
|
||||
|
||||
class QueueEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
class QueueItemEventBase(QueueEventBase):
|
||||
"""Base class for queue item events"""
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
|
||||
|
||||
class SessionEventBase(QueueItemEventBase):
|
||||
"""Base class for session (aka graph execution state) events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
|
||||
|
||||
class InvocationEventBase(SessionEventBase):
|
||||
"""Base class for invocation events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation_id: str = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
invocation_type: str = Field(description="The type of invocation")
|
||||
|
||||
|
||||
class InvocationStartedEvent(InvocationEventBase):
|
||||
"""Event model for invocation_started"""
|
||||
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
|
||||
) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
"""Event model for invocation_denoise_progress"""
|
||||
|
||||
__event_name__ = "invocation_denoise_progress"
|
||||
|
||||
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
|
||||
step: int = Field(description="The current step of the invocation")
|
||||
total_steps: int = Field(description="The total number of steps in the invocation")
|
||||
order: int = Field(description="The order of the invocation in the session")
|
||||
percentage: float = Field(description="The percentage of completion of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
step = intermediate_state.step
|
||||
total_steps = intermediate_state.total_steps
|
||||
order = intermediate_state.order
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
order=order,
|
||||
percentage=cls.calc_percentage(step, total_steps, order),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
|
||||
"""Calculate the percentage of completion of denoising."""
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if scheduler_order == 2:
|
||||
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
|
||||
# order == 1
|
||||
return (step + 1 + 1) / (total_steps + 1)
|
||||
|
||||
|
||||
class InvocationCompleteEvent(InvocationEventBase):
|
||||
"""Event model for invocation_complete"""
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
result: BaseInvocationOutput,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
result=result,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class InvocationErrorEvent(InvocationEventBase):
|
||||
"""Event model for invocation_error"""
|
||||
|
||||
__event_name__ = "invocation_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
error_type: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "InvocationErrorEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation_id=invocation.id,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
invocation_type=invocation.get_type(),
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionStartedEvent(SessionEventBase):
|
||||
"""Event model for session_started"""
|
||||
|
||||
__event_name__ = "session_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionCompleteEvent(SessionEventBase):
|
||||
"""Event model for session_complete"""
|
||||
|
||||
__event_name__ = "session_complete"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class SessionCanceledEvent(SessionEventBase):
|
||||
"""Event model for session_canceled"""
|
||||
|
||||
__event_name__ = "session_canceled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
"""Event model for queue_item_status_changed"""
|
||||
|
||||
__event_name__ = "queue_item_status_changed"
|
||||
|
||||
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
|
||||
error: Optional[str] = Field(default=None, description="The error message, if any")
|
||||
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
|
||||
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
|
||||
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
|
||||
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "QueueItemStatusChangedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
status=queue_item.status,
|
||||
error=queue_item.error,
|
||||
created_at=str(queue_item.created_at) if queue_item.created_at else None,
|
||||
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
|
||||
started_at=str(queue_item.started_at) if queue_item.started_at else None,
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BatchEnqueuedEvent(QueueEventBase):
|
||||
"""Event model for batch_enqueued"""
|
||||
|
||||
__event_name__ = "batch_enqueued"
|
||||
|
||||
batch_id: str = Field(description="The ID of the batch")
|
||||
enqueued: int = Field(description="The number of invocations enqueued")
|
||||
requested: int = Field(
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class QueueClearedEvent(QueueEventBase):
|
||||
"""Event model for queue_cleared"""
|
||||
|
||||
__event_name__ = "queue_cleared"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
|
||||
return cls(
|
||||
queue_id=queue_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a download"""
|
||||
|
||||
source: str = Field(description="The source of the download")
|
||||
|
||||
|
||||
class DownloadStartedEvent(DownloadEventBase):
|
||||
"""Event model for download_started"""
|
||||
|
||||
__event_name__ = "download_started"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadProgressEvent(DownloadEventBase):
|
||||
"""Event model for download_progress"""
|
||||
|
||||
__event_name__ = "download_progress"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
current_bytes: int = Field(description="The number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadCompleteEvent(DownloadEventBase):
|
||||
"""Event model for download_complete"""
|
||||
|
||||
__event_name__ = "download_complete"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadCancelledEvent(DownloadEventBase):
|
||||
"""Event model for download_cancelled"""
|
||||
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class DownloadErrorEvent(DownloadEventBase):
|
||||
"""Event model for download_error"""
|
||||
|
||||
__event_name__ = "download_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
|
||||
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadStartedEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
|
||||
) -> "ModelLoadCompleteEvent":
|
||||
return cls(
|
||||
config=config,
|
||||
submodel_type=submodel_type,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_progress"""
|
||||
|
||||
__event_name__ = "model_install_download_progress"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
parts: list[dict[str, int | str]] = Field(
|
||||
description="Progress of downloading URLs that comprise the model, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
"""Emitted once when an install job becomes active."""
|
||||
|
||||
__event_name__ = "model_install_downloads_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_started"""
|
||||
|
||||
__event_name__ = "model_install_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_install_complete"""
|
||||
|
||||
__event_name__ = "model_install_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
key: str = Field(description="Model config record key")
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
key=(job.config_out.key),
|
||||
total_bytes=job.total_bytes,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallCancelledEvent(ModelEventBase):
|
||||
"""Event model for model_install_cancelled"""
|
||||
|
||||
__event_name__ = "model_install_cancelled"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallErrorEvent(ModelEventBase):
|
||||
"""Event model for model_install_error"""
|
||||
|
||||
__event_name__ = "model_install_error"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
error_type: str = Field(description="The name of the exception")
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
error_type=job.error_type,
|
||||
error=job.error,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a bulk image download"""
|
||||
|
||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||
|
||||
|
||||
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_started"""
|
||||
|
||||
__event_name__ = "bulk_download_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_complete"""
|
||||
|
||||
__event_name__ = "bulk_download_complete"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_error"""
|
||||
|
||||
__event_name__ = "bulk_download_error"
|
||||
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
extra: Optional[ExtraData] = None,
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
extra=extra,
|
||||
)
|
46
invokeai/app/services/events/events_fastapievents.py
Normal file
46
invokeai/app/services/events/events_fastapievents.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
|
||||
from .events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
@ -4,9 +4,6 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
@ -33,8 +30,9 @@ class ImageFileStorageBase(ABC):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@ -46,6 +44,11 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
"""Gets the workflow of an image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
"""Gets the graph of an image."""
|
||||
pass
|
||||
|
@ -7,9 +7,7 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
from .image_files_base import ImageFileStorageBase
|
||||
@ -56,8 +54,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
@ -68,13 +67,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
info_dict = {}
|
||||
|
||||
if metadata is not None:
|
||||
metadata_json = metadata.model_dump_json()
|
||||
info_dict["invokeai_metadata"] = metadata_json
|
||||
pnginfo.add_text("invokeai_metadata", metadata_json)
|
||||
info_dict["invokeai_metadata"] = metadata
|
||||
pnginfo.add_text("invokeai_metadata", metadata)
|
||||
if workflow is not None:
|
||||
workflow_json = workflow.model_dump_json()
|
||||
info_dict["invokeai_workflow"] = workflow_json
|
||||
pnginfo.add_text("invokeai_workflow", workflow_json)
|
||||
info_dict["invokeai_workflow"] = workflow
|
||||
pnginfo.add_text("invokeai_workflow", workflow)
|
||||
if graph is not None:
|
||||
info_dict["invokeai_graph"] = graph
|
||||
pnginfo.add_text("invokeai_graph", graph)
|
||||
|
||||
# When saving the image, the image object's info field is not populated. We need to set it
|
||||
image.info = info_dict
|
||||
@ -129,11 +129,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
|
||||
def get_workflow(self, image_name: str) -> str | None:
|
||||
image = self.get(image_name)
|
||||
workflow = image.info.get("invokeai_workflow", None)
|
||||
if workflow is not None:
|
||||
return WorkflowWithoutID.model_validate_json(workflow)
|
||||
if isinstance(workflow, str):
|
||||
return workflow
|
||||
return None
|
||||
|
||||
def get_graph(self, image_name: str) -> str | None:
|
||||
image = self.get(image_name)
|
||||
graph = image.info.get("invokeai_graph", None)
|
||||
if isinstance(graph, str):
|
||||
return graph
|
||||
return None
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
|
@ -80,7 +80,7 @@ class ImageRecordStorageBase(ABC):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
@ -328,10 +328,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = metadata.model_dump_json() if metadata is not None else None
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -358,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata_json,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
|
@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import (
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
@ -51,8 +50,9 @@ class ImageServiceABC(ABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -87,7 +87,12 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
|
@ -5,7 +5,6 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@ -42,8 +41,9 @@ class ImageService(ImageServiceABC):
|
||||
session_id: Optional[str] = None,
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
metadata: Optional[str] = None,
|
||||
workflow: Optional[str] = None,
|
||||
graph: Optional[str] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@ -64,7 +64,7 @@ class ImageService(ImageServiceABC):
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
has_workflow=workflow is not None,
|
||||
has_workflow=workflow is not None or graph is not None,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
@ -75,7 +75,7 @@ class ImageService(ImageServiceABC):
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
|
||||
)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
@ -157,7 +157,7 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image metadata")
|
||||
raise e
|
||||
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
def get_workflow(self, image_name: str) -> Optional[str]:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get_workflow(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
@ -167,6 +167,16 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image workflow")
|
||||
raise
|
||||
|
||||
def get_graph(self, image_name: str) -> Optional[str]:
|
||||
try:
|
||||
return self.__invoker.services.image_files.get_graph(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self.__invoker.services.logger.error("Image file not found")
|
||||
raise
|
||||
except Exception:
|
||||
self.__invoker.services.logger.error("Problem getting image graph")
|
||||
raise
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))
|
||||
|
@ -1,11 +1,13 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from .model_install_common import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
|
@ -1,244 +1,19 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
||||
"""Baseclass definitions for the model installer."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def event_bus(self) -> Optional[EventServiceBase]:
|
||||
def event_bus(self) -> Optional["EventServiceBase"]:
|
||||
"""Return the event service base object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
|
233
invokeai/app/services/model_install/model_install_common.py
Normal file
233
invokeai/app/services/model_install/model_install_common.py
Normal file
@ -0,0 +1,233 @@
|
||||
import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -20,8 +20,8 @@ from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_base import (
|
||||
from .model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
@ -59,6 +58,9 @@ from .model_install_base import (
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return self._record_store
|
||||
|
||||
@property
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
# make the invoker optional here because we don't need it and it
|
||||
@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"Model install started: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_running(str(job.source))
|
||||
self._event_bus.emit_model_install_started(job)
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_downloading(
|
||||
str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"Model download complete: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
self._event_bus.emit_model_install_downloads_complete(job)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
self._event_bus.emit_model_install_error(job)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"Model install canceled: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
|
@ -4,7 +4,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -5,7 +5,6 @@ from typing import Optional, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
@ -51,25 +50,18 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@ -79,40 +71,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
@ -4,11 +4,16 @@ from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
FastAPIEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionCanceledEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -31,8 +36,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
@ -49,6 +52,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event)
|
||||
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@ -67,30 +72,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
|
||||
_event_name, payload = event
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
isinstance(payload, SessionCanceledEvent)
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
and self._queue_item.item_id == payload.item_id
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
isinstance(payload, QueueClearedEvent)
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
and self._queue_item.queue_id == payload.queue_id
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
elif isinstance(payload, BatchEnqueuedEvent):
|
||||
self._poll_now()
|
||||
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||
"completed",
|
||||
"failed",
|
||||
"canceled",
|
||||
]:
|
||||
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ("completed", "failed", "canceled"):
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
@ -139,6 +139,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.events.emit_session_started(self._queue_item)
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
@ -153,16 +154,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
)
|
||||
self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
@ -189,19 +181,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
self._queue_item, self._invocation, outputs
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
# TODO(MM2): I don't think this is ever raised...
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
@ -227,14 +212,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
invocation=self._invocation,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
@ -242,13 +222,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
self._invoker.services.session_queue.set_queue_item_session(
|
||||
self._queue_item.item_id, self._queue_item.session
|
||||
)
|
||||
self._invoker.services.events.emit_session_complete(self._queue_item)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
@ -279,6 +256,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(
|
||||
self._queue_item.item_id, self._queue_item.session
|
||||
)
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
|
||||
@ -103,3 +104,8 @@ class SessionQueueBase(ABC):
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
"""Sets the session for a session queue item. Use this to update the session state."""
|
||||
pass
|
||||
|
@ -2,10 +2,13 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
FastAPIEvent,
|
||||
InvocationErrorEvent,
|
||||
SessionCanceledEvent,
|
||||
SessionCompleteEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@ -27,6 +30,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
@ -41,7 +45,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
|
||||
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
|
||||
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
|
||||
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
@ -51,51 +59,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
_event_name, payload = event
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
queue_item = self.get_queue_item(payload.item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
_event_name, payload = event
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
pass
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
_event_name, payload = event
|
||||
queue_item = self.get_queue_item(payload.item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
self._set_queue_item_status(item_id=payload.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
pass
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
@ -292,11 +284,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
@ -429,12 +417,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
@ -470,18 +453,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@ -521,18 +497,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@ -562,6 +531,29 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
session_json = session.model_dump_json(warnings=False, exclude_none=True)
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET session = ?
|
||||
WHERE item_id = ?
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
|
@ -180,9 +180,9 @@ class ImagesInterface(InvocationContextInterface):
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
metadata_ = metadata
|
||||
elif isinstance(self._data.invocation, WithMetadata):
|
||||
metadata_ = self._data.invocation.metadata
|
||||
metadata_ = metadata.model_dump_json()
|
||||
elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
|
||||
metadata_ = self._data.invocation.metadata.model_dump_json()
|
||||
|
||||
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||
board_id_ = None
|
||||
@ -191,6 +191,14 @@ class ImagesInterface(InvocationContextInterface):
|
||||
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
|
||||
board_id_ = self._data.invocation.board.board_id
|
||||
|
||||
workflow_ = None
|
||||
if self._data.queue_item.workflow:
|
||||
workflow_ = self._data.queue_item.workflow.model_dump_json()
|
||||
|
||||
graph_ = None
|
||||
if self._data.queue_item.session.graph:
|
||||
graph_ = self._data.queue_item.session.graph.model_dump_json()
|
||||
|
||||
return self._services.images.create(
|
||||
image=image,
|
||||
is_intermediate=self._data.invocation.is_intermediate,
|
||||
@ -198,7 +206,8 @@ class ImagesInterface(InvocationContextInterface):
|
||||
board_id=board_id_,
|
||||
metadata=metadata_,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
workflow=self._data.queue_item.workflow,
|
||||
workflow=workflow_,
|
||||
graph=graph_,
|
||||
session_id=self._data.queue_item.session_id,
|
||||
node_id=self._data.invocation.id,
|
||||
)
|
||||
@ -344,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@ -373,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -13,8 +13,36 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
SDXL_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
]
|
||||
SDXL_SMOOTH_MATRIX = [
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
]
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
SD1_5_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
)
|
||||
|
@ -4,5 +4,4 @@ Initialization file for invokeai.backend.image_util methods.
|
||||
|
||||
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -16,6 +16,7 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
|
||||
|
||||
@ -24,30 +25,30 @@ class SafetyChecker:
|
||||
Wrapper around SafetyChecker model.
|
||||
"""
|
||||
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
tried_load: bool = False
|
||||
safety_checker = None
|
||||
|
||||
@classmethod
|
||||
def _load_safety_checker(cls):
|
||||
if cls.tried_load:
|
||||
if cls.safety_checker is not None and cls.feature_extractor is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
model_path = get_config().models_path / CHECKER_PATH
|
||||
if model_path.exists():
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
|
||||
else:
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
||||
cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
|
||||
cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def safety_checker_available(cls) -> bool:
|
||||
return Path(get_config().models_path, CHECKER_PATH).exists()
|
||||
|
||||
@classmethod
|
||||
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||
if not cls.safety_checker_available() and cls.tried_load:
|
||||
return False
|
||||
cls._load_safety_checker()
|
||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||
return False
|
||||
@ -60,3 +61,24 @@ class SafetyChecker:
|
||||
with SilenceWarnings():
|
||||
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
return has_nsfw_concept[0]
|
||||
|
||||
@classmethod
|
||||
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||
if cls.has_nsfw_concept(image):
|
||||
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = cls._get_caution_img()
|
||||
# Center the caution image on the blurred image
|
||||
x = (blurry_image.width - caution.width) // 2
|
||||
y = (blurry_image.height - caution.height) // 2
|
||||
blurry_image.paste(caution, (x, y), caution)
|
||||
image = blurry_image
|
||||
|
||||
return image
|
||||
|
||||
@classmethod
|
||||
def _get_caution_img(cls) -> Image.Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
@ -1,52 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
nn.modules.utils._pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
def configure_model_padding(model, seamless, seamless_axes):
|
||||
"""
|
||||
Modifies the 2D convolution layers to use a circular padding mode based on
|
||||
the `seamless` and `seamless_axes` options.
|
||||
"""
|
||||
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
|
||||
for m in model.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
if seamless:
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
else:
|
||||
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
|
||||
if hasattr(m, "asymmetric_padding_mode"):
|
||||
del m.asymmetric_padding_mode
|
||||
if hasattr(m, "asymmetric_padding"):
|
||||
del m.asymmetric_padding
|
@ -1,89 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
def _conv_forward_asymmetric(self, input, weight, bias):
|
||||
"""
|
||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||
"""
|
||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||
return nn.functional.conv2d(
|
||||
working,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
nn.modules.utils._pair(0),
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
|
||||
try:
|
||||
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence
|
||||
skipped_layers = 1
|
||||
for m_name, m in model.named_modules():
|
||||
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
continue
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
|
||||
# down_blocks.1.resnets.1.conv1
|
||||
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
|
||||
block_num = int(block_num)
|
||||
resnet_num = int(resnet_num)
|
||||
conv_layers: List[torch.nn.Conv2d] = []
|
||||
|
||||
if block_num >= len(model.down_blocks) - skipped_layers:
|
||||
continue
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
conv_layers.append(module)
|
||||
|
||||
# Skip the second resnet (could be configurable)
|
||||
if resnet_num > 0:
|
||||
continue
|
||||
|
||||
# Skip Conv2d layers (could be configurable)
|
||||
if submodule_name == "conv2":
|
||||
continue
|
||||
|
||||
m.asymmetric_padding_mode = {}
|
||||
m.asymmetric_padding = {}
|
||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["x"] = (
|
||||
m._reversed_padding_repeated_twice[0],
|
||||
m._reversed_padding_repeated_twice[1],
|
||||
0,
|
||||
0,
|
||||
)
|
||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||
m.asymmetric_padding["y"] = (
|
||||
0,
|
||||
0,
|
||||
m._reversed_padding_repeated_twice[2],
|
||||
m._reversed_padding_repeated_twice[3],
|
||||
)
|
||||
|
||||
to_restore.append((m, m._conv_forward))
|
||||
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
|
||||
for layer in conv_layers:
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for module, orig_conv_forward in to_restore:
|
||||
module._conv_forward = orig_conv_forward
|
||||
if hasattr(module, "asymmetric_padding_mode"):
|
||||
del module.asymmetric_padding_mode
|
||||
if hasattr(module, "asymmetric_padding"):
|
||||
del module.asymmetric_padding
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
||||
|
@ -10,6 +10,8 @@ module.exports = {
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
|
||||
'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'error',
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
|
3
invokeai/frontend/web/.gitignore
vendored
3
invokeai/frontend/web/.gitignore
vendored
@ -43,4 +43,5 @@ stats.html
|
||||
yalc.lock
|
||||
|
||||
# vitest
|
||||
tsconfig.vitest-temp.json
|
||||
tsconfig.vitest-temp.json
|
||||
coverage/
|
@ -35,6 +35,7 @@
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
"test": "vitest",
|
||||
"test:ui": "vitest --coverage --ui",
|
||||
"test:no-watch": "vitest --no-watch"
|
||||
},
|
||||
"madge": {
|
||||
@ -65,6 +66,7 @@
|
||||
"chakra-react-select": "^4.7.6",
|
||||
"compare-versions": "^6.1.0",
|
||||
"dateformat": "^5.0.3",
|
||||
"fracturedjsonjs": "^4.0.1",
|
||||
"framer-motion": "^11.1.8",
|
||||
"i18next": "^23.11.3",
|
||||
"i18next-http-backend": "^2.5.1",
|
||||
@ -131,6 +133,8 @@
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@vitejs/plugin-react-swc": "^3.6.0",
|
||||
"@vitest/coverage-v8": "^1.5.0",
|
||||
"@vitest/ui": "^1.5.0",
|
||||
"concurrently": "^8.2.2",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^8.57.0",
|
||||
|
155
invokeai/frontend/web/pnpm-lock.yaml
generated
155
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -50,6 +50,9 @@ dependencies:
|
||||
dateformat:
|
||||
specifier: ^5.0.3
|
||||
version: 5.0.3
|
||||
fracturedjsonjs:
|
||||
specifier: ^4.0.1
|
||||
version: 4.0.1
|
||||
framer-motion:
|
||||
specifier: ^11.1.8
|
||||
version: 11.1.8(react-dom@18.3.1)(react@18.3.1)
|
||||
@ -226,6 +229,12 @@ devDependencies:
|
||||
'@vitejs/plugin-react-swc':
|
||||
specifier: ^3.6.0
|
||||
version: 3.6.0(vite@5.2.11)
|
||||
'@vitest/coverage-v8':
|
||||
specifier: ^1.5.0
|
||||
version: 1.6.0(vitest@1.6.0)
|
||||
'@vitest/ui':
|
||||
specifier: ^1.5.0
|
||||
version: 1.6.0(vitest@1.6.0)
|
||||
concurrently:
|
||||
specifier: ^8.2.2
|
||||
version: 8.2.2
|
||||
@ -285,7 +294,7 @@ devDependencies:
|
||||
version: 4.3.2(typescript@5.4.5)(vite@5.2.11)
|
||||
vitest:
|
||||
specifier: ^1.6.0
|
||||
version: 1.6.0(@types/node@20.12.10)
|
||||
version: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
|
||||
packages:
|
||||
|
||||
@ -1676,6 +1685,10 @@ packages:
|
||||
resolution: {integrity: sha512-4iri8i1AqYHJE2DstZYkyEprg6Pq6sKx3xn5FpySk9sNhH7qN2LLlHJCfDTZRILNwQNPD7mATWM0TBui7uC1pA==}
|
||||
dev: true
|
||||
|
||||
/@bcoe/v8-coverage@0.2.3:
|
||||
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
|
||||
dev: true
|
||||
|
||||
/@chakra-ui/accordion@2.3.1(@chakra-ui/system@2.6.2)(framer-motion@10.18.0)(react@18.3.1):
|
||||
resolution: {integrity: sha512-FSXRm8iClFyU+gVaXisOSEw0/4Q+qZbFRiuhIAkVU6Boj0FxAMrlo9a8AV5TuF77rgaHytCdHk0Ng+cyUijrag==}
|
||||
peerDependencies:
|
||||
@ -3632,6 +3645,11 @@ packages:
|
||||
wrap-ansi-cjs: /wrap-ansi@7.0.0
|
||||
dev: true
|
||||
|
||||
/@istanbuljs/schema@0.1.3:
|
||||
resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==}
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/@jest/schemas@29.6.3:
|
||||
resolution: {integrity: sha512-mo5j5X+jIZmJQveBKeS/clAueipV7KgiX1vMgCxam1RNYiqE1w62n0/tJJnHtjW8ZHcQco5gY85jA3mi0L+nSA==}
|
||||
engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0}
|
||||
@ -3819,6 +3837,10 @@ packages:
|
||||
dev: true
|
||||
optional: true
|
||||
|
||||
/@polka/url@1.0.0-next.25:
|
||||
resolution: {integrity: sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==}
|
||||
dev: true
|
||||
|
||||
/@popperjs/core@2.11.8:
|
||||
resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==}
|
||||
dev: false
|
||||
@ -5143,7 +5165,7 @@ packages:
|
||||
dom-accessibility-api: 0.6.3
|
||||
lodash: 4.17.21
|
||||
redent: 3.0.0
|
||||
vitest: 1.6.0(@types/node@20.12.10)
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
dev: true
|
||||
|
||||
/@testing-library/user-event@14.5.2(@testing-library/dom@9.3.4):
|
||||
@ -5822,6 +5844,29 @@ packages:
|
||||
- '@swc/helpers'
|
||||
dev: true
|
||||
|
||||
/@vitest/coverage-v8@1.6.0(vitest@1.6.0):
|
||||
resolution: {integrity: sha512-KvapcbMY/8GYIG0rlwwOKCVNRc0OL20rrhFkg/CHNzncV03TE2XWvO5w9uZYoxNiMEBacAJt3unSOiZ7svePew==}
|
||||
peerDependencies:
|
||||
vitest: 1.6.0
|
||||
dependencies:
|
||||
'@ampproject/remapping': 2.3.0
|
||||
'@bcoe/v8-coverage': 0.2.3
|
||||
debug: 4.3.4
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
istanbul-lib-report: 3.0.1
|
||||
istanbul-lib-source-maps: 5.0.4
|
||||
istanbul-reports: 3.1.7
|
||||
magic-string: 0.30.10
|
||||
magicast: 0.3.4
|
||||
picocolors: 1.0.0
|
||||
std-env: 3.7.0
|
||||
strip-literal: 2.1.0
|
||||
test-exclude: 6.0.0
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
dev: true
|
||||
|
||||
/@vitest/expect@1.3.1:
|
||||
resolution: {integrity: sha512-xofQFwIzfdmLLlHa6ag0dPV8YsnKOCP1KdAeVVh34vSjN2dcUiXYCD9htu/9eM7t8Xln4v03U9HLxLpPlsXdZw==}
|
||||
dependencies:
|
||||
@ -5866,6 +5911,21 @@ packages:
|
||||
tinyspy: 2.2.1
|
||||
dev: true
|
||||
|
||||
/@vitest/ui@1.6.0(vitest@1.6.0):
|
||||
resolution: {integrity: sha512-k3Lyo+ONLOgylctiGovRKy7V4+dIN2yxstX3eY5cWFXH6WP+ooVX79YSyi0GagdTQzLmT43BF27T0s6dOIPBXA==}
|
||||
peerDependencies:
|
||||
vitest: 1.6.0
|
||||
dependencies:
|
||||
'@vitest/utils': 1.6.0
|
||||
fast-glob: 3.3.2
|
||||
fflate: 0.8.2
|
||||
flatted: 3.3.1
|
||||
pathe: 1.1.2
|
||||
picocolors: 1.0.0
|
||||
sirv: 2.0.4
|
||||
vitest: 1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0)
|
||||
dev: true
|
||||
|
||||
/@vitest/utils@1.3.1:
|
||||
resolution: {integrity: sha512-d3Waie/299qqRyHTm2DjADeTaNdNSVsnwHPWrs20JMpjh6eiVq7ggggweO8rc4arhf6rRkWuHKwvxGvejUXZZQ==}
|
||||
dependencies:
|
||||
@ -8518,6 +8578,10 @@ packages:
|
||||
resolution: {integrity: sha512-3yurQZ2hD9VISAhJJP9bpYFNQrHHBXE2JxxjY5aLEcDi46RmAzJE2OC9FAde0yis5ElW0jTTzs0zfg/Cca4XqQ==}
|
||||
dev: true
|
||||
|
||||
/fflate@0.8.2:
|
||||
resolution: {integrity: sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==}
|
||||
dev: true
|
||||
|
||||
/file-entry-cache@6.0.1:
|
||||
resolution: {integrity: sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==}
|
||||
engines: {node: ^10.12.0 || >=12.0.0}
|
||||
@ -8691,6 +8755,10 @@ packages:
|
||||
engines: {node: '>= 0.6'}
|
||||
dev: true
|
||||
|
||||
/fracturedjsonjs@4.0.1:
|
||||
resolution: {integrity: sha512-KMhSx7o45aPVj4w27dwdQyKJkNU8oBqw8UiK/s3VzsQB3+pKQ/3AqG/YOEQblV2BDuYE5dKp0OMf8RDsshrjTA==}
|
||||
dev: false
|
||||
|
||||
/framer-motion@10.18.0(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-oGlDh1Q1XqYPksuTD/usb0I70hq95OUzmL9+6Zd+Hs4XV0oaISBa/UUMSjYiq6m8EUF32132mOJ8xVZS+I0S6w==}
|
||||
peerDependencies:
|
||||
@ -9077,6 +9145,10 @@ packages:
|
||||
resolution: {integrity: sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw==}
|
||||
dev: true
|
||||
|
||||
/html-escaper@2.0.2:
|
||||
resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==}
|
||||
dev: true
|
||||
|
||||
/html-parse-stringify@3.0.1:
|
||||
resolution: {integrity: sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==}
|
||||
dependencies:
|
||||
@ -9506,6 +9578,39 @@ packages:
|
||||
engines: {node: '>=0.10.0'}
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-coverage@3.2.2:
|
||||
resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==}
|
||||
engines: {node: '>=8'}
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-report@3.0.1:
|
||||
resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
make-dir: 4.0.0
|
||||
supports-color: 7.2.0
|
||||
dev: true
|
||||
|
||||
/istanbul-lib-source-maps@5.0.4:
|
||||
resolution: {integrity: sha512-wHOoEsNJTVltaJp8eVkm8w+GVkVNHT2YDYo53YdzQEL2gWm1hBX5cGFR9hQJtuGLebidVX7et3+dmDZrmclduw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
'@jridgewell/trace-mapping': 0.3.25
|
||||
debug: 4.3.4
|
||||
istanbul-lib-coverage: 3.2.2
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
dev: true
|
||||
|
||||
/istanbul-reports@3.1.7:
|
||||
resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==}
|
||||
engines: {node: '>=8'}
|
||||
dependencies:
|
||||
html-escaper: 2.0.2
|
||||
istanbul-lib-report: 3.0.1
|
||||
dev: true
|
||||
|
||||
/iterable-lookahead@1.0.0:
|
||||
resolution: {integrity: sha512-hJnEP2Xk4+44DDwJqUQGdXal5VbyeWLaPyDl2AQc242Zr7iqz4DgpQOrEzglWVMGHMDCkguLHEKxd1+rOsmgSQ==}
|
||||
engines: {node: '>=4'}
|
||||
@ -9905,6 +10010,14 @@ packages:
|
||||
'@jridgewell/sourcemap-codec': 1.4.15
|
||||
dev: true
|
||||
|
||||
/magicast@0.3.4:
|
||||
resolution: {integrity: sha512-TyDF/Pn36bBji9rWKHlZe+PZb6Mx5V8IHCSxk7X4aljM4e/vyDvZZYwHewdVaqiA0nb3ghfHU/6AUpDxWoER2Q==}
|
||||
dependencies:
|
||||
'@babel/parser': 7.24.5
|
||||
'@babel/types': 7.24.5
|
||||
source-map-js: 1.2.0
|
||||
dev: true
|
||||
|
||||
/make-dir@2.1.0:
|
||||
resolution: {integrity: sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==}
|
||||
engines: {node: '>=6'}
|
||||
@ -9920,6 +10033,13 @@ packages:
|
||||
semver: 6.3.1
|
||||
dev: true
|
||||
|
||||
/make-dir@4.0.0:
|
||||
resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==}
|
||||
engines: {node: '>=10'}
|
||||
dependencies:
|
||||
semver: 7.6.0
|
||||
dev: true
|
||||
|
||||
/map-obj@2.0.0:
|
||||
resolution: {integrity: sha512-TzQSV2DiMYgoF5RycneKVUzIa9bQsj/B3tTgsE3dOGqlzHnGIDaC7XBE7grnA+8kZPnfqSGFe95VHc2oc0VFUQ==}
|
||||
engines: {node: '>=4'}
|
||||
@ -10094,6 +10214,11 @@ packages:
|
||||
resolution: {integrity: sha512-iSAJLHYKnX41mKcJKjqvnAN9sf0LMDTXDEvFv+ffuRR9a1MIuXLjMNL6EsnDHSkKLTWNqQQ5uo61P4EbU4NU+Q==}
|
||||
dev: false
|
||||
|
||||
/mrmime@2.0.0:
|
||||
resolution: {integrity: sha512-eu38+hdgojoyq63s+yTpN4XMBdt5l8HhMhc4VKLO9KM5caLIBvUm4thi7fFaxyTmCKeNnXZ5pAlBwCUnhA09uw==}
|
||||
engines: {node: '>=10'}
|
||||
dev: true
|
||||
|
||||
/ms@2.0.0:
|
||||
resolution: {integrity: sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==}
|
||||
dev: true
|
||||
@ -11759,6 +11884,15 @@ packages:
|
||||
engines: {node: '>=14'}
|
||||
dev: true
|
||||
|
||||
/sirv@2.0.4:
|
||||
resolution: {integrity: sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==}
|
||||
engines: {node: '>= 10'}
|
||||
dependencies:
|
||||
'@polka/url': 1.0.0-next.25
|
||||
mrmime: 2.0.0
|
||||
totalist: 3.0.1
|
||||
dev: true
|
||||
|
||||
/sisteransi@1.0.5:
|
||||
resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==}
|
||||
dev: true
|
||||
@ -12184,6 +12318,15 @@ packages:
|
||||
unique-string: 2.0.0
|
||||
dev: true
|
||||
|
||||
/test-exclude@6.0.0:
|
||||
resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==}
|
||||
engines: {node: '>=8'}
|
||||
dependencies:
|
||||
'@istanbuljs/schema': 0.1.3
|
||||
glob: 7.2.3
|
||||
minimatch: 3.1.2
|
||||
dev: true
|
||||
|
||||
/text-table@0.2.0:
|
||||
resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==}
|
||||
dev: true
|
||||
@ -12257,6 +12400,11 @@ packages:
|
||||
engines: {node: '>=0.6'}
|
||||
dev: true
|
||||
|
||||
/totalist@3.0.1:
|
||||
resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==}
|
||||
engines: {node: '>=6'}
|
||||
dev: true
|
||||
|
||||
/tr46@0.0.3:
|
||||
resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==}
|
||||
|
||||
@ -12830,7 +12978,7 @@ packages:
|
||||
fsevents: 2.3.3
|
||||
dev: true
|
||||
|
||||
/vitest@1.6.0(@types/node@20.12.10):
|
||||
/vitest@1.6.0(@types/node@20.12.10)(@vitest/ui@1.6.0):
|
||||
resolution: {integrity: sha512-H5r/dN06swuFnzNFhq/dnz37bPXnq8xB2xB5JOVk8K09rUtoeNN+LHWkoQ0A/i3hvbUKKcCei9KpbxqHMLhLLA==}
|
||||
engines: {node: ^18.0.0 || >=20.0.0}
|
||||
hasBin: true
|
||||
@ -12860,6 +13008,7 @@ packages:
|
||||
'@vitest/runner': 1.6.0
|
||||
'@vitest/snapshot': 1.6.0
|
||||
'@vitest/spy': 1.6.0
|
||||
'@vitest/ui': 1.6.0(vitest@1.6.0)
|
||||
'@vitest/utils': 1.6.0
|
||||
acorn-walk: 8.3.2
|
||||
chai: 4.4.1
|
||||
|
@ -76,7 +76,9 @@
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen",
|
||||
"loglevel": "Protokoll Stufe"
|
||||
"loglevel": "Protokoll Stufe",
|
||||
"selected": "Ausgewählt",
|
||||
"beta": "Beta"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@ -86,7 +88,7 @@
|
||||
"noImagesInGallery": "Keine Bilder in der Galerie",
|
||||
"loading": "Lade",
|
||||
"deleteImage_one": "Lösche Bild",
|
||||
"deleteImage_other": "",
|
||||
"deleteImage_other": "Lösche {{count}} Bilder",
|
||||
"copy": "Kopieren",
|
||||
"download": "Runterladen",
|
||||
"setCurrentImage": "Setze aktuelle Bild",
|
||||
@ -397,7 +399,14 @@
|
||||
"cancel": "Stornieren",
|
||||
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
|
||||
"addModels": "Model hinzufügen",
|
||||
"deleteModelImage": "Lösche Model Bild"
|
||||
"deleteModelImage": "Lösche Model Bild",
|
||||
"hfTokenInvalidErrorMessage": "Falscher oder fehlender HuggingFace Schlüssel.",
|
||||
"huggingFaceRepoID": "HuggingFace Repo ID",
|
||||
"hfToken": "HuggingFace Schlüssel",
|
||||
"hfTokenInvalid": "Falscher oder fehlender HF Schlüssel",
|
||||
"huggingFacePlaceholder": "besitzer/model-name",
|
||||
"hfTokenSaved": "HF Schlüssel gespeichert",
|
||||
"hfTokenUnableToVerify": "Konnte den HF Schlüssel nicht validieren"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@ -686,7 +695,11 @@
|
||||
"hands": "Hände",
|
||||
"dwOpenpose": "DW Openpose",
|
||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
|
||||
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
|
||||
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus",
|
||||
"ipAdapterMethod": "Methode",
|
||||
"composition": "Nur Komposition",
|
||||
"full": "Voll",
|
||||
"style": "Nur Style"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Status",
|
||||
@ -717,7 +730,6 @@
|
||||
"resume": "Wieder aufnehmen",
|
||||
"item": "Auftrag",
|
||||
"notReady": "Warteschlange noch nicht bereit",
|
||||
"queueCountPrediction": "{{promptsCount}} Prompts × {{iterations}} Iterationen -> {{count}} Generationen",
|
||||
"clearQueueAlertDialog": "\"Die Warteschlange leeren\" stoppt den aktuellen Prozess und leert die Warteschlange komplett.",
|
||||
"completedIn": "Fertig in",
|
||||
"cancelBatchSucceeded": "Stapel abgebrochen",
|
||||
|
@ -142,9 +142,11 @@
|
||||
"blue": "Blue",
|
||||
"alpha": "Alpha",
|
||||
"selected": "Selected",
|
||||
"viewer": "Viewer",
|
||||
"tab": "Tab",
|
||||
"close": "Close"
|
||||
"viewing": "Viewing",
|
||||
"viewingDesc": "Review images in a large gallery view",
|
||||
"editing": "Editing",
|
||||
"editingDesc": "Edit on the Control Layers canvas"
|
||||
},
|
||||
"controlnet": {
|
||||
"controlAdapter_one": "Control Adapter",
|
||||
@ -259,7 +261,6 @@
|
||||
"queue": "Queue",
|
||||
"queueFront": "Add to Front of Queue",
|
||||
"queueBack": "Add to Queue",
|
||||
"queueCountPrediction": "{{promptsCount}} prompts \u00d7 {{iterations}} iterations -> {{count}} generations",
|
||||
"queueEmpty": "Queue Empty",
|
||||
"enqueueing": "Queueing Batch",
|
||||
"resume": "Resume",
|
||||
@ -312,7 +313,13 @@
|
||||
"batchFailedToQueue": "Failed to Queue Batch",
|
||||
"graphQueued": "Graph queued",
|
||||
"graphFailedToQueue": "Failed to queue graph",
|
||||
"openQueue": "Open Queue"
|
||||
"openQueue": "Open Queue",
|
||||
"prompts_one": "Prompt",
|
||||
"prompts_other": "Prompts",
|
||||
"iterations_one": "Iteration",
|
||||
"iterations_other": "Iterations",
|
||||
"generations_one": "Generation",
|
||||
"generations_other": "Generations"
|
||||
},
|
||||
"invocationCache": {
|
||||
"invocationCache": "Invocation Cache",
|
||||
@ -365,10 +372,7 @@
|
||||
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
||||
"bulkDownloadFailed": "Download Failed",
|
||||
"problemDeletingImages": "Problem Deleting Images",
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"switchTo": "Switch to {{ tab }} (Z)",
|
||||
"openFloatingViewer": "Open Floating Viewer",
|
||||
"closeFloatingViewer": "Close Floating Viewer"
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted"
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Search Hotkeys",
|
||||
@ -770,10 +774,15 @@
|
||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
|
||||
"missingNode": "Missing invocation node",
|
||||
"missingInvocationTemplate": "Missing invocation template",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"nodePack": "Node pack",
|
||||
"collection": "Collection",
|
||||
"collectionFieldType": "{{name}} Collection",
|
||||
"collectionOrScalarFieldType": "{{name}} Collection|Scalar",
|
||||
"singleFieldType": "{{name}} (Single)",
|
||||
"collectionFieldType": "{{name}} (Collection)",
|
||||
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
|
||||
"colorCodeEdges": "Color-Code Edges",
|
||||
"colorCodeEdgesHelp": "Color-code edges according to their connected fields",
|
||||
"connectionWouldCreateCycle": "Connection would create a cycle",
|
||||
@ -875,6 +884,7 @@
|
||||
"versionUnknown": " Version Unknown",
|
||||
"workflow": "Workflow",
|
||||
"graph": "Graph",
|
||||
"noGraph": "No Graph",
|
||||
"workflowAuthor": "Author",
|
||||
"workflowContact": "Contact",
|
||||
"workflowDescription": "Short Description",
|
||||
@ -935,17 +945,30 @@
|
||||
"noModelSelected": "No model selected",
|
||||
"noPrompts": "No prompts generated",
|
||||
"noNodesInGraph": "No nodes in graph",
|
||||
"systemDisconnected": "System disconnected"
|
||||
"systemDisconnected": "System disconnected",
|
||||
"layer": {
|
||||
"initialImageNoImageSelected": "no initial image selected",
|
||||
"controlAdapterNoModelSelected": "no Control Adapter model selected",
|
||||
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
|
||||
"controlAdapterNoImageSelected": "no Control Adapter image selected",
|
||||
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
|
||||
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
|
||||
"ipAdapterNoModelSelected": "no IP adapter selected",
|
||||
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
|
||||
"ipAdapterNoImageSelected": "no IP Adapter image selected",
|
||||
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
|
||||
"rgNoRegion": "no region selected"
|
||||
}
|
||||
},
|
||||
"maskBlur": "Mask Blur",
|
||||
"negativePromptPlaceholder": "Negative Prompt",
|
||||
"globalNegativePromptPlaceholder": "Global Negative Prompt",
|
||||
"noiseThreshold": "Noise Threshold",
|
||||
"patchmatchDownScaleSize": "Downscale",
|
||||
"perlinNoise": "Perlin Noise",
|
||||
"positivePromptPlaceholder": "Positive Prompt",
|
||||
"globalPositivePromptPlaceholder": "Global Positive Prompt",
|
||||
"iterations": "Iterations",
|
||||
"iterationsWithCount_one": "{{count}} Iteration",
|
||||
"iterationsWithCount_other": "{{count}} Iterations",
|
||||
"scale": "Scale",
|
||||
"scaleBeforeProcessing": "Scale Before Processing",
|
||||
"scaledHeight": "Scaled H",
|
||||
@ -1547,8 +1570,6 @@
|
||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||
"regionalGuidance": "Regional Guidance",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"controlNetLayer": "$t(common.controlNet) $t(unifiedCanvas.layer)",
|
||||
"ipAdapterLayer": "$t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||
"opacity": "Opacity",
|
||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
@ -1559,7 +1580,9 @@
|
||||
"opacityFilter": "Opacity Filter",
|
||||
"clearProcessor": "Clear Processor",
|
||||
"resetProcessor": "Reset Processor to Defaults",
|
||||
"noLayersAdded": "No Layers Added"
|
||||
"noLayersAdded": "No Layers Added",
|
||||
"layers_one": "Layer",
|
||||
"layers_other": "Layers"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
|
@ -25,7 +25,24 @@
|
||||
"areYouSure": "¿Estas seguro?",
|
||||
"batch": "Administrador de lotes",
|
||||
"modelManager": "Administrador de modelos",
|
||||
"communityLabel": "Comunidad"
|
||||
"communityLabel": "Comunidad",
|
||||
"direction": "Dirección",
|
||||
"ai": "Ia",
|
||||
"add": "Añadir",
|
||||
"auto": "Automático",
|
||||
"copyError": "Error $t(gallery.copy)",
|
||||
"details": "Detalles",
|
||||
"or": "o",
|
||||
"checkpoint": "Punto de control",
|
||||
"controlNet": "ControlNet",
|
||||
"aboutHeading": "Sea dueño de su poder creativo",
|
||||
"advanced": "Avanzado",
|
||||
"data": "Fecha",
|
||||
"delete": "Borrar",
|
||||
"copy": "Copiar",
|
||||
"beta": "Beta",
|
||||
"on": "En",
|
||||
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Tamaño de la imagen",
|
||||
@ -443,7 +460,13 @@
|
||||
"previousImage": "Imagen anterior",
|
||||
"nextImage": "Siguiente imagen",
|
||||
"showOptionsPanel": "Mostrar el panel lateral",
|
||||
"menu": "Menú"
|
||||
"menu": "Menú",
|
||||
"showGalleryPanel": "Mostrar panel de galería",
|
||||
"loadMore": "Cargar más",
|
||||
"about": "Acerca de",
|
||||
"createIssue": "Crear un problema",
|
||||
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
|
||||
"mode": "Modo"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomInNodes": "Acercar",
|
||||
@ -456,5 +479,68 @@
|
||||
"reloadNodeTemplates": "Recargar las plantillas de nodos",
|
||||
"loadWorkflow": "Cargar el flujo de trabajo",
|
||||
"downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Agregar panel automáticamente",
|
||||
"changeBoard": "Cambiar el panel",
|
||||
"clearSearch": "Borrar la búsqueda",
|
||||
"deleteBoard": "Borrar el panel",
|
||||
"selectBoard": "Seleccionar un panel",
|
||||
"uncategorized": "Sin categoría",
|
||||
"cancel": "Cancelar",
|
||||
"addBoard": "Agregar un panel",
|
||||
"movingImagesToBoard_one": "Moviendo {{count}} imagen al panel:",
|
||||
"movingImagesToBoard_many": "Moviendo {{count}} imágenes al panel:",
|
||||
"movingImagesToBoard_other": "Moviendo {{count}} imágenes al panel:",
|
||||
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
|
||||
"deleteBoardAndImages": "Borrar el panel y las imágenes",
|
||||
"loading": "Cargando...",
|
||||
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar",
|
||||
"move": "Mover",
|
||||
"menuItemAutoAdd": "Agregar automáticamente a este panel",
|
||||
"searchBoard": "Buscando paneles…",
|
||||
"topMessage": "Este panel contiene imágenes utilizadas en las siguientes funciones:",
|
||||
"downloadBoard": "Descargar panel",
|
||||
"deleteBoardOnly": "Borrar solo el panel",
|
||||
"myBoard": "Mi panel",
|
||||
"noMatching": "No hay paneles que coincidan"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
"title": "Composición",
|
||||
"infillTab": "Relleno"
|
||||
},
|
||||
"generation": {
|
||||
"title": "Generación"
|
||||
},
|
||||
"image": {
|
||||
"title": "Imagen"
|
||||
},
|
||||
"control": {
|
||||
"title": "Control"
|
||||
},
|
||||
"advanced": {
|
||||
"options": "$t(accordions.advanced.title) opciones",
|
||||
"title": "Avanzado"
|
||||
}
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
|
||||
"canvas": "Lienzo",
|
||||
"generation": "Generación",
|
||||
"queue": "Cola",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"workflows": "Flujos de trabajo",
|
||||
"models": "Modelos",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)"
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
"layers_one": "Capa",
|
||||
"layers_many": "Capas",
|
||||
"layers_other": "Capas"
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
"reportBugLabel": "Segnala un errore",
|
||||
"settingsLabel": "Impostazioni",
|
||||
"img2img": "Immagine a Immagine",
|
||||
"unifiedCanvas": "Tela unificata",
|
||||
"unifiedCanvas": "Tela",
|
||||
"nodes": "Flussi di lavoro",
|
||||
"upload": "Caricamento",
|
||||
"load": "Carica",
|
||||
@ -74,7 +74,18 @@
|
||||
"file": "File",
|
||||
"toResolve": "Da risolvere",
|
||||
"add": "Aggiungi",
|
||||
"loglevel": "Livello di log"
|
||||
"loglevel": "Livello di log",
|
||||
"beta": "Beta",
|
||||
"positivePrompt": "Prompt positivo",
|
||||
"negativePrompt": "Prompt negativo",
|
||||
"selected": "Selezionato",
|
||||
"goTo": "Vai a",
|
||||
"editor": "Editor",
|
||||
"tab": "Scheda",
|
||||
"viewing": "Visualizza",
|
||||
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
|
||||
"editing": "Modifica",
|
||||
"editingDesc": "Modifica nell'area Livelli di controllo"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@ -180,8 +191,8 @@
|
||||
"desc": "Mostra le informazioni sui metadati dell'immagine corrente"
|
||||
},
|
||||
"sendToImageToImage": {
|
||||
"title": "Invia a Immagine a Immagine",
|
||||
"desc": "Invia l'immagine corrente a da Immagine a Immagine"
|
||||
"title": "Invia a Generazione da immagine",
|
||||
"desc": "Invia l'immagine corrente a Generazione da immagine"
|
||||
},
|
||||
"deleteImage": {
|
||||
"title": "Elimina immagine",
|
||||
@ -334,6 +345,10 @@
|
||||
"remixImage": {
|
||||
"desc": "Utilizza tutti i parametri tranne il seme dell'immagine corrente",
|
||||
"title": "Remixa l'immagine"
|
||||
},
|
||||
"toggleViewer": {
|
||||
"title": "Attiva/disattiva il visualizzatore di immagini",
|
||||
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@ -471,8 +486,8 @@
|
||||
"scaledHeight": "Altezza ridimensionata",
|
||||
"infillMethod": "Metodo di riempimento",
|
||||
"tileSize": "Dimensione piastrella",
|
||||
"sendToImg2Img": "Invia a Immagine a Immagine",
|
||||
"sendToUnifiedCanvas": "Invia a Tela Unificata",
|
||||
"sendToImg2Img": "Invia a Generazione da immagine",
|
||||
"sendToUnifiedCanvas": "Invia alla Tela",
|
||||
"downloadImage": "Scarica l'immagine",
|
||||
"usePrompt": "Usa Prompt",
|
||||
"useSeed": "Usa Seme",
|
||||
@ -508,13 +523,11 @@
|
||||
"incompatibleBaseModelForControlAdapter": "Il modello dell'adattatore di controllo #{{number}} non è compatibile con il modello principale.",
|
||||
"missingNodeTemplate": "Modello di nodo mancante",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante",
|
||||
"missingFieldTemplate": "Modello di campo mancante"
|
||||
"missingFieldTemplate": "Modello di campo mancante",
|
||||
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
"iterationsWithCount_one": "{{count}} Iterazione",
|
||||
"iterationsWithCount_many": "{{count}} Iterazioni",
|
||||
"iterationsWithCount_other": "{{count}} Iterazioni",
|
||||
"isAllowedToUpscale": {
|
||||
"useX2Model": "L'immagine è troppo grande per l'ampliamento con il modello x4, utilizza il modello x2",
|
||||
"tooLarge": "L'immagine è troppo grande per l'ampliamento, seleziona un'immagine più piccola"
|
||||
@ -534,7 +547,10 @@
|
||||
"infillMosaicMinColor": "Colore minimo",
|
||||
"infillMosaicMaxColor": "Colore massimo",
|
||||
"infillMosaicTileHeight": "Altezza piastrella",
|
||||
"infillColorValue": "Colore di riempimento"
|
||||
"infillColorValue": "Colore di riempimento",
|
||||
"globalSettings": "Impostazioni globali",
|
||||
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@ -559,7 +575,7 @@
|
||||
"intermediatesCleared_one": "Cancellata {{count}} immagine intermedia",
|
||||
"intermediatesCleared_many": "Cancellate {{count}} immagini intermedie",
|
||||
"intermediatesCleared_other": "Cancellate {{count}} immagini intermedie",
|
||||
"clearIntermediatesDesc1": "La cancellazione delle immagini intermedie ripristinerà lo stato di Tela Unificata e ControlNet.",
|
||||
"clearIntermediatesDesc1": "La cancellazione delle immagini intermedie ripristinerà lo stato della Tela e degli Adattatori di Controllo.",
|
||||
"intermediatesClearedFailed": "Problema con la cancellazione delle immagini intermedie",
|
||||
"clearIntermediatesWithCount_one": "Cancella {{count}} immagine intermedia",
|
||||
"clearIntermediatesWithCount_many": "Cancella {{count}} immagini intermedie",
|
||||
@ -575,8 +591,8 @@
|
||||
"imageCopied": "Immagine copiata",
|
||||
"imageNotLoadedDesc": "Impossibile trovare l'immagine",
|
||||
"canvasMerged": "Tela unita",
|
||||
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
||||
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
||||
"sentToImageToImage": "Inviato a Generazione da immagine",
|
||||
"sentToUnifiedCanvas": "Inviato alla Tela",
|
||||
"parametersNotSet": "Parametri non impostati",
|
||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||
"serverError": "Errore del Server",
|
||||
@ -795,7 +811,7 @@
|
||||
"float": "In virgola mobile",
|
||||
"currentImageDescription": "Visualizza l'immagine corrente nell'editor dei nodi",
|
||||
"fieldTypesMustMatch": "I tipi di campo devono corrispondere",
|
||||
"edge": "Bordo",
|
||||
"edge": "Collegamento",
|
||||
"currentImage": "Immagine corrente",
|
||||
"integer": "Numero Intero",
|
||||
"inputMayOnlyHaveOneConnection": "L'ingresso può avere solo una connessione",
|
||||
@ -845,7 +861,9 @@
|
||||
"resetToDefaultValue": "Ripristina il valore predefinito",
|
||||
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
|
||||
"edit": "Modifica",
|
||||
"graph": "Grafico"
|
||||
"graph": "Grafico",
|
||||
"showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati",
|
||||
"showEdgeLabels": "Mostra le etichette del collegamento"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@ -922,7 +940,7 @@
|
||||
"colorMapTileSize": "Dimensione piastrella",
|
||||
"mediapipeFaceDescription": "Rilevamento dei volti tramite Mediapipe",
|
||||
"hedDescription": "Rilevamento dei bordi nidificati olisticamente",
|
||||
"setControlImageDimensions": "Imposta le dimensioni dell'immagine di controllo su L/A",
|
||||
"setControlImageDimensions": "Copia le dimensioni in L/A (ottimizza per il modello)",
|
||||
"maxFaces": "Numero massimo di volti",
|
||||
"addT2IAdapter": "Aggiungi $t(common.t2iAdapter)",
|
||||
"addControlNet": "Aggiungi $t(common.controlNet)",
|
||||
@ -951,12 +969,17 @@
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
|
||||
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision",
|
||||
"ipAdapterMethod": "Metodo",
|
||||
"full": "Completo",
|
||||
"composition": "Solo la composizione",
|
||||
"style": "Solo lo stile",
|
||||
"beginEndStepPercentShort": "Inizio/Fine %",
|
||||
"setControlImageDimensionsForce": "Copia le dimensioni in L/A (ignora il modello)"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
"queueBack": "Aggiungi alla coda",
|
||||
"queueCountPrediction": "{{promptsCount}} prompt × {{iterations}} iterazioni -> {{count}} generazioni",
|
||||
"queue": "Coda",
|
||||
"status": "Stato",
|
||||
"pruneSucceeded": "Rimossi {{item_count}} elementi completati dalla coda",
|
||||
@ -993,7 +1016,7 @@
|
||||
"cancelBatchSucceeded": "Lotto annullato",
|
||||
"clearTooltip": "Annulla e cancella tutti gli elementi",
|
||||
"current": "Attuale",
|
||||
"pauseTooltip": "Sospende l'elaborazione",
|
||||
"pauseTooltip": "Sospendi l'elaborazione",
|
||||
"failed": "Falliti",
|
||||
"cancelItem": "Annulla l'elemento",
|
||||
"next": "Prossimo",
|
||||
@ -1394,6 +1417,12 @@
|
||||
"paragraphs": [
|
||||
"La dimensione del bordo del passaggio di coerenza."
|
||||
]
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"heading": "Metodo",
|
||||
"paragraphs": [
|
||||
"Metodo con cui applicare l'adattatore IP corrente."
|
||||
]
|
||||
}
|
||||
},
|
||||
"sdxl": {
|
||||
@ -1522,5 +1551,56 @@
|
||||
"compatibleEmbeddings": "Incorporamenti compatibili",
|
||||
"addPromptTrigger": "Aggiungi Trigger nel prompt",
|
||||
"noMatchingTriggers": "Nessun Trigger corrispondente"
|
||||
},
|
||||
"controlLayers": {
|
||||
"opacityFilter": "Filtro opacità",
|
||||
"deleteAll": "Cancella tutto",
|
||||
"addLayer": "Aggiungi Livello",
|
||||
"moveToFront": "Sposta in primo piano",
|
||||
"moveToBack": "Sposta in fondo",
|
||||
"moveForward": "Sposta avanti",
|
||||
"moveBackward": "Sposta indietro",
|
||||
"brushSize": "Dimensioni del pennello",
|
||||
"globalMaskOpacity": "Opacità globale della maschera",
|
||||
"autoNegative": "Auto Negativo",
|
||||
"toggleVisibility": "Attiva/disattiva la visibilità dei livelli",
|
||||
"deletePrompt": "Cancella il prompt",
|
||||
"debugLayers": "Debug dei Livelli",
|
||||
"rectangle": "Rettangolo",
|
||||
"maskPreviewColor": "Colore anteprima maschera",
|
||||
"addPositivePrompt": "Aggiungi $t(common.positivePrompt)",
|
||||
"addNegativePrompt": "Aggiungi $t(common.negativePrompt)",
|
||||
"addIPAdapter": "Aggiungi $t(common.ipAdapter)",
|
||||
"regionalGuidance": "Guida regionale",
|
||||
"regionalGuidanceLayer": "$t(unifiedCanvas.layer) $t(controlLayers.regionalGuidance)",
|
||||
"opacity": "Opacità",
|
||||
"globalControlAdapter": "$t(controlnet.controlAdapter_one) Globale",
|
||||
"globalControlAdapterLayer": "$t(controlnet.controlAdapter_one) - $t(unifiedCanvas.layer) Globale",
|
||||
"globalIPAdapter": "$t(common.ipAdapter) Globale",
|
||||
"globalIPAdapterLayer": "$t(common.ipAdapter) - $t(unifiedCanvas.layer) Globale",
|
||||
"globalInitialImage": "Immagine iniziale",
|
||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) - $t(unifiedCanvas.layer) Globale",
|
||||
"clearProcessor": "Cancella processore",
|
||||
"resetProcessor": "Ripristina il processore alle impostazioni predefinite",
|
||||
"noLayersAdded": "Nessun livello aggiunto",
|
||||
"resetRegion": "Reimposta la regione",
|
||||
"controlLayers": "Livelli di controllo",
|
||||
"layers_one": "Livello",
|
||||
"layers_many": "Livelli",
|
||||
"layers_other": "Livelli"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "Generazione",
|
||||
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
|
||||
"canvas": "Tela",
|
||||
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
||||
"workflows": "Flussi di lavoro",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
"models": "Modelli",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Coda",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -570,7 +570,6 @@
|
||||
"pauseSucceeded": "処理が一時停止されました",
|
||||
"queueFront": "キューの先頭へ追加",
|
||||
"queueBack": "キューに追加",
|
||||
"queueCountPrediction": "{{promptsCount}} プロンプト × {{iterations}} イテレーション -> {{count}} 枚生成",
|
||||
"pause": "一時停止",
|
||||
"queue": "キュー",
|
||||
"pauseTooltip": "処理を一時停止",
|
||||
|
@ -505,7 +505,6 @@
|
||||
"completed": "완성된",
|
||||
"queueBack": "Queue에 추가",
|
||||
"cancelFailed": "항목 취소 중 발생한 문제",
|
||||
"queueCountPrediction": "Queue에 {{predicted}} 추가",
|
||||
"batchQueued": "Batch Queued",
|
||||
"pauseFailed": "프로세서 중지 중 발생한 문제",
|
||||
"clearFailed": "Queue 제거 중 발생한 문제",
|
||||
|
@ -383,8 +383,6 @@
|
||||
"useCpuNoise": "Gebruik CPU-ruis",
|
||||
"imageActions": "Afbeeldingshandeling",
|
||||
"iterations": "Iteraties",
|
||||
"iterationsWithCount_one": "{{count}} iteratie",
|
||||
"iterationsWithCount_other": "{{count}} iteraties",
|
||||
"coherenceMode": "Modus"
|
||||
},
|
||||
"settings": {
|
||||
@ -940,7 +938,6 @@
|
||||
"completed": "Voltooid",
|
||||
"queueBack": "Voeg toe aan wachtrij",
|
||||
"cancelFailed": "Fout bij annuleren onderdeel",
|
||||
"queueCountPrediction": "Voeg {{predicted}} toe aan wachtrij",
|
||||
"batchQueued": "Reeks in wachtrij geplaatst",
|
||||
"pauseFailed": "Fout bij onderbreken verwerker",
|
||||
"clearFailed": "Fout bij wissen van wachtrij",
|
||||
|
@ -76,7 +76,18 @@
|
||||
"localSystem": "Локальная система",
|
||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
||||
"add": "Добавить",
|
||||
"loglevel": "Уровень логов"
|
||||
"loglevel": "Уровень логов",
|
||||
"beta": "Бета",
|
||||
"selected": "Выбрано",
|
||||
"positivePrompt": "Позитивный запрос",
|
||||
"negativePrompt": "Негативный запрос",
|
||||
"editor": "Редактор",
|
||||
"goTo": "Перейти к",
|
||||
"tab": "Вкладка",
|
||||
"viewing": "Просмотр",
|
||||
"editing": "Редактирование",
|
||||
"viewingDesc": "Просмотр изображений в режиме большой галереи",
|
||||
"editingDesc": "Редактировать на холсте слоёв управления"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@ -87,8 +98,8 @@
|
||||
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
|
||||
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
|
||||
"deleteImage_one": "Удалить изображение",
|
||||
"deleteImage_few": "",
|
||||
"deleteImage_many": "",
|
||||
"deleteImage_few": "Удалить {{count}} изображения",
|
||||
"deleteImage_many": "Удалить {{count}} изображений",
|
||||
"assets": "Ресурсы",
|
||||
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
|
||||
"deleteSelection": "Удалить выделенное",
|
||||
@ -336,6 +347,10 @@
|
||||
"remixImage": {
|
||||
"desc": "Используйте все параметры, кроме сида из текущего изображения",
|
||||
"title": "Ремикс изображения"
|
||||
},
|
||||
"toggleViewer": {
|
||||
"title": "Переключить просмотр изображений",
|
||||
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@ -512,7 +527,8 @@
|
||||
"missingNodeTemplate": "Отсутствует шаблон узла",
|
||||
"missingFieldTemplate": "Отсутствует шаблон поля",
|
||||
"addingImagesTo": "Добавление изображений в",
|
||||
"invoke": "Создать"
|
||||
"invoke": "Создать",
|
||||
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
|
||||
},
|
||||
"isAllowedToUpscale": {
|
||||
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
|
||||
@ -523,9 +539,6 @@
|
||||
"useCpuNoise": "Использовать шум CPU",
|
||||
"imageActions": "Действия с изображениями",
|
||||
"iterations": "Кол-во",
|
||||
"iterationsWithCount_one": "{{count}} Интеграция",
|
||||
"iterationsWithCount_few": "{{count}} Итерации",
|
||||
"iterationsWithCount_many": "{{count}} Итераций",
|
||||
"useSize": "Использовать размер",
|
||||
"coherenceMode": "Режим",
|
||||
"aspect": "Соотношение",
|
||||
@ -541,7 +554,10 @@
|
||||
"infillMosaicTileHeight": "Высота плиток",
|
||||
"infillMosaicMinColor": "Мин цвет",
|
||||
"infillMosaicMaxColor": "Макс цвет",
|
||||
"infillColorValue": "Цвет заливки"
|
||||
"infillColorValue": "Цвет заливки",
|
||||
"globalSettings": "Глобальные настройки",
|
||||
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
|
||||
"globalPositivePromptPlaceholder": "Глобальный запрос"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Модели",
|
||||
@ -706,7 +722,9 @@
|
||||
"coherenceModeBoxBlur": "коробчатое размытие",
|
||||
"discardCurrent": "Отбросить текущее",
|
||||
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти",
|
||||
"initialFitImageSize": "Подогнать размер изображения при перебросе"
|
||||
"initialFitImageSize": "Подогнать размер изображения при перебросе",
|
||||
"hideBoundingBox": "Скрыть ограничительную рамку",
|
||||
"showBoundingBox": "Показать ограничительную рамку"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Загрузить изображение",
|
||||
@ -849,7 +867,10 @@
|
||||
"editMode": "Открыть в редакторе узлов",
|
||||
"resetToDefaultValue": "Сбросить к стандартному значкнию",
|
||||
"edit": "Редактировать",
|
||||
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений."
|
||||
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
|
||||
"graph": "График",
|
||||
"showEdgeLabels": "Показать метки на ребрах",
|
||||
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
|
||||
},
|
||||
"controlnet": {
|
||||
"amult": "a_mult",
|
||||
@ -917,8 +938,8 @@
|
||||
"lineartAnime": "Контурный рисунок в стиле аниме",
|
||||
"mediapipeFaceDescription": "Обнаружение лиц с помощью Mediapipe",
|
||||
"hedDescription": "Целостное обнаружение границ",
|
||||
"setControlImageDimensions": "Установите размеры контрольного изображения на Ш/В",
|
||||
"scribble": "каракули",
|
||||
"setControlImageDimensions": "Скопируйте размер в Ш/В (оптимизируйте для модели)",
|
||||
"scribble": "Штрихи",
|
||||
"maxFaces": "Макс Лица",
|
||||
"mlsdDescription": "Минималистичный детектор отрезков линии",
|
||||
"resizeSimple": "Изменить размер (простой)",
|
||||
@ -933,7 +954,18 @@
|
||||
"small": "Маленький",
|
||||
"body": "Тело",
|
||||
"hands": "Руки",
|
||||
"selectCLIPVisionModel": "Выбрать модель CLIP Vision"
|
||||
"selectCLIPVisionModel": "Выбрать модель CLIP Vision",
|
||||
"ipAdapterMethod": "Метод",
|
||||
"full": "Всё",
|
||||
"mlsd": "M-LSD",
|
||||
"h": "H",
|
||||
"style": "Только стиль",
|
||||
"dwOpenpose": "DW Openpose",
|
||||
"pidi": "PIDI",
|
||||
"composition": "Только композиция",
|
||||
"hed": "HED",
|
||||
"beginEndStepPercentShort": "Начало/конец %",
|
||||
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Авто добавление Доски",
|
||||
@ -1312,6 +1344,12 @@
|
||||
"paragraphs": [
|
||||
"Плавно укладывайте изображение вдоль вертикальной оси."
|
||||
]
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"heading": "Метод",
|
||||
"paragraphs": [
|
||||
"Метод, с помощью которого применяется текущий IP-адаптер."
|
||||
]
|
||||
}
|
||||
},
|
||||
"metadata": {
|
||||
@ -1359,7 +1397,6 @@
|
||||
"completed": "Выполнено",
|
||||
"queueBack": "Добавить в очередь",
|
||||
"cancelFailed": "Проблема с отменой элемента",
|
||||
"queueCountPrediction": "{{promptsCount}} запросов × {{iterations}} изображений -> {{count}} генераций",
|
||||
"batchQueued": "Пакетная очередь",
|
||||
"pauseFailed": "Проблема с приостановкой рендеринга",
|
||||
"clearFailed": "Проблема с очисткой очереди",
|
||||
@ -1475,7 +1512,11 @@
|
||||
"projectWorkflows": "Рабочие процессы проекта",
|
||||
"defaultWorkflows": "Стандартные рабочие процессы",
|
||||
"name": "Имя",
|
||||
"noRecentWorkflows": "Нет последних рабочих процессов"
|
||||
"noRecentWorkflows": "Нет последних рабочих процессов",
|
||||
"loadWorkflow": "Рабочий процесс $t(common.load)",
|
||||
"convertGraph": "Конвертировать график",
|
||||
"loadFromGraph": "Загрузка рабочего процесса из графика",
|
||||
"autoLayout": "Автоматическое расположение"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Включить исправление высокого разрешения",
|
||||
@ -1528,5 +1569,56 @@
|
||||
"addPromptTrigger": "Добавить триггер запроса",
|
||||
"compatibleEmbeddings": "Совместимые встраивания",
|
||||
"noMatchingTriggers": "Нет соответствующих триггеров"
|
||||
},
|
||||
"controlLayers": {
|
||||
"moveToBack": "На задний план",
|
||||
"moveForward": "Переместить вперёд",
|
||||
"moveBackward": "Переместить назад",
|
||||
"brushSize": "Размер кисти",
|
||||
"controlLayers": "Слои управления",
|
||||
"globalMaskOpacity": "Глобальная непрозрачность маски",
|
||||
"autoNegative": "Авто негатив",
|
||||
"deletePrompt": "Удалить запрос",
|
||||
"resetRegion": "Сбросить регион",
|
||||
"debugLayers": "Слои отладки",
|
||||
"rectangle": "Прямоугольник",
|
||||
"maskPreviewColor": "Цвет предпросмотра маски",
|
||||
"addNegativePrompt": "Добавить $t(common.negativePrompt)",
|
||||
"regionalGuidance": "Региональная точность",
|
||||
"opacity": "Непрозрачность",
|
||||
"globalControlAdapter": "Глобальный $t(controlnet.controlAdapter_one)",
|
||||
"globalControlAdapterLayer": "Глобальный $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||
"globalIPAdapter": "Глобальный $t(common.ipAdapter)",
|
||||
"globalIPAdapterLayer": "Глобальный $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||
"opacityFilter": "Фильтр непрозрачности",
|
||||
"deleteAll": "Удалить всё",
|
||||
"addLayer": "Добавить слой",
|
||||
"moveToFront": "На передний план",
|
||||
"toggleVisibility": "Переключить видимость слоя",
|
||||
"addPositivePrompt": "Добавить $t(common.positivePrompt)",
|
||||
"addIPAdapter": "Добавить $t(common.ipAdapter)",
|
||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||
"resetProcessor": "Сброс процессора по умолчанию",
|
||||
"clearProcessor": "Чистый процессор",
|
||||
"globalInitialImage": "Глобальное исходное изображение",
|
||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
||||
"noLayersAdded": "Без слоев",
|
||||
"layers_one": "Слой",
|
||||
"layers_few": "Слоя",
|
||||
"layers_many": "Слоев"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "Генерация",
|
||||
"canvas": "Холст",
|
||||
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
|
||||
"models": "Модели",
|
||||
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
|
||||
"workflows": "Рабочие процессы",
|
||||
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Очередь"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,7 @@
|
||||
"saveAs": "保存为",
|
||||
"ai": "ai",
|
||||
"or": "或",
|
||||
"aboutDesc": "使用 Invoke 工作?查看:",
|
||||
"aboutDesc": "使用 Invoke 工作?来看看:",
|
||||
"add": "添加",
|
||||
"loglevel": "日志级别",
|
||||
"copy": "复制",
|
||||
@ -445,7 +445,6 @@
|
||||
"useX2Model": "图像太大,无法使用 x4 模型,使用 x2 模型作为替代",
|
||||
"tooLarge": "图像太大无法进行放大,请选择更小的图像"
|
||||
},
|
||||
"iterationsWithCount_other": "{{count}} 次迭代生成",
|
||||
"cfgRescaleMultiplier": "CFG 重缩放倍数",
|
||||
"useSize": "使用尺寸",
|
||||
"setToOptimalSize": "优化模型大小",
|
||||
@ -853,7 +852,6 @@
|
||||
"pruneSucceeded": "从队列修剪 {{item_count}} 个已完成的项目",
|
||||
"notReady": "无法排队",
|
||||
"batchFailedToQueue": "批次加入队列失败",
|
||||
"queueCountPrediction": "{{promptsCount}} 提示词 × {{iterations}} 迭代次数 -> {{count}} 次生成",
|
||||
"batchQueued": "加入队列的批次",
|
||||
"front": "前",
|
||||
"pruneTooltip": "修剪 {{item_count}} 个已完成的项目",
|
||||
|
@ -1,3 +1,4 @@
|
||||
/* eslint-disable no-console */
|
||||
import fs from 'node:fs';
|
||||
|
||||
import openapiTS from 'openapi-typescript';
|
||||
|
@ -21,6 +21,7 @@ import i18n from 'i18n';
|
||||
import { size } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import PreselectedImage from './PreselectedImage';
|
||||
@ -46,6 +47,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
||||
useSocketIO();
|
||||
useGlobalModifiersInit();
|
||||
useGlobalHotkeys();
|
||||
useGetOpenAPISchemaQuery();
|
||||
|
||||
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
|
||||
|
||||
|
@ -67,6 +67,8 @@ export const useSocketIO = () => {
|
||||
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
window.$socketOptions = $socketOptions;
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
/* eslint-disable-next-line no-console */
|
||||
console.log('Socket initialized', socket);
|
||||
}
|
||||
|
||||
@ -75,6 +77,8 @@ export const useSocketIO = () => {
|
||||
return () => {
|
||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||
window.$socketOptions = undefined;
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
/* eslint-disable-next-line no-console */
|
||||
console.log('Socket teardown', socket);
|
||||
}
|
||||
socket.disconnect();
|
||||
|
@ -1,3 +1,6 @@
|
||||
/* eslint-disable no-console */
|
||||
// This is only enabled manually for debugging, console is allowed.
|
||||
|
||||
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { diff } from 'jsondiffpatch';
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
import type { Graph } from 'services/api/types';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@ -25,13 +24,6 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||
};
|
||||
}
|
||||
|
||||
if (nodeTemplatesBuilt.match(action)) {
|
||||
return {
|
||||
...action,
|
||||
payload: '<Node templates omitted>',
|
||||
};
|
||||
}
|
||||
|
||||
if (socketGeneratorProgress.match(action)) {
|
||||
const sanitized = deepClone(action);
|
||||
if (sanitized.payload.data.progress_image) {
|
||||
|
@ -35,28 +35,23 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
|
||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
|
||||
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@ -110,12 +105,8 @@ addInvocationErrorEventListener(startAppListening);
|
||||
addInvocationStartedEventListener(startAppListening);
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
addSocketDisconnectedEventListener(startAppListening);
|
||||
addSocketSubscribedEventListener(startAppListening);
|
||||
addSocketUnsubscribedEventListener(startAppListening);
|
||||
addModelLoadEventListener(startAppListening);
|
||||
addModelInstallEventListener(startAppListening);
|
||||
addSessionRetrievalErrorEventListener(startAppListening);
|
||||
addInvocationRetrievalErrorEventListener(startAppListening);
|
||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
||||
addBulkDownloadListeners(startAppListening);
|
||||
|
||||
|
@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
|
||||
|
||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
||||
deleted_images.forEach((image_name) => {
|
||||
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name);
|
||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
||||
|
||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
||||
dispatch(resetCanvas());
|
||||
|
@ -6,8 +6,8 @@ import { toast } from 'common/util/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
@ -56,7 +56,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadCompleted,
|
||||
actionCreator: socketBulkDownloadComplete,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
||||
|
||||
@ -89,7 +89,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadFailed,
|
||||
actionCreator: socketBulkDownloadError,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
@ -43,6 +44,9 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
|
||||
type: 'TOAST',
|
||||
toastOptions: { title: t('toast.canvasSavedGallery') },
|
||||
},
|
||||
metadata: {
|
||||
_canvas_objects: parseify(state.canvas.layerState.objects),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
|
@ -1,13 +1,15 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch } from 'app/store/store';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
caLayerImageChanged,
|
||||
caLayerIsProcessingImageChanged,
|
||||
caLayerModelChanged,
|
||||
caLayerProcessedImageChanged,
|
||||
caLayerProcessorConfigChanged,
|
||||
caLayerProcessorPendingBatchIdChanged,
|
||||
caLayerRecalled,
|
||||
isControlAdapterLayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||
@ -15,46 +17,41 @@ import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getImageDTO } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged);
|
||||
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
const log = logger('session');
|
||||
|
||||
/**
|
||||
* Simple helper to cancel a batch and reset the pending batch ID
|
||||
*/
|
||||
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
|
||||
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
|
||||
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
|
||||
try {
|
||||
await req.unwrap();
|
||||
} catch {
|
||||
// no-op
|
||||
} finally {
|
||||
req.reset();
|
||||
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
||||
}
|
||||
};
|
||||
|
||||
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher,
|
||||
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take }) => {
|
||||
const { layerId } = action.payload;
|
||||
const precheckLayerOriginal = getOriginalState()
|
||||
.controlLayers.present.layers.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId);
|
||||
const precheckLayer = getState()
|
||||
.controlLayers.present.layers.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId);
|
||||
|
||||
// Conditions to bail
|
||||
const layerDoesNotExist = !precheckLayer;
|
||||
const layerHasNoImage = !precheckLayer?.controlAdapter.image;
|
||||
const layerHasNoProcessorConfig = !precheckLayer?.controlAdapter.processorConfig;
|
||||
const layerIsAlreadyProcessingImage = precheckLayer?.controlAdapter.isProcessingImage;
|
||||
const areImageAndProcessorUnchanged =
|
||||
isEqual(precheckLayer?.controlAdapter.image, precheckLayerOriginal?.controlAdapter.image) &&
|
||||
isEqual(precheckLayer?.controlAdapter.processorConfig, precheckLayerOriginal?.controlAdapter.processorConfig);
|
||||
|
||||
if (
|
||||
layerDoesNotExist ||
|
||||
layerHasNoImage ||
|
||||
layerHasNoProcessorConfig ||
|
||||
areImageAndProcessorUnchanged ||
|
||||
layerIsAlreadyProcessingImage
|
||||
) {
|
||||
return;
|
||||
}
|
||||
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
|
||||
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
|
||||
const state = getState();
|
||||
const originalState = getOriginalState();
|
||||
|
||||
// Cancel any in-progress instances of this listener
|
||||
cancelActiveListeners();
|
||||
@ -62,27 +59,55 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
|
||||
// Delay before starting actual work
|
||||
await delay(DEBOUNCE_MS);
|
||||
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: true }));
|
||||
|
||||
// Double-check that we are still eligible for processing
|
||||
const state = getState();
|
||||
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
|
||||
const image = layer?.controlAdapter.image;
|
||||
const config = layer?.controlAdapter.processorConfig;
|
||||
|
||||
// If we have no image or there is no processor config, bail
|
||||
if (!layer || !image || !config) {
|
||||
if (!layer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error...
|
||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config);
|
||||
// We should only process if the processor settings or image have changed
|
||||
const originalLayer = originalState.controlLayers.present.layers
|
||||
.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId);
|
||||
const originalImage = originalLayer?.controlAdapter.image;
|
||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
||||
|
||||
const image = layer.controlAdapter.image;
|
||||
const config = layer.controlAdapter.processorConfig;
|
||||
|
||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
|
||||
// Neither config nor image have changed, we can bail
|
||||
return;
|
||||
}
|
||||
|
||||
if (!image || !config) {
|
||||
// - If we have no image, we have nothing to process
|
||||
// - If we have no processor config, we have nothing to process
|
||||
// Clear the processed image and bail
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
}
|
||||
|
||||
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
|
||||
|
||||
// If there is a pending processor batch, cancel it.
|
||||
if (layer.controlAdapter.processorPendingBatchId) {
|
||||
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
|
||||
}
|
||||
|
||||
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
|
||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: {
|
||||
nodes: {
|
||||
[processorNode.id]: { ...processorNode, is_intermediate: true },
|
||||
[processorNode.id]: {
|
||||
...processorNode,
|
||||
// Control images are always intermediate - do not save to gallery
|
||||
is_intermediate: true,
|
||||
},
|
||||
},
|
||||
edges: [],
|
||||
},
|
||||
@ -90,66 +115,74 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
},
|
||||
};
|
||||
|
||||
// Kick off the processor batch
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
|
||||
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
||||
|
||||
// Wait for the processor node to complete
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === processorNode.id
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === processorNode.id
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
assert(
|
||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||
);
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
// Wait for the ImageDTO to be received
|
||||
const [{ payload }] = await take(
|
||||
(action) =>
|
||||
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
|
||||
);
|
||||
const imageDTO = await getImageDTO(image_name);
|
||||
assert(imageDTO, "Failed to fetch processor output's image DTO");
|
||||
|
||||
const imageDTO = payload as ImageDTO;
|
||||
|
||||
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
|
||||
|
||||
// Update the processed image in the store
|
||||
dispatch(
|
||||
caLayerProcessedImageChanged({
|
||||
layerId,
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
|
||||
}
|
||||
// Whew! We made it. Update the layer with the processed image
|
||||
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
|
||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
|
||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
|
||||
if (signal.aborted) {
|
||||
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
|
||||
const pendingBatchId = getState()
|
||||
.controlLayers.present.layers.filter(isControlAdapterLayer)
|
||||
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
|
||||
if (pendingBatchId) {
|
||||
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
|
||||
}
|
||||
log.trace('Control Adapter preprocessor cancelled');
|
||||
} else {
|
||||
// Some other error condition...
|
||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
||||
|
||||
if (error instanceof Object) {
|
||||
if ('data' in error && 'status' in error) {
|
||||
if (error.status === 403) {
|
||||
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
if (error instanceof Object) {
|
||||
if ('data' in error && 'status' in error) {
|
||||
if (error.status === 403) {
|
||||
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.graphFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.graphFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === nodeId
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === nodeId
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
|
@ -8,8 +8,8 @@ import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
@ -1,8 +1,9 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/buildGenerationTabSDXLGraph';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
|
||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
@ -11,12 +12,13 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'generation',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const model = state.generation.model;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
let graph;
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (model?.base === 'sdxl') {
|
||||
graph = await buildGenerationTabSDXLGraph(state);
|
||||
} else {
|
||||
graph = await buildGenerationTabGraph(state);
|
||||
@ -29,7 +31,14 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
try {
|
||||
await req.unwrap();
|
||||
if (shouldShowProgressInViewer) {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -11,9 +11,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { nodes, edges } = state.nodes;
|
||||
const { nodes, edges } = state.nodes.present;
|
||||
const workflow = state.workflow;
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
const graph = buildNodesGraph(state.nodes.present);
|
||||
const builtWorkflow = buildWorkflowWithValidation({
|
||||
nodes,
|
||||
edges,
|
||||
@ -39,7 +39,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
try {
|
||||
await req.unwrap();
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
@ -62,7 +62,6 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
} else {
|
||||
dispatch(selectionChanged([imageDTO]));
|
||||
}
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||
import { size } from 'lodash-es';
|
||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: (action, { getState }) => {
|
||||
const log = logger('system');
|
||||
const schemaJSON = action.payload;
|
||||
|
||||
@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
|
||||
|
||||
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
|
||||
|
||||
dispatch(nodeTemplatesBuilt(nodeTemplates));
|
||||
$templates.set(nodeTemplates);
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -29,7 +29,7 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.nodes.forEach((node) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
@ -73,25 +73,25 @@ const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, ima
|
||||
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.controlLayers.present.layers.forEach((l) => {
|
||||
if (isRegionalGuidanceLayer(l)) {
|
||||
if (l.ipAdapters.some((ipa) => ipa.image?.imageName === imageDTO.image_name)) {
|
||||
if (l.ipAdapters.some((ipa) => ipa.image?.name === imageDTO.image_name)) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (isControlAdapterLayer(l)) {
|
||||
if (
|
||||
l.controlAdapter.image?.imageName === imageDTO.image_name ||
|
||||
l.controlAdapter.processedImage?.imageName === imageDTO.image_name
|
||||
l.controlAdapter.image?.name === imageDTO.image_name ||
|
||||
l.controlAdapter.processedImage?.name === imageDTO.image_name
|
||||
) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (isIPAdapterLayer(l)) {
|
||||
if (l.ipAdapter.image?.imageName === imageDTO.image_name) {
|
||||
if (l.ipAdapter.image?.name === imageDTO.image_name) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
if (isInitialImageLayer(l)) {
|
||||
if (l.image?.imageName === imageDTO.image_name) {
|
||||
if (l.image?.name === imageDTO.image_name) {
|
||||
dispatch(layerDeleted(l.id));
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -9,6 +12,14 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
nes.progressImage = progress_image ?? null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
@ -9,7 +10,9 @@ import {
|
||||
isImageViewerOpenChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@ -26,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
actionCreator: socketInvocationComplete,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
|
||||
|
||||
const { result, node, queue_batch_id } = data;
|
||||
const { result, invocation_source_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
|
||||
const { image_name } = data.result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
@ -45,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -8,7 +11,16 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.error = action.payload.data.error;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketInvocationRetrievalError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationRetrievalError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -8,7 +11,13 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloading,
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallCompleted,
|
||||
actionCreator: socketModelInstallComplete,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
|
||||
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@ -8,10 +8,11 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action) => {
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
@ -23,10 +24,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
actionCreator: socketModelLoadComplete,
|
||||
effect: (action) => {
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
|
@ -1,5 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
|
||||
@ -10,16 +14,23 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: async (action, { dispatch }) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const { queue_item, batch_status, queue_status } = action.payload.data;
|
||||
const { item_id, status, started_at, updated_at, error, completed_at, batch_status, queue_status } =
|
||||
action.payload.data;
|
||||
|
||||
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
|
||||
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
|
||||
|
||||
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||
queueItemsAdapter.updateOne(draft, {
|
||||
id: String(queue_item.item_id),
|
||||
changes: queue_item,
|
||||
id: String(item_id),
|
||||
changes: {
|
||||
status,
|
||||
started_at,
|
||||
updated_at: updated_at ?? undefined,
|
||||
error,
|
||||
completed_at: completed_at ?? undefined,
|
||||
},
|
||||
});
|
||||
})
|
||||
);
|
||||
@ -39,21 +50,31 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
|
||||
);
|
||||
|
||||
// Update the queue item status (this is the full queue item, including the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => {
|
||||
if (!draft) {
|
||||
return;
|
||||
}
|
||||
Object.assign(draft, queue_item);
|
||||
})
|
||||
);
|
||||
|
||||
// Invalidate caches for things we cannot update
|
||||
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
|
||||
dispatch(
|
||||
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
||||
queueApi.util.invalidateTags([
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
])
|
||||
);
|
||||
|
||||
if (['in_progress'].includes(action.payload.data.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
}
|
||||
const clone = deepClone(nes);
|
||||
clone.status = zNodeStatus.enum.PENDING;
|
||||
clone.error = null;
|
||||
clone.progress = null;
|
||||
clone.progressImage = null;
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSessionRetrievalError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSessionRetrievalError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
|
||||
},
|
||||
});
|
||||
};
|
@ -1,14 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSubscribedSession } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Subscribed');
|
||||
},
|
||||
});
|
||||
};
|
@ -1,13 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketUnsubscribedSession } from 'services/events/actions';
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketUnsubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Unsubscribed');
|
||||
},
|
||||
});
|
||||
};
|
@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { updateAllNodesRequested } from 'features/nodes/store/actions';
|
||||
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NodeUpdateError } from 'features/nodes/types/error';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
|
||||
@ -14,7 +14,8 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
actionCreator: updateAllNodesRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const { nodes, templates } = getState().nodes;
|
||||
const { nodes } = getState().nodes.present;
|
||||
const templates = $templates.get();
|
||||
|
||||
let unableToUpdateCount = 0;
|
||||
|
||||
@ -24,13 +25,18 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
|
||||
unableToUpdateCount++;
|
||||
return;
|
||||
}
|
||||
if (!getNeedsUpdate(node, template)) {
|
||||
if (!getNeedsUpdate(node.data, template)) {
|
||||
// No need to increment the count here, since we're not actually updating
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const updatedNode = updateNode(node, template);
|
||||
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
|
||||
dispatch(
|
||||
nodesChanged([
|
||||
{ type: 'remove', id: updatedNode.id },
|
||||
{ type: 'add', item: updatedNode },
|
||||
])
|
||||
);
|
||||
} catch (e) {
|
||||
if (e instanceof NodeUpdateError) {
|
||||
unableToUpdateCount++;
|
||||
|
@ -2,32 +2,51 @@ import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { t } from 'i18next';
|
||||
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
|
||||
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
|
||||
if (data.workflow) {
|
||||
// Prefer to load the workflow if it's available - it has more information
|
||||
const parsed = JSON.parse(data.workflow);
|
||||
return validateWorkflow(parsed, templates);
|
||||
} else if (data.graph) {
|
||||
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
|
||||
const parsed = JSON.parse(data.graph);
|
||||
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
|
||||
return validateWorkflow(workflow, templates);
|
||||
} else {
|
||||
throw new Error('No workflow or graph provided');
|
||||
}
|
||||
};
|
||||
|
||||
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('nodes');
|
||||
const { workflow, asCopy } = action.payload;
|
||||
const nodeTemplates = getState().nodes.templates;
|
||||
const { data, asCopy } = action.payload;
|
||||
const nodeTemplates = $templates.get();
|
||||
|
||||
try {
|
||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
|
||||
|
||||
if (asCopy) {
|
||||
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
|
||||
delete validatedWorkflow.id;
|
||||
delete workflow.id;
|
||||
}
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
dispatch(workflowLoaded(workflow));
|
||||
if (!warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
|
@ -21,7 +21,8 @@ import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/galle
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
@ -50,7 +51,7 @@ const allReducers = {
|
||||
[canvasSlice.name]: canvasSlice.reducer,
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: nodesSlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
@ -66,6 +67,7 @@ const allReducers = {
|
||||
[workflowSlice.name]: workflowSlice.reducer,
|
||||
[hrfSlice.name]: hrfSlice.reducer,
|
||||
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@ -111,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
@ -70,6 +70,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
onMouseOver,
|
||||
onMouseOut,
|
||||
dataTestId,
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
@ -138,6 +139,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
|
@ -13,6 +13,7 @@ type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||
onChange: (value: T | null) => void;
|
||||
getIsDisabled?: (model: T) => boolean;
|
||||
isLoading?: boolean;
|
||||
groupByType?: boolean;
|
||||
};
|
||||
|
||||
type UseGroupedModelComboboxReturn = {
|
||||
@ -23,17 +24,21 @@ type UseGroupedModelComboboxReturn = {
|
||||
noOptionsMessage: () => string;
|
||||
};
|
||||
|
||||
const groupByBaseFunc = <T extends AnyModelConfig>(model: T) => model.base.toUpperCase();
|
||||
const groupByBaseAndTypeFunc = <T extends AnyModelConfig>(model: T) =>
|
||||
`${model.base.toUpperCase()} / ${model.type.replaceAll('_', ' ').toUpperCase()}`;
|
||||
|
||||
export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
arg: UseGroupedModelComboboxArg<T>
|
||||
): UseGroupedModelComboboxReturn => {
|
||||
const { t } = useTranslation();
|
||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, groupByType = false } = arg;
|
||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||
if (!modelConfigs) {
|
||||
return [];
|
||||
}
|
||||
const groupedModels = groupBy(modelConfigs, 'base');
|
||||
const groupedModels = groupBy(modelConfigs, groupByType ? groupByBaseAndTypeFunc : groupByBaseFunc);
|
||||
const _options = reduce(
|
||||
groupedModels,
|
||||
(acc, val, label) => {
|
||||
@ -49,9 +54,9 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
||||
},
|
||||
[] as GroupBase<ComboboxOption>[]
|
||||
);
|
||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||
_options.sort((a) => (a.label?.split('/')[0]?.toLowerCase().includes(base_model) ? -1 : 1));
|
||||
return _options;
|
||||
}, [getIsDisabled, modelConfigs, base_model]);
|
||||
}, [modelConfigs, groupByType, getIsDisabled, base_model]);
|
||||
|
||||
const value = useMemo(
|
||||
() =>
|
||||
|
@ -1,3 +1,4 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
@ -6,187 +7,230 @@ import {
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { selectControlLayersSlice } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { Layer } from 'features/controlLayers/store/types';
|
||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
|
||||
const selector = createMemoizedSelector(
|
||||
[
|
||||
selectControlAdaptersSlice,
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
selectControlLayersSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, dynamicPrompts, controlLayers, activeTabName) => {
|
||||
const { model } = generation;
|
||||
const { positivePrompt } = controlLayers.present;
|
||||
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
|
||||
initial_image_layer: 'controlLayers.globalInitialImage',
|
||||
control_adapter_layer: 'controlLayers.globalControlAdapter',
|
||||
ip_adapter_layer: 'controlLayers.globalIPAdapter',
|
||||
regional_guidance_layer: 'controlLayers.regionalGuidance',
|
||||
};
|
||||
|
||||
const { isConnected } = system;
|
||||
const createSelector = (templates: Templates) =>
|
||||
createMemoizedSelector(
|
||||
[
|
||||
selectControlAdaptersSlice,
|
||||
selectGenerationSlice,
|
||||
selectSystemSlice,
|
||||
selectNodesSlice,
|
||||
selectWorkflowSettingsSlice,
|
||||
selectDynamicPromptsSlice,
|
||||
selectControlLayersSlice,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
|
||||
const { model } = generation;
|
||||
const { size } = controlLayers.present;
|
||||
const { positivePrompt } = controlLayers.present;
|
||||
|
||||
const reasons: string[] = [];
|
||||
const { isConnected } = system;
|
||||
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
reasons.push(i18n.t('parameters.invoke.systemDisconnected'));
|
||||
}
|
||||
const reasons: { prefix?: string; content: string }[] = [];
|
||||
|
||||
if (activeTabName === 'workflows') {
|
||||
if (nodes.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push(i18n.t('parameters.invoke.noNodesInGraph'));
|
||||
// Cannot generate if not connected
|
||||
if (!isConnected) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
|
||||
}
|
||||
|
||||
if (activeTabName === 'workflows') {
|
||||
if (workflowSettings.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = templates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) => edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.missingInputForField', {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (!model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
|
||||
const nodeTemplate = nodes.templates[node.data.type];
|
||||
if (activeTabName === 'generation') {
|
||||
// Handling for generation tab
|
||||
controlLayers.present.layers
|
||||
.filter((l) => l.isEnabled)
|
||||
.forEach((l, i) => {
|
||||
const layerLiteral = i18n.t('controlLayers.layers_one');
|
||||
const layerNumber = i + 1;
|
||||
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
|
||||
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
|
||||
const problems: string[] = [];
|
||||
if (l.type === 'control_adapter_layer') {
|
||||
// Must have model
|
||||
if (!l.controlAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (l.controlAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have a control image OR, if it has a processor, it must have a processed image
|
||||
if (!l.controlAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
|
||||
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
|
||||
}
|
||||
// T2I Adapters require images have dimensions that are multiples of 64 (SD1.5) or 32 (SDXL)
|
||||
if (l.controlAdapter.type === 't2i_adapter') {
|
||||
const multiple = model?.base === 'sdxl' ? 32 : 64;
|
||||
if (size.width % multiple !== 0 || size.height % multiple !== 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push(i18n.t('parameters.invoke.missingNodeTemplate'));
|
||||
return;
|
||||
}
|
||||
if (l.type === 'ip_adapter_layer') {
|
||||
// Must have model
|
||||
if (!l.ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (l.ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!l.ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
if (l.type === 'initial_image_layer') {
|
||||
// Must have an image
|
||||
if (!l.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
|
||||
}
|
||||
}
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) => edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
if (l.type === 'regional_guidance_layer') {
|
||||
// Must have a region
|
||||
if (l.maskObjects.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
|
||||
}
|
||||
// Must have at least 1 prompt or IP Adapter
|
||||
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
|
||||
}
|
||||
l.ipAdapters.forEach((ipAdapter) => {
|
||||
// Must have model
|
||||
if (!ipAdapter.model) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
|
||||
}
|
||||
// Model base must match
|
||||
if (ipAdapter.model?.base !== model?.base) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
|
||||
}
|
||||
// Must have an image
|
||||
if (!ipAdapter.image) {
|
||||
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push(i18n.t('parameters.invoke.missingFieldTemplate'));
|
||||
return;
|
||||
}
|
||||
if (problems.length) {
|
||||
const content = upperFirst(problems.join(', '));
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Handling for all other tabs
|
||||
selectControlAdapterAll(controlAdapters)
|
||||
.filter((ca) => ca.isEnabled)
|
||||
.forEach((ca, i) => {
|
||||
if (!ca.isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.missingInputForField', {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
|
||||
reasons.push(i18n.t('parameters.invoke.noPrompts'));
|
||||
if (!ca.model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
|
||||
} else if (ca.model.base !== model?.base) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
!ca.controlImage ||
|
||||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
|
||||
) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
|
||||
}
|
||||
|
||||
if (activeTabName === 'generation') {
|
||||
// Handling for generation tab
|
||||
controlLayers.present.layers
|
||||
.filter((l) => l.isEnabled)
|
||||
.flatMap((l) => {
|
||||
if (l.type === 'control_adapter_layer') {
|
||||
return l.controlAdapter;
|
||||
} else if (l.type === 'ip_adapter_layer') {
|
||||
return l.ipAdapter;
|
||||
} else if (l.type === 'regional_guidance_layer') {
|
||||
return l.ipAdapters;
|
||||
}
|
||||
return [];
|
||||
})
|
||||
.forEach((ca, i) => {
|
||||
const hasNoModel = !ca.model;
|
||||
const mismatchedModelBase = ca.model?.base !== model?.base;
|
||||
const hasNoImage = !ca.image;
|
||||
const imageNotProcessed =
|
||||
(ca.type === 'controlnet' || ca.type === 't2i_adapter') && !ca.processedImage && ca.processorConfig;
|
||||
|
||||
if (hasNoModel) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.noModelForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (mismatchedModelBase) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (hasNoImage) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.noControlImageForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
if (imageNotProcessed) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.imageNotProcessedForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Handling for all other tabs
|
||||
selectControlAdapterAll(controlAdapters)
|
||||
.filter((ca) => ca.isEnabled)
|
||||
.forEach((ca, i) => {
|
||||
if (!ca.isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ca.model) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.noModelForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
} else if (ca.model.base !== model?.base) {
|
||||
// This should never happen, just a sanity check
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
!ca.controlImage ||
|
||||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
|
||||
) {
|
||||
reasons.push(
|
||||
i18n.t('parameters.invoke.noControlImageForControlAdapter', {
|
||||
number: i + 1,
|
||||
})
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
return { isReady: !reasons.length, reasons };
|
||||
}
|
||||
|
||||
return { isReady: !reasons.length, reasons };
|
||||
}
|
||||
);
|
||||
);
|
||||
|
||||
export const useIsReadyToEnqueue = () => {
|
||||
const { isReady, reasons } = useAppSelector(selector);
|
||||
return { isReady, reasons };
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(() => createSelector(templates), [templates]);
|
||||
const value = useAppSelector(selector);
|
||||
return value;
|
||||
};
|
||||
|
@ -21,8 +21,6 @@ import {
|
||||
setShouldShowBoundingBox,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import type { CanvasLayer } from 'features/canvas/store/canvasTypes';
|
||||
import { LAYER_NAMES_DICT } from 'features/canvas/store/canvasTypes';
|
||||
import { ViewerButton } from 'features/gallery/components/ImageViewer/ViewerButton';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -217,110 +215,107 @@ const IAICanvasToolbar = () => {
|
||||
[dispatch, isMaskEnabled]
|
||||
);
|
||||
|
||||
const value = useMemo(() => LAYER_NAMES_DICT.filter((o) => o.value === layer)[0], [layer]);
|
||||
const layerOptions = useMemo<{ label: string; value: CanvasLayer }[]>(
|
||||
() => [
|
||||
{ label: t('unifiedCanvas.base'), value: 'base' },
|
||||
{ label: t('unifiedCanvas.mask'), value: 'mask' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
const layerValue = useMemo(() => layerOptions.filter((o) => o.value === layer)[0] ?? null, [layer, layerOptions]);
|
||||
|
||||
return (
|
||||
<Flex w="full" gap={2} alignItems="center">
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto" />
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<Tooltip label={`${t('unifiedCanvas.layer')} (Q)`}>
|
||||
<FormControl isDisabled={isStaging} w="5rem">
|
||||
<Combobox value={value} options={LAYER_NAMES_DICT} onChange={handleChangeLayer} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
<Flex alignItems="center" gap={2} flexWrap="wrap">
|
||||
<Tooltip label={`${t('unifiedCanvas.layer')} (Q)`}>
|
||||
<FormControl isDisabled={isStaging} w="5rem">
|
||||
<Combobox value={layerValue} options={layerOptions} onChange={handleChangeLayer} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
|
||||
<IAICanvasMaskOptions />
|
||||
<IAICanvasToolChooserOptions />
|
||||
<IAICanvasMaskOptions />
|
||||
<IAICanvasToolChooserOptions />
|
||||
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.move')} (V)`}
|
||||
tooltip={`${t('unifiedCanvas.move')} (V)`}
|
||||
icon={<PiHandGrabbingBold />}
|
||||
isChecked={tool === 'move' || isStaging}
|
||||
onClick={handleSelectMoveTool}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
tooltip={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
icon={shouldShowBoundingBox ? <PiEyeBold /> : <PiEyeSlashBold />}
|
||||
onClick={handleSetShouldShowBoundingBox}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
tooltip={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
icon={<PiCrosshairSimpleBold />}
|
||||
onClick={handleClickResetCanvasView}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.move')} (V)`}
|
||||
tooltip={`${t('unifiedCanvas.move')} (V)`}
|
||||
icon={<PiHandGrabbingBold />}
|
||||
isChecked={tool === 'move' || isStaging}
|
||||
onClick={handleSelectMoveTool}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
tooltip={`${shouldShowBoundingBox ? t('unifiedCanvas.hideBoundingBox') : t('unifiedCanvas.showBoundingBox')} (Shift + H)`}
|
||||
icon={shouldShowBoundingBox ? <PiEyeBold /> : <PiEyeSlashBold />}
|
||||
onClick={handleSetShouldShowBoundingBox}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
tooltip={`${t('unifiedCanvas.resetView')} (R)`}
|
||||
icon={<PiCrosshairSimpleBold />}
|
||||
onClick={handleClickResetCanvasView}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.mergeVisible')} (Shift+M)`}
|
||||
tooltip={`${t('unifiedCanvas.mergeVisible')} (Shift+M)`}
|
||||
icon={<PiStackBold />}
|
||||
onClick={handleMergeVisible}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
|
||||
tooltip={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
|
||||
icon={<PiFloppyDiskBold />}
|
||||
onClick={handleSaveToGallery}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
{isClipboardAPIAvailable && (
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.mergeVisible')} (Shift+M)`}
|
||||
tooltip={`${t('unifiedCanvas.mergeVisible')} (Shift+M)`}
|
||||
icon={<PiStackBold />}
|
||||
onClick={handleMergeVisible}
|
||||
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
icon={<PiCopyBold />}
|
||||
onClick={handleCopyImageToClipboard}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
|
||||
tooltip={`${t('unifiedCanvas.saveToGallery')} (Shift+S)`}
|
||||
icon={<PiFloppyDiskBold />}
|
||||
onClick={handleSaveToGallery}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
{isClipboardAPIAvailable && (
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
tooltip={`${t('unifiedCanvas.copyToClipboard')} (Cmd/Ctrl+C)`}
|
||||
icon={<PiCopyBold />}
|
||||
onClick={handleCopyImageToClipboard}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
)}
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
icon={<PiDownloadSimpleBold />}
|
||||
onClick={handleDownloadAsImage}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IAICanvasUndoButton />
|
||||
<IAICanvasRedoButton />
|
||||
</ButtonGroup>
|
||||
)}
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
tooltip={`${t('unifiedCanvas.downloadAsImage')} (Shift+D)`}
|
||||
icon={<PiDownloadSimpleBold />}
|
||||
onClick={handleDownloadAsImage}
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IAICanvasUndoButton />
|
||||
<IAICanvasRedoButton />
|
||||
</ButtonGroup>
|
||||
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={`${t('common.upload')}`}
|
||||
tooltip={`${t('common.upload')}`}
|
||||
icon={<PiUploadSimpleBold />}
|
||||
isDisabled={isStaging}
|
||||
{...getUploadButtonProps()}
|
||||
/>
|
||||
<input {...getUploadInputProps()} />
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.clearCanvas')}`}
|
||||
tooltip={`${t('unifiedCanvas.clearCanvas')}`}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={handleResetCanvas}
|
||||
colorScheme="error"
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IAICanvasSettingsButtonPopover />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
<ViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<ButtonGroup>
|
||||
<IconButton
|
||||
aria-label={`${t('common.upload')}`}
|
||||
tooltip={`${t('common.upload')}`}
|
||||
icon={<PiUploadSimpleBold />}
|
||||
isDisabled={isStaging}
|
||||
{...getUploadButtonProps()}
|
||||
/>
|
||||
<input {...getUploadInputProps()} />
|
||||
<IconButton
|
||||
aria-label={`${t('unifiedCanvas.clearCanvas')}`}
|
||||
tooltip={`${t('unifiedCanvas.clearCanvas')}`}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={handleResetCanvas}
|
||||
colorScheme="error"
|
||||
isDisabled={isStaging}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<ButtonGroup>
|
||||
<IAICanvasSettingsButtonPopover />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
|
||||
const queueItemStatus = action.payload.data.queue_item.status;
|
||||
const queueItemStatus = action.payload.data.status;
|
||||
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||
resetStagingAreaIfEmpty(state);
|
||||
}
|
||||
|
@ -5,11 +5,6 @@ import { z } from 'zod';
|
||||
|
||||
export type CanvasLayer = 'base' | 'mask';
|
||||
|
||||
export const LAYER_NAMES_DICT: { label: string; value: CanvasLayer }[] = [
|
||||
{ label: 'Base', value: 'base' },
|
||||
{ label: 'Mask', value: 'mask' },
|
||||
];
|
||||
|
||||
const zBoundingBoxScaleMethod = z.enum(['none', 'auto', 'manual']);
|
||||
export type BoundingBoxScaleMethod = z.infer<typeof zBoundingBoxScaleMethod>;
|
||||
export const isBoundingBoxScaleMethod = (v: unknown): v is BoundingBoxScaleMethod =>
|
||||
|
@ -5,22 +5,7 @@ import type {
|
||||
ParameterT2IAdapterModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type {
|
||||
CannyImageProcessorInvocation,
|
||||
ColorMapImageProcessorInvocation,
|
||||
ContentShuffleImageProcessorInvocation,
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
DWOpenposeImageProcessorInvocation,
|
||||
HedImageProcessorInvocation,
|
||||
LineartAnimeImageProcessorInvocation,
|
||||
LineartImageProcessorInvocation,
|
||||
MediapipeFaceProcessorInvocation,
|
||||
MidasDepthImageProcessorInvocation,
|
||||
MlsdImageProcessorInvocation,
|
||||
NormalbaeImageProcessorInvocation,
|
||||
PidiImageProcessorInvocation,
|
||||
ZoeDepthImageProcessorInvocation,
|
||||
} from 'services/api/types';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import { z } from 'zod';
|
||||
|
||||
@ -28,20 +13,20 @@ import { z } from 'zod';
|
||||
* Any ControlNet processor node
|
||||
*/
|
||||
export type ControlAdapterProcessorNode =
|
||||
| CannyImageProcessorInvocation
|
||||
| ColorMapImageProcessorInvocation
|
||||
| ContentShuffleImageProcessorInvocation
|
||||
| DepthAnythingImageProcessorInvocation
|
||||
| HedImageProcessorInvocation
|
||||
| LineartAnimeImageProcessorInvocation
|
||||
| LineartImageProcessorInvocation
|
||||
| MediapipeFaceProcessorInvocation
|
||||
| MidasDepthImageProcessorInvocation
|
||||
| MlsdImageProcessorInvocation
|
||||
| NormalbaeImageProcessorInvocation
|
||||
| DWOpenposeImageProcessorInvocation
|
||||
| PidiImageProcessorInvocation
|
||||
| ZoeDepthImageProcessorInvocation;
|
||||
| Invocation<'canny_image_processor'>
|
||||
| Invocation<'color_map_image_processor'>
|
||||
| Invocation<'content_shuffle_image_processor'>
|
||||
| Invocation<'depth_anything_image_processor'>
|
||||
| Invocation<'hed_image_processor'>
|
||||
| Invocation<'lineart_anime_image_processor'>
|
||||
| Invocation<'lineart_image_processor'>
|
||||
| Invocation<'mediapipe_face_processor'>
|
||||
| Invocation<'midas_depth_image_processor'>
|
||||
| Invocation<'mlsd_image_processor'>
|
||||
| Invocation<'normalbae_image_processor'>
|
||||
| Invocation<'dw_openpose_image_processor'>
|
||||
| Invocation<'pidi_image_processor'>
|
||||
| Invocation<'zoe_depth_image_processor'>;
|
||||
|
||||
/**
|
||||
* Any ControlNet processor type
|
||||
@ -71,7 +56,7 @@ export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterPr
|
||||
* The Canny processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredCannyImageProcessorInvocation = O.Required<
|
||||
CannyImageProcessorInvocation,
|
||||
Invocation<'canny_image_processor'>,
|
||||
'type' | 'low_threshold' | 'high_threshold' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@ -79,7 +64,7 @@ export type RequiredCannyImageProcessorInvocation = O.Required<
|
||||
* The Color Map processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredColorMapImageProcessorInvocation = O.Required<
|
||||
ColorMapImageProcessorInvocation,
|
||||
Invocation<'color_map_image_processor'>,
|
||||
'type' | 'color_map_tile_size'
|
||||
>;
|
||||
|
||||
@ -87,7 +72,7 @@ export type RequiredColorMapImageProcessorInvocation = O.Required<
|
||||
* The ContentShuffle processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
||||
ContentShuffleImageProcessorInvocation,
|
||||
Invocation<'content_shuffle_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'w' | 'h' | 'f'
|
||||
>;
|
||||
|
||||
@ -95,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
||||
* The DepthAnything processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
Invocation<'depth_anything_image_processor'>,
|
||||
'type' | 'model_size' | 'resolution' | 'offload'
|
||||
>;
|
||||
|
||||
@ -108,7 +93,7 @@ export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSiz
|
||||
* The HED processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredHedImageProcessorInvocation = O.Required<
|
||||
HedImageProcessorInvocation,
|
||||
Invocation<'hed_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'scribble'
|
||||
>;
|
||||
|
||||
@ -116,7 +101,7 @@ export type RequiredHedImageProcessorInvocation = O.Required<
|
||||
* The Lineart Anime processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
|
||||
LineartAnimeImageProcessorInvocation,
|
||||
Invocation<'lineart_anime_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution'
|
||||
>;
|
||||
|
||||
@ -124,7 +109,7 @@ export type RequiredLineartAnimeImageProcessorInvocation = O.Required<
|
||||
* The Lineart processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredLineartImageProcessorInvocation = O.Required<
|
||||
LineartImageProcessorInvocation,
|
||||
Invocation<'lineart_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'coarse'
|
||||
>;
|
||||
|
||||
@ -132,7 +117,7 @@ export type RequiredLineartImageProcessorInvocation = O.Required<
|
||||
* The MediapipeFace processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMediapipeFaceProcessorInvocation = O.Required<
|
||||
MediapipeFaceProcessorInvocation,
|
||||
Invocation<'mediapipe_face_processor'>,
|
||||
'type' | 'max_faces' | 'min_confidence' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@ -140,7 +125,7 @@ export type RequiredMediapipeFaceProcessorInvocation = O.Required<
|
||||
* The MidasDepth processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMidasDepthImageProcessorInvocation = O.Required<
|
||||
MidasDepthImageProcessorInvocation,
|
||||
Invocation<'midas_depth_image_processor'>,
|
||||
'type' | 'a_mult' | 'bg_th' | 'image_resolution' | 'detect_resolution'
|
||||
>;
|
||||
|
||||
@ -148,7 +133,7 @@ export type RequiredMidasDepthImageProcessorInvocation = O.Required<
|
||||
* The MLSD processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredMlsdImageProcessorInvocation = O.Required<
|
||||
MlsdImageProcessorInvocation,
|
||||
Invocation<'mlsd_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'thr_v' | 'thr_d'
|
||||
>;
|
||||
|
||||
@ -156,7 +141,7 @@ export type RequiredMlsdImageProcessorInvocation = O.Required<
|
||||
* The NormalBae processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredNormalbaeImageProcessorInvocation = O.Required<
|
||||
NormalbaeImageProcessorInvocation,
|
||||
Invocation<'normalbae_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution'
|
||||
>;
|
||||
|
||||
@ -164,7 +149,7 @@ export type RequiredNormalbaeImageProcessorInvocation = O.Required<
|
||||
* The DW Openpose processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
|
||||
DWOpenposeImageProcessorInvocation,
|
||||
Invocation<'dw_openpose_image_processor'>,
|
||||
'type' | 'image_resolution' | 'draw_body' | 'draw_face' | 'draw_hands'
|
||||
>;
|
||||
|
||||
@ -172,14 +157,14 @@ export type RequiredDWOpenposeImageProcessorInvocation = O.Required<
|
||||
* The Pidi processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredPidiImageProcessorInvocation = O.Required<
|
||||
PidiImageProcessorInvocation,
|
||||
Invocation<'pidi_image_processor'>,
|
||||
'type' | 'detect_resolution' | 'image_resolution' | 'safe' | 'scribble'
|
||||
>;
|
||||
|
||||
/**
|
||||
* The ZoeDepth processor node, with parameters flagged as required
|
||||
*/
|
||||
export type RequiredZoeDepthImageProcessorInvocation = O.Required<ZoeDepthImageProcessorInvocation, 'type'>;
|
||||
export type RequiredZoeDepthImageProcessorInvocation = O.Required<Invocation<'zoe_depth_image_processor'>, 'type'>;
|
||||
|
||||
/**
|
||||
* Any ControlNet Processor node, with its parameters flagged as required
|
||||
|
@ -18,7 +18,12 @@ export const AddLayerButton = memo(() => {
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={Button} leftIcon={<PiPlusBold />} variant="ghost">
|
||||
<MenuButton
|
||||
as={Button}
|
||||
leftIcon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
data-testid="control-layers-add-layer-menu-button"
|
||||
>
|
||||
{t('controlLayers.addLayer')}
|
||||
</MenuButton>
|
||||
<MenuList>
|
||||
|
@ -19,7 +19,6 @@ export const CALayer = memo(({ layerId }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isSelected = useAppSelector((s) => selectCALayerOrThrow(s.controlLayers.present, layerId).isSelected);
|
||||
const onClick = useCallback(() => {
|
||||
// Must be capture so that the layer is selected before deleting/resetting/etc
|
||||
dispatch(layerSelected(layerId));
|
||||
}, [dispatch, layerId]);
|
||||
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
|
||||
|
@ -42,10 +42,10 @@ export const ControlAdapterImagePreview = memo(
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
const { currentData: controlImage, isError: isErrorControlImage } = useGetImageDTOQuery(
|
||||
controlAdapter.image?.imageName ?? skipToken
|
||||
controlAdapter.image?.name ?? skipToken
|
||||
);
|
||||
const { currentData: processedControlImage, isError: isErrorProcessedControlImage } = useGetImageDTOQuery(
|
||||
controlAdapter.processedImage?.imageName ?? skipToken
|
||||
controlAdapter.processedImage?.name ?? skipToken
|
||||
);
|
||||
|
||||
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
|
||||
@ -124,7 +124,7 @@ export const ControlAdapterImagePreview = memo(
|
||||
controlImage &&
|
||||
processedControlImage &&
|
||||
!isMouseOverImage &&
|
||||
!controlAdapter.isProcessingImage &&
|
||||
!controlAdapter.processorPendingBatchId &&
|
||||
controlAdapter.processorConfig !== null;
|
||||
|
||||
useEffect(() => {
|
||||
@ -190,7 +190,7 @@ export const ControlAdapterImagePreview = memo(
|
||||
/>
|
||||
</>
|
||||
|
||||
{controlAdapter.isProcessingImage && (
|
||||
{controlAdapter.processorPendingBatchId !== null && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
top={0}
|
||||
|
@ -42,6 +42,7 @@ export const ControlAdapterModelCombobox = memo(({ modelKey, onChange: onChangeM
|
||||
selectedModel,
|
||||
getIsDisabled,
|
||||
isLoading,
|
||||
groupByType: true,
|
||||
});
|
||||
|
||||
return (
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user