mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
71 Commits
v4.2.9.dev
...
feat/ui/wo
Author | SHA1 | Date | |
---|---|---|---|
f505ec64ba | |||
f22eb368a3 | |||
96ae22c7e0 | |||
f5447cdc23 | |||
c76a6bd65f | |||
6c4eeaa569 | |||
1bbd13ead7 | |||
321b939d0e | |||
8fb77e431e | |||
083a4f3faa | |||
2005411f7e | |||
ba7b1b2665 | |||
b7ffd36cc6 | |||
199ddd6623 | |||
a7207ed8cf | |||
6bb2dda3f1 | |||
c1e5cd5893 | |||
ff249a2315 | |||
c58f8c3269 | |||
ed772a7107 | |||
cb0b389b4b | |||
8892df1d97 | |||
bc5f356390 | |||
bcb85e100d | |||
1f27ddc07d | |||
7a2b606001 | |||
83ddcc5f3a | |||
55fa785561 | |||
06429028c8 | |||
8b6e322697 | |||
54a67459bf | |||
7fe5283e74 | |||
fe0391c86b | |||
25386a76ef | |||
fd30cb4d90 | |||
0266946d3d | |||
a7f91b3e01 | |||
de0b72528c | |||
2932652787 | |||
db6bc7305a | |||
a5db204629 | |||
8e2b61e19f | |||
a3faa3792a | |||
c16eba78ab | |||
1a191c4655 | |||
e36d925bce | |||
b1ba18b3d1 | |||
aff46759f9 | |||
d7b7dcc7fe | |||
889a26c5b6 | |||
b4c774896a | |||
afbe889d35 | |||
9c1e52b1ef | |||
3f5ab02da9 | |||
bf48e8a03a | |||
e52434cb99 | |||
483bdbcb9f | |||
ae421fb4ab | |||
cc295a9f0a | |||
a7e23af9c6 | |||
3de4390711 | |||
3ceee2b2b2 | |||
5c7ed24aab | |||
183c9c4799 | |||
8baf3f78a2 | |||
ac2eb16a65 | |||
4aa7bee4b9 | |||
7e5ba2795e | |||
97a6c6eea7 | |||
f0e60a4ba2 | |||
aa089e8108 |
@ -9,11 +9,15 @@ complex functionality.
|
|||||||
|
|
||||||
## Invocations Directory
|
## Invocations Directory
|
||||||
|
|
||||||
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes.
|
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These
|
||||||
|
can be used as examples to create your own nodes.
|
||||||
|
|
||||||
New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup.
|
New nodes should be added to a subfolder in `nodes` direction found at the root
|
||||||
|
level of the InvokeAI installation location. Nodes added to this folder will be
|
||||||
|
able to be used upon application startup.
|
||||||
|
|
||||||
|
Example `nodes` subfolder structure:
|
||||||
|
|
||||||
Example `nodes` subfolder structure:
|
|
||||||
```py
|
```py
|
||||||
├── __init__.py # Invoke-managed custom node loader
|
├── __init__.py # Invoke-managed custom node loader
|
||||||
│
|
│
|
||||||
@ -30,14 +34,14 @@ Example `nodes` subfolder structure:
|
|||||||
└── fancy_node.py
|
└── fancy_node.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded.
|
Each node folder must have an `__init__.py` file that imports its nodes. Only
|
||||||
See the README in the nodes folder for more examples:
|
nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||||
|
folder for more examples:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from .cool_node import CoolInvocation
|
from .cool_node import CoolInvocation
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Creating A New Invocation
|
## Creating A New Invocation
|
||||||
|
|
||||||
In order to understand the process of creating a new Invocation, let us actually
|
In order to understand the process of creating a new Invocation, let us actually
|
||||||
@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField
|
|||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
@ -167,7 +170,6 @@ from invokeai.app.invocations.primitives import ImageField
|
|||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
@ -197,7 +199,6 @@ from invokeai.app.invocations.image import ImageOutput
|
|||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
@ -229,30 +230,17 @@ class ResizeInvocation(BaseInvocation):
|
|||||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
|
# Load the input image as a PIL image
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
# Resizing the image
|
# Resize the image
|
||||||
resized_image = image.resize((self.width, self.height))
|
resized_image = image.resize((self.width, self.height))
|
||||||
|
|
||||||
# Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
|
# Save the image
|
||||||
output_image = context.services.images.create(
|
image_dto = context.images.save(image=resized_image)
|
||||||
image=resized_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Returning the Image
|
# Return an ImageOutput
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(
|
|
||||||
image_name=output_image.image_name,
|
|
||||||
),
|
|
||||||
width=output_image.width,
|
|
||||||
height=output_image.height,
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a
|
**Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a
|
||||||
@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
That's all there is to it.
|
That's all there is to it.
|
||||||
|
|
||||||
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
|
|
||||||
|
|
||||||
### Custom Input Fields
|
### Custom Input Fields
|
||||||
|
|
||||||
Now that you know how to create your own Invocations, let us dive into slightly
|
Now that you know how to create your own Invocations, let us dive into slightly
|
||||||
more advanced topics.
|
more advanced topics.
|
||||||
|
|
||||||
While creating your own Invocations, you might run into a scenario where the
|
While creating your own Invocations, you might run into a scenario where the
|
||||||
existing input types in InvokeAI do not meet your requirements. In such cases,
|
existing fields in InvokeAI do not meet your requirements. In such cases, you
|
||||||
you can create your own input types.
|
can create your own fields.
|
||||||
|
|
||||||
Let us create one as an example. Let us say we want to create a color input
|
Let us create one as an example. Let us say we want to create a color input
|
||||||
field that represents a color code. But before we start on that here are some
|
field that represents a color code. But before we start on that here are some
|
||||||
general good practices to keep in mind.
|
general good practices to keep in mind.
|
||||||
|
|
||||||
**Good Practices**
|
### Best Practices
|
||||||
|
|
||||||
- There is no naming convention for input fields but we highly recommend that
|
- There is no naming convention for input fields but we highly recommend that
|
||||||
you name it something appropriate like `ColorField`.
|
you name it something appropriate like `ColorField`.
|
||||||
- It is not mandatory but it is heavily recommended to add a relevant
|
- It is not mandatory but it is heavily recommended to add a relevant
|
||||||
`docstring` to describe your input field.
|
`docstring` to describe your field.
|
||||||
- Keep your field in the same file as the Invocation that it is made for or in
|
- Keep your field in the same file as the Invocation that it is made for or in
|
||||||
another file where it is relevant.
|
another file where it is relevant.
|
||||||
|
|
||||||
@ -378,10 +364,13 @@ class ColorField(BaseModel):
|
|||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
Perfect. Now let us create our custom inputs for our field. This is exactly
|
Perfect. Now let us create the properties for our field. This is similar to how
|
||||||
similar how you created input fields for your Invocation. All the same rules
|
you created input fields for your Invocation. All the same rules apply. Let us
|
||||||
apply. Let us create four fields representing the _red(r)_, _blue(b)_,
|
create four fields representing the _red(r)_, _blue(b)_, _green(g)_ and
|
||||||
_green(g)_ and _alpha(a)_ channel of the color.
|
_alpha(a)_ channel of the color.
|
||||||
|
|
||||||
|
> Technically, the properties are _also_ called fields - but in this case, it
|
||||||
|
> refers to a `pydantic` field.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ColorField(BaseModel):
|
class ColorField(BaseModel):
|
||||||
@ -396,25 +385,11 @@ That's it. We now have a new input field type that we can use in our Invocations
|
|||||||
like this.
|
like this.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom Components For Frontend
|
### Using the custom field
|
||||||
|
|
||||||
Every backend input type should have a corresponding frontend component so the
|
When you start the UI, your custom field will be automatically recognized.
|
||||||
UI knows what to render when you use a particular field type.
|
|
||||||
|
|
||||||
If you are using existing field types, we already have components for those. So
|
Custom fields only support connection inputs in the Workflow Editor.
|
||||||
you don't have to worry about creating anything new. But this might not always
|
|
||||||
be the case. Sometimes you might want to create new field types and have the
|
|
||||||
frontend UI deal with it in a different way.
|
|
||||||
|
|
||||||
This is where we venture into the world of React and Javascript and create our
|
|
||||||
own new components for our Invocations. Do not fear the world of JS. It's
|
|
||||||
actually pretty straightforward.
|
|
||||||
|
|
||||||
Let us create a new component for our custom color field we created above. When
|
|
||||||
we use a color field, let us say we want the UI to display a color picker for
|
|
||||||
the user to pick from rather than entering values. That is what we will build
|
|
||||||
now.
|
|
||||||
-->
|
|
||||||
|
@ -2,9 +2,14 @@
|
|||||||
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
|
||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
|
||||||
from ..services.model_install import ModelInstallService
|
from ..services.model_install import ModelInstallService
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
@ -68,6 +71,9 @@ class ApiDependencies:
|
|||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
if output_folder is None:
|
||||||
|
raise ValueError("Output folder is not set")
|
||||||
|
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||||
@ -84,7 +90,12 @@ class ApiDependencies:
|
|||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
tensors = ObjectSerializerForwardCache(
|
||||||
|
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||||
|
)
|
||||||
|
conditioning = ObjectSerializerForwardCache(
|
||||||
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
|
)
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
model_record_service = ModelRecordServiceSQL(db=db)
|
model_record_service = ModelRecordServiceSQL(db=db)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
@ -117,7 +128,6 @@ class ApiDependencies:
|
|||||||
image_records=image_records,
|
image_records=image_records,
|
||||||
images=images,
|
images=images,
|
||||||
invocation_cache=invocation_cache,
|
invocation_cache=invocation_cache,
|
||||||
latents=latents,
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
model_records=model_record_service,
|
model_records=model_record_service,
|
||||||
@ -131,6 +141,8 @@ class ApiDependencies:
|
|||||||
session_queue=session_queue,
|
session_queue=session_queue,
|
||||||
urls=urls,
|
urls=urls,
|
||||||
workflow_records=workflow_records,
|
workflow_records=workflow_records,
|
||||||
|
tensors=tensors,
|
||||||
|
conditioning=conditioning,
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
@ -8,7 +8,7 @@ from fastapi.routing import APIRouter
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
@ -6,6 +6,7 @@ import sys
|
|||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
|
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
@ -57,8 +58,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import (
|
from .invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InputFieldJSONSchemaExtra,
|
|
||||||
OutputFieldJSONSchemaExtra,
|
|
||||||
UIConfigBase,
|
UIConfigBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,13 +12,16 @@ from types import UnionType
|
|||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||||
from pydantic.fields import FieldInfo, _Unset
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldKind,
|
||||||
|
Input,
|
||||||
|
)
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -52,393 +55,6 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
|||||||
Prototype = "prototype"
|
Prototype = "prototype"
|
||||||
|
|
||||||
|
|
||||||
class Input(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""
|
|
||||||
The type of input a field accepts.
|
|
||||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
|
||||||
are instantiated.
|
|
||||||
- `Input.Connection`: The field must have its value provided by a connection.
|
|
||||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Connection = "connection"
|
|
||||||
Direct = "direct"
|
|
||||||
Any = "any"
|
|
||||||
|
|
||||||
|
|
||||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""
|
|
||||||
The kind of field.
|
|
||||||
- `Input`: An input field on a node.
|
|
||||||
- `Output`: An output field on a node.
|
|
||||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
|
||||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
|
||||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
|
||||||
allowing "metadata" for that field.
|
|
||||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
|
||||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
|
||||||
attributes.
|
|
||||||
|
|
||||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
|
||||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Input = "input"
|
|
||||||
Output = "output"
|
|
||||||
Internal = "internal"
|
|
||||||
NodeAttribute = "node_attribute"
|
|
||||||
|
|
||||||
|
|
||||||
class UIType(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""
|
|
||||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
|
||||||
|
|
||||||
- Model Fields
|
|
||||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
|
||||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
|
||||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
|
||||||
the field is an SDXL main model field.
|
|
||||||
|
|
||||||
- Any Field
|
|
||||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
|
||||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
|
||||||
|
|
||||||
- Scheduler Field
|
|
||||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
|
||||||
|
|
||||||
- Internal Fields
|
|
||||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
|
||||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
|
||||||
should not be used by node authors.
|
|
||||||
|
|
||||||
- DEPRECATED Fields
|
|
||||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
|
||||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# region Model Field Types
|
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
|
||||||
ONNXModel = "ONNXModelField"
|
|
||||||
VaeModel = "VAEModelField"
|
|
||||||
LoRAModel = "LoRAModelField"
|
|
||||||
ControlNetModel = "ControlNetModelField"
|
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
# region Misc Field Types
|
|
||||||
Scheduler = "SchedulerField"
|
|
||||||
Any = "AnyField"
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
# region Internal Field Types
|
|
||||||
_Collection = "CollectionField"
|
|
||||||
_CollectionItem = "CollectionItemField"
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
# region DEPRECATED
|
|
||||||
Boolean = "DEPRECATED_Boolean"
|
|
||||||
Color = "DEPRECATED_Color"
|
|
||||||
Conditioning = "DEPRECATED_Conditioning"
|
|
||||||
Control = "DEPRECATED_Control"
|
|
||||||
Float = "DEPRECATED_Float"
|
|
||||||
Image = "DEPRECATED_Image"
|
|
||||||
Integer = "DEPRECATED_Integer"
|
|
||||||
Latents = "DEPRECATED_Latents"
|
|
||||||
String = "DEPRECATED_String"
|
|
||||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
|
||||||
ColorCollection = "DEPRECATED_ColorCollection"
|
|
||||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
|
||||||
ControlCollection = "DEPRECATED_ControlCollection"
|
|
||||||
FloatCollection = "DEPRECATED_FloatCollection"
|
|
||||||
ImageCollection = "DEPRECATED_ImageCollection"
|
|
||||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
|
||||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
|
||||||
StringCollection = "DEPRECATED_StringCollection"
|
|
||||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
|
||||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
|
||||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
|
||||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
|
||||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
|
||||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
|
||||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
|
||||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
|
||||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
|
||||||
MainModel = "DEPRECATED_MainModel"
|
|
||||||
UNet = "DEPRECATED_UNet"
|
|
||||||
Vae = "DEPRECATED_Vae"
|
|
||||||
CLIP = "DEPRECATED_CLIP"
|
|
||||||
Collection = "DEPRECATED_Collection"
|
|
||||||
CollectionItem = "DEPRECATED_CollectionItem"
|
|
||||||
Enum = "DEPRECATED_Enum"
|
|
||||||
WorkflowField = "DEPRECATED_WorkflowField"
|
|
||||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
|
||||||
BoardField = "DEPRECATED_BoardField"
|
|
||||||
MetadataItem = "DEPRECATED_MetadataItem"
|
|
||||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
|
||||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
|
||||||
MetadataDict = "DEPRECATED_MetadataDict"
|
|
||||||
# endregion
|
|
||||||
|
|
||||||
|
|
||||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""
|
|
||||||
The type of UI component to use for a field, used to override the default components, which are
|
|
||||||
inferred from the field type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
None_ = "none"
|
|
||||||
Textarea = "textarea"
|
|
||||||
Slider = "slider"
|
|
||||||
|
|
||||||
|
|
||||||
class InputFieldJSONSchemaExtra(BaseModel):
|
|
||||||
"""
|
|
||||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
|
||||||
and by the workflow editor during schema parsing and UI rendering.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input: Input
|
|
||||||
orig_required: bool
|
|
||||||
field_kind: FieldKind
|
|
||||||
default: Optional[Any] = None
|
|
||||||
orig_default: Optional[Any] = None
|
|
||||||
ui_hidden: bool = False
|
|
||||||
ui_type: Optional[UIType] = None
|
|
||||||
ui_component: Optional[UIComponent] = None
|
|
||||||
ui_order: Optional[int] = None
|
|
||||||
ui_choice_labels: Optional[dict[str, str]] = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
validate_assignment=True,
|
|
||||||
json_schema_serialization_defaults_required=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
|
||||||
"""
|
|
||||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
|
||||||
during schema parsing and UI rendering.
|
|
||||||
"""
|
|
||||||
|
|
||||||
field_kind: FieldKind
|
|
||||||
ui_hidden: bool
|
|
||||||
ui_type: Optional[UIType]
|
|
||||||
ui_order: Optional[int]
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
validate_assignment=True,
|
|
||||||
json_schema_serialization_defaults_required=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def InputField(
|
|
||||||
# copied from pydantic's Field
|
|
||||||
# TODO: Can we support default_factory?
|
|
||||||
default: Any = _Unset,
|
|
||||||
default_factory: Callable[[], Any] | None = _Unset,
|
|
||||||
title: str | None = _Unset,
|
|
||||||
description: str | None = _Unset,
|
|
||||||
pattern: str | None = _Unset,
|
|
||||||
strict: bool | None = _Unset,
|
|
||||||
gt: float | None = _Unset,
|
|
||||||
ge: float | None = _Unset,
|
|
||||||
lt: float | None = _Unset,
|
|
||||||
le: float | None = _Unset,
|
|
||||||
multiple_of: float | None = _Unset,
|
|
||||||
allow_inf_nan: bool | None = _Unset,
|
|
||||||
max_digits: int | None = _Unset,
|
|
||||||
decimal_places: int | None = _Unset,
|
|
||||||
min_length: int | None = _Unset,
|
|
||||||
max_length: int | None = _Unset,
|
|
||||||
# custom
|
|
||||||
input: Input = Input.Any,
|
|
||||||
ui_type: Optional[UIType] = None,
|
|
||||||
ui_component: Optional[UIComponent] = None,
|
|
||||||
ui_hidden: bool = False,
|
|
||||||
ui_order: Optional[int] = None,
|
|
||||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Creates an input field for an invocation.
|
|
||||||
|
|
||||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
|
||||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
|
||||||
|
|
||||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
|
||||||
`Input.Direct` means a value must be provided on instantiation. \
|
|
||||||
`Input.Connection` means the value must be provided by a connection. \
|
|
||||||
`Input.Any` means either will do.
|
|
||||||
|
|
||||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
|
||||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
|
||||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
|
||||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
|
||||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
|
||||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
|
||||||
|
|
||||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
|
||||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
|
||||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
|
||||||
For this case, you could provide `UIComponent.Textarea`.
|
|
||||||
|
|
||||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
|
||||||
|
|
||||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
|
||||||
|
|
||||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
|
||||||
"""
|
|
||||||
|
|
||||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
|
||||||
input=input,
|
|
||||||
ui_type=ui_type,
|
|
||||||
ui_component=ui_component,
|
|
||||||
ui_hidden=ui_hidden,
|
|
||||||
ui_order=ui_order,
|
|
||||||
ui_choice_labels=ui_choice_labels,
|
|
||||||
field_kind=FieldKind.Input,
|
|
||||||
orig_required=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
|
||||||
`invoke()` function.
|
|
||||||
|
|
||||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
|
||||||
any number of fields may be optional, because they may be provided by connections.
|
|
||||||
|
|
||||||
On calling of `invoke()`, however, those fields may be required.
|
|
||||||
|
|
||||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
|
||||||
|
|
||||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
|
||||||
the field may not be present. This is fine, because that image field will be provided by a
|
|
||||||
connection from an ancestor node, which outputs an image.
|
|
||||||
|
|
||||||
This means we want to type the `image` field as optional for the node class definition, but required
|
|
||||||
for the `invoke()` function.
|
|
||||||
|
|
||||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
|
||||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
|
||||||
any static type analysis tools will complain.
|
|
||||||
|
|
||||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
|
||||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
|
||||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if default_factory is not _Unset and default_factory is not None:
|
|
||||||
default = default_factory()
|
|
||||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
|
||||||
|
|
||||||
# These are the args we may wish pass to the pydantic `Field()` function
|
|
||||||
field_args = {
|
|
||||||
"default": default,
|
|
||||||
"title": title,
|
|
||||||
"description": description,
|
|
||||||
"pattern": pattern,
|
|
||||||
"strict": strict,
|
|
||||||
"gt": gt,
|
|
||||||
"ge": ge,
|
|
||||||
"lt": lt,
|
|
||||||
"le": le,
|
|
||||||
"multiple_of": multiple_of,
|
|
||||||
"allow_inf_nan": allow_inf_nan,
|
|
||||||
"max_digits": max_digits,
|
|
||||||
"decimal_places": decimal_places,
|
|
||||||
"min_length": min_length,
|
|
||||||
"max_length": max_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
|
||||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
|
||||||
|
|
||||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
|
||||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
|
||||||
|
|
||||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
|
||||||
if input is Input.Any or input is Input.Connection:
|
|
||||||
default_ = None if default is PydanticUndefined else default
|
|
||||||
provided_args.update({"default": default_})
|
|
||||||
if default is not PydanticUndefined:
|
|
||||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
|
||||||
json_schema_extra_.default = default
|
|
||||||
json_schema_extra_.orig_default = default
|
|
||||||
elif default is not PydanticUndefined:
|
|
||||||
default_ = default
|
|
||||||
provided_args.update({"default": default_})
|
|
||||||
json_schema_extra_.orig_default = default_
|
|
||||||
|
|
||||||
return Field(
|
|
||||||
**provided_args,
|
|
||||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def OutputField(
|
|
||||||
# copied from pydantic's Field
|
|
||||||
default: Any = _Unset,
|
|
||||||
title: str | None = _Unset,
|
|
||||||
description: str | None = _Unset,
|
|
||||||
pattern: str | None = _Unset,
|
|
||||||
strict: bool | None = _Unset,
|
|
||||||
gt: float | None = _Unset,
|
|
||||||
ge: float | None = _Unset,
|
|
||||||
lt: float | None = _Unset,
|
|
||||||
le: float | None = _Unset,
|
|
||||||
multiple_of: float | None = _Unset,
|
|
||||||
allow_inf_nan: bool | None = _Unset,
|
|
||||||
max_digits: int | None = _Unset,
|
|
||||||
decimal_places: int | None = _Unset,
|
|
||||||
min_length: int | None = _Unset,
|
|
||||||
max_length: int | None = _Unset,
|
|
||||||
# custom
|
|
||||||
ui_type: Optional[UIType] = None,
|
|
||||||
ui_hidden: bool = False,
|
|
||||||
ui_order: Optional[int] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Creates an output field for an invocation output.
|
|
||||||
|
|
||||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
|
||||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
|
||||||
|
|
||||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
|
||||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
|
||||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
|
||||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
|
||||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
|
||||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
|
||||||
|
|
||||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
|
||||||
|
|
||||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
|
||||||
"""
|
|
||||||
return Field(
|
|
||||||
default=default,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
pattern=pattern,
|
|
||||||
strict=strict,
|
|
||||||
gt=gt,
|
|
||||||
ge=ge,
|
|
||||||
lt=lt,
|
|
||||||
le=le,
|
|
||||||
multiple_of=multiple_of,
|
|
||||||
allow_inf_nan=allow_inf_nan,
|
|
||||||
max_digits=max_digits,
|
|
||||||
decimal_places=decimal_places,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=max_length,
|
|
||||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
|
||||||
ui_type=ui_type,
|
|
||||||
ui_hidden=ui_hidden,
|
|
||||||
ui_order=ui_order,
|
|
||||||
field_kind=FieldKind.Output,
|
|
||||||
).model_dump(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UIConfigBase(BaseModel):
|
class UIConfigBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Provides additional node configuration to the UI.
|
Provides additional node configuration to the UI.
|
||||||
@ -460,33 +76,6 @@ class UIConfigBase(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
|
||||||
"""Initialized and provided to on execution of invocations."""
|
|
||||||
|
|
||||||
services: InvocationServices
|
|
||||||
graph_execution_state_id: str
|
|
||||||
queue_id: str
|
|
||||||
queue_item_id: int
|
|
||||||
queue_batch_id: str
|
|
||||||
workflow: Optional[WorkflowWithoutID]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
services: InvocationServices,
|
|
||||||
queue_id: str,
|
|
||||||
queue_item_id: int,
|
|
||||||
queue_batch_id: str,
|
|
||||||
graph_execution_state_id: str,
|
|
||||||
workflow: Optional[WorkflowWithoutID],
|
|
||||||
):
|
|
||||||
self.services = services
|
|
||||||
self.graph_execution_state_id = graph_execution_state_id
|
|
||||||
self.queue_id = queue_id
|
|
||||||
self.queue_item_id = queue_item_id
|
|
||||||
self.queue_batch_id = queue_batch_id
|
|
||||||
self.workflow = workflow
|
|
||||||
|
|
||||||
|
|
||||||
class BaseInvocationOutput(BaseModel):
|
class BaseInvocationOutput(BaseModel):
|
||||||
"""
|
"""
|
||||||
Base class for all invocation outputs.
|
Base class for all invocation outputs.
|
||||||
@ -632,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||||
"""
|
"""
|
||||||
Internal invoke method, calls `invoke()` after some prep.
|
Internal invoke method, calls `invoke()` after some prep.
|
||||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||||
@ -657,23 +246,23 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||||
|
|
||||||
# skip node cache codepath if it's disabled
|
# skip node cache codepath if it's disabled
|
||||||
if context.services.configuration.node_cache_size == 0:
|
if services.configuration.node_cache_size == 0:
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
output: BaseInvocationOutput
|
output: BaseInvocationOutput
|
||||||
if self.use_cache:
|
if self.use_cache:
|
||||||
key = context.services.invocation_cache.create_key(self)
|
key = services.invocation_cache.create_key(self)
|
||||||
cached_value = context.services.invocation_cache.get(key)
|
cached_value = services.invocation_cache.get(key)
|
||||||
if cached_value is None:
|
if cached_value is None:
|
||||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||||
output = self.invoke(context)
|
output = self.invoke(context)
|
||||||
context.services.invocation_cache.save(key, output)
|
services.invocation_cache.save(key, output)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||||
return cached_value
|
return cached_value
|
||||||
else:
|
else:
|
||||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
@ -714,9 +303,7 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
|||||||
"workflow",
|
"workflow",
|
||||||
}
|
}
|
||||||
|
|
||||||
RESERVED_INPUT_FIELD_NAMES = {
|
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||||
"metadata",
|
|
||||||
}
|
|
||||||
|
|
||||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||||
|
|
||||||
@ -926,37 +513,3 @@ def invocation_output(
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
|
||||||
"""
|
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
|
||||||
Metadata is stored without a strict schema.
|
|
||||||
"""
|
|
||||||
|
|
||||||
root: dict[str, Any] = Field(description="The metadata")
|
|
||||||
|
|
||||||
|
|
||||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
|
||||||
|
|
||||||
|
|
||||||
class WithMetadata(BaseModel):
|
|
||||||
metadata: Optional[MetadataField] = Field(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.metadata,
|
|
||||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
|
||||||
field_kind=FieldKind.Internal,
|
|
||||||
input=Input.Connection,
|
|
||||||
orig_required=False,
|
|
||||||
).model_dump(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WithWorkflow:
|
|
||||||
workflow = None
|
|
||||||
|
|
||||||
def __init_subclass__(cls) -> None:
|
|
||||||
logger.warn(
|
|
||||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
|
||||||
)
|
|
||||||
super().__init_subclass__()
|
|
||||||
|
@ -5,9 +5,11 @@ import numpy as np
|
|||||||
from pydantic import ValidationInfo, field_validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -1,14 +1,21 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
from invokeai.app.invocations.fields import (
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
ConditioningFieldData,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
@ -20,21 +27,12 @@ from ..util.ti_utils import extract_ti_triggers_from_prompt
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
|
# unconditioned: Optional[torch.Tensor]
|
||||||
@dataclass
|
|
||||||
class ConditioningFieldData:
|
|
||||||
conditionings: List[BasicConditioningInfo]
|
|
||||||
# unconditioned: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
# class ConditioningAlgo(str, Enum):
|
# class ConditioningAlgo(str, Enum):
|
||||||
@ -48,7 +46,7 @@ class ConditioningFieldData:
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -66,25 +64,17 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
**self.clip.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
@ -93,11 +83,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
|
||||||
).context.model,
|
).context.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -128,7 +117,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
@ -149,14 +138,9 @@ class CompelInvocation(BaseInvocation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
@ -169,14 +153,8 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
):
|
):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
**clip_field.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -200,14 +178,12 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
@ -216,11 +192,10 @@ class SDXLPromptInvocationBase:
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
|
||||||
).context.model,
|
).context.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -253,7 +228,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(prompt)
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
# TODO: better logging for and syntax
|
# TODO: better logging for and syntax
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
@ -286,7 +261,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -368,14 +343,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -383,7 +353,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
title="SDXL Refiner Prompt",
|
title="SDXL Refiner Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -421,14 +391,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
|
14
invokeai/app/invocations/constants.py
Normal file
14
invokeai/app/invocations/constants.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
|
||||||
|
LATENT_SCALE_FACTOR = 8
|
||||||
|
"""
|
||||||
|
HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||||
|
be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||||
|
factor is hard-coded to a literal '8' rather than using this constant.
|
||||||
|
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||||
|
"""A literal type representing the valid scheduler names."""
|
@ -25,22 +25,25 @@ from controlnet_aux.util import HWC3, ade_palette
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
WithMetadata,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -140,7 +143,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
@ -150,22 +153,13 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
raw_image = context.images.get_pil(self.image.image_name)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
# currently can't see processed image in node UI without a showImage node,
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=processed_image)
|
||||||
image=processed_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.CONTROL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||||
@ -184,7 +178,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
|||||||
title="Canny Processor",
|
title="Canny Processor",
|
||||||
tags=["controlnet", "canny"],
|
tags=["controlnet", "canny"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
@ -207,7 +201,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="HED (softedge) Processor",
|
title="HED (softedge) Processor",
|
||||||
tags=["controlnet", "hed", "softedge"],
|
tags=["controlnet", "hed", "softedge"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
@ -236,7 +230,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Processor",
|
title="Lineart Processor",
|
||||||
tags=["controlnet", "lineart"],
|
tags=["controlnet", "lineart"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
@ -258,7 +252,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Anime Processor",
|
title="Lineart Anime Processor",
|
||||||
tags=["controlnet", "lineart", "anime"],
|
tags=["controlnet", "lineart", "anime"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
@ -281,7 +275,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Midas Depth Processor",
|
title="Midas Depth Processor",
|
||||||
tags=["controlnet", "midas"],
|
tags=["controlnet", "midas"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
@ -308,7 +302,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Normal BAE Processor",
|
title="Normal BAE Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
@ -325,7 +319,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
|
||||||
)
|
)
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
@ -348,7 +342,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
|
||||||
)
|
)
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
@ -375,7 +369,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Content Shuffle Processor",
|
title="Content Shuffle Processor",
|
||||||
tags=["controlnet", "contentshuffle"],
|
tags=["controlnet", "contentshuffle"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
@ -405,7 +399,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Zoe (Depth) Processor",
|
title="Zoe (Depth) Processor",
|
||||||
tags=["controlnet", "zoe", "depth"],
|
tags=["controlnet", "zoe", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
@ -421,7 +415,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Mediapipe Face Processor",
|
title="Mediapipe Face Processor",
|
||||||
tags=["controlnet", "mediapipe", "face"],
|
tags=["controlnet", "mediapipe", "face"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
@ -444,7 +438,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Leres (Depth) Processor",
|
title="Leres (Depth) Processor",
|
||||||
tags=["controlnet", "leres", "depth"],
|
tags=["controlnet", "leres", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
@ -473,7 +467,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Tile Resample Processor",
|
title="Tile Resample Processor",
|
||||||
tags=["controlnet", "tile"],
|
tags=["controlnet", "tile"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Tile resampler processor"""
|
"""Tile resampler processor"""
|
||||||
@ -513,7 +507,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Segment Anything Processor",
|
title="Segment Anything Processor",
|
||||||
tags=["controlnet", "segmentanything"],
|
tags=["controlnet", "segmentanything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
@ -555,7 +549,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
|||||||
title="Color Map Processor",
|
title="Color Map Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Generates a color map from the provided image"""
|
"""Generates a color map from the provided image"""
|
||||||
|
@ -5,22 +5,24 @@ import cv2 as cv
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
|
|
||||||
|
|
||||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
mask = context.images.get_pil(self.mask.image_name)
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
# TODO: consider making these utility functions
|
# TODO: consider making these utility functions
|
||||||
@ -34,18 +36,6 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
|||||||
# TODO: consider making a utility function
|
# TODO: consider making a utility function
|
||||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=image_inpainted)
|
||||||
image=image_inpainted,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -13,15 +13,13 @@ from pydantic import field_validator
|
|||||||
import invokeai.assets.fonts as font_assets
|
import invokeai.assets.fonts as font_assets
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
WithMetadata,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("face_mask_output")
|
@invocation_output("face_mask_output")
|
||||||
@ -306,37 +304,37 @@ def extract_face(
|
|||||||
|
|
||||||
# Adjust the crop boundaries to stay within the original image's dimensions
|
# Adjust the crop boundaries to stay within the original image's dimensions
|
||||||
if x_min < 0:
|
if x_min < 0:
|
||||||
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||||
x_max -= x_min
|
x_max -= x_min
|
||||||
x_min = 0
|
x_min = 0
|
||||||
elif x_max > mask.width:
|
elif x_max > mask.width:
|
||||||
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||||
x_min -= x_max - mask.width
|
x_min -= x_max - mask.width
|
||||||
x_max = mask.width
|
x_max = mask.width
|
||||||
|
|
||||||
if y_min < 0:
|
if y_min < 0:
|
||||||
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||||
y_max -= y_min
|
y_max -= y_min
|
||||||
y_min = 0
|
y_min = 0
|
||||||
elif y_max > mask.height:
|
elif y_max > mask.height:
|
||||||
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||||
y_min -= y_max - mask.height
|
y_min -= y_max - mask.height
|
||||||
y_max = mask.height
|
y_max = mask.height
|
||||||
|
|
||||||
# Ensure the crop is square and adjust the boundaries if needed
|
# Ensure the crop is square and adjust the boundaries if needed
|
||||||
if x_max - x_min != crop_size:
|
if x_max - x_min != crop_size:
|
||||||
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||||
diff = crop_size - (x_max - x_min)
|
diff = crop_size - (x_max - x_min)
|
||||||
x_min -= diff // 2
|
x_min -= diff // 2
|
||||||
x_max += diff - diff // 2
|
x_max += diff - diff // 2
|
||||||
|
|
||||||
if y_max - y_min != crop_size:
|
if y_max - y_min != crop_size:
|
||||||
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||||
diff = crop_size - (y_max - y_min)
|
diff = crop_size - (y_max - y_min)
|
||||||
y_min -= diff // 2
|
y_min -= diff // 2
|
||||||
y_max += diff - diff // 2
|
y_max += diff - diff // 2
|
||||||
|
|
||||||
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||||
|
|
||||||
# Crop the output image to the specified size with the center of the face mesh as the center.
|
# Crop the output image to the specified size with the center of the face mesh as the center.
|
||||||
mask = mask.crop((x_min, y_min, x_max, y_max))
|
mask = mask.crop((x_min, y_min, x_max, y_max))
|
||||||
@ -368,7 +366,7 @@ def get_faces_list(
|
|||||||
|
|
||||||
# Generate the face box mask and get the center of the face.
|
# Generate the face box mask and get the center of the face.
|
||||||
if not should_chunk:
|
if not should_chunk:
|
||||||
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
context.logger.info("FaceTools --> Attempting full image face detection.")
|
||||||
result = generate_face_box_mask(
|
result = generate_face_box_mask(
|
||||||
context=context,
|
context=context,
|
||||||
minimum_confidence=minimum_confidence,
|
minimum_confidence=minimum_confidence,
|
||||||
@ -380,7 +378,7 @@ def get_faces_list(
|
|||||||
draw_mesh=draw_mesh,
|
draw_mesh=draw_mesh,
|
||||||
)
|
)
|
||||||
if should_chunk or len(result) == 0:
|
if should_chunk or len(result) == 0:
|
||||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
image_chunks = []
|
image_chunks = []
|
||||||
x_offsets = []
|
x_offsets = []
|
||||||
@ -399,7 +397,7 @@ def get_faces_list(
|
|||||||
x_offsets.append(x)
|
x_offsets.append(x)
|
||||||
y_offsets.append(0)
|
y_offsets.append(0)
|
||||||
fx += increment
|
fx += increment
|
||||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
context.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||||
elif height > width:
|
elif height > width:
|
||||||
# Portrait - slice the image vertically
|
# Portrait - slice the image vertically
|
||||||
fy = 0.0
|
fy = 0.0
|
||||||
@ -411,10 +409,10 @@ def get_faces_list(
|
|||||||
x_offsets.append(0)
|
x_offsets.append(0)
|
||||||
y_offsets.append(y)
|
y_offsets.append(y)
|
||||||
fy += increment
|
fy += increment
|
||||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
context.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||||
|
|
||||||
for idx in range(len(image_chunks)):
|
for idx in range(len(image_chunks)):
|
||||||
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||||
result = result + generate_face_box_mask(
|
result = result + generate_face_box_mask(
|
||||||
context=context,
|
context=context,
|
||||||
minimum_confidence=minimum_confidence,
|
minimum_confidence=minimum_confidence,
|
||||||
@ -428,7 +426,7 @@ def get_faces_list(
|
|||||||
|
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
# Give up
|
# Give up
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -437,7 +435,7 @@ def get_faces_list(
|
|||||||
return all_faces
|
return all_faces
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
|
||||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||||
|
|
||||||
@ -470,11 +468,11 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(all_faces) == 0:
|
if len(all_faces) == 0:
|
||||||
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
context.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.face_id > len(all_faces) - 1:
|
if self.face_id > len(all_faces) - 1:
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -486,7 +484,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
return face_data
|
return face_data
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result = self.faceoff(context=context, image=image)
|
result = self.faceoff(context=context, image=image)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
@ -500,24 +498,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
x = result["x_min"]
|
x = result["x_min"]
|
||||||
y = result["y_min"]
|
y = result["y_min"]
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result_image)
|
||||||
image=result_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_dto = context.services.images.create(
|
mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK)
|
||||||
image=result_mask,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.MASK,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = FaceOffOutput(
|
output = FaceOffOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
@ -531,7 +514,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
|
||||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Face mask creation using mediapipe face detection"""
|
"""Face mask creation using mediapipe face detection"""
|
||||||
|
|
||||||
@ -580,7 +563,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
if len(intersected_face_ids) == 0:
|
if len(intersected_face_ids) == 0:
|
||||||
id_range_str = ",".join([str(id) for id in id_range])
|
id_range_str = ",".join([str(id) for id in id_range])
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
||||||
)
|
)
|
||||||
return FaceMaskResult(
|
return FaceMaskResult(
|
||||||
@ -616,27 +599,12 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result = self.facemask(context=context, image=image)
|
result = self.facemask(context=context, image=image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result["image"])
|
||||||
image=result["image"],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_dto = context.services.images.create(
|
mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK)
|
||||||
image=result["mask"],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.MASK,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = FaceMaskOutput(
|
output = FaceMaskOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
@ -649,9 +617,9 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
|
||||||
)
|
)
|
||||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="Image to face detect")
|
image: ImageField = InputField(description="Image to face detect")
|
||||||
@ -705,21 +673,9 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result_image = self.faceidentifier(context=context, image=image)
|
result_image = self.faceidentifier(context=context, image=image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result_image)
|
||||||
image=result_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
565
invokeai/app/invocations/fields.py
Normal file
565
invokeai/app/invocations/fields.py
Normal file
@ -0,0 +1,565 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||||
|
from pydantic.fields import _Unset
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class UIType(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""
|
||||||
|
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||||
|
|
||||||
|
- Model Fields
|
||||||
|
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||||
|
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||||
|
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||||
|
the field is an SDXL main model field.
|
||||||
|
|
||||||
|
- Any Field
|
||||||
|
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||||
|
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||||
|
|
||||||
|
- Scheduler Field
|
||||||
|
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||||
|
|
||||||
|
- Internal Fields
|
||||||
|
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||||
|
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||||
|
should not be used by node authors.
|
||||||
|
|
||||||
|
- DEPRECATED Fields
|
||||||
|
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||||
|
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# region Model Field Types
|
||||||
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
|
ONNXModel = "ONNXModelField"
|
||||||
|
VaeModel = "VAEModelField"
|
||||||
|
LoRAModel = "LoRAModelField"
|
||||||
|
ControlNetModel = "ControlNetModelField"
|
||||||
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Misc Field Types
|
||||||
|
Scheduler = "SchedulerField"
|
||||||
|
Any = "AnyField"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Internal Field Types
|
||||||
|
_Collection = "CollectionField"
|
||||||
|
_CollectionItem = "CollectionItemField"
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region DEPRECATED
|
||||||
|
Boolean = "DEPRECATED_Boolean"
|
||||||
|
Color = "DEPRECATED_Color"
|
||||||
|
Conditioning = "DEPRECATED_Conditioning"
|
||||||
|
Control = "DEPRECATED_Control"
|
||||||
|
Float = "DEPRECATED_Float"
|
||||||
|
Image = "DEPRECATED_Image"
|
||||||
|
Integer = "DEPRECATED_Integer"
|
||||||
|
Latents = "DEPRECATED_Latents"
|
||||||
|
String = "DEPRECATED_String"
|
||||||
|
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||||
|
ColorCollection = "DEPRECATED_ColorCollection"
|
||||||
|
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||||
|
ControlCollection = "DEPRECATED_ControlCollection"
|
||||||
|
FloatCollection = "DEPRECATED_FloatCollection"
|
||||||
|
ImageCollection = "DEPRECATED_ImageCollection"
|
||||||
|
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||||
|
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||||
|
StringCollection = "DEPRECATED_StringCollection"
|
||||||
|
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||||
|
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||||
|
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||||
|
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||||
|
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||||
|
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||||
|
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||||
|
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||||
|
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||||
|
MainModel = "DEPRECATED_MainModel"
|
||||||
|
UNet = "DEPRECATED_UNet"
|
||||||
|
Vae = "DEPRECATED_Vae"
|
||||||
|
CLIP = "DEPRECATED_CLIP"
|
||||||
|
Collection = "DEPRECATED_Collection"
|
||||||
|
CollectionItem = "DEPRECATED_CollectionItem"
|
||||||
|
Enum = "DEPRECATED_Enum"
|
||||||
|
WorkflowField = "DEPRECATED_WorkflowField"
|
||||||
|
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||||
|
BoardField = "DEPRECATED_BoardField"
|
||||||
|
MetadataItem = "DEPRECATED_MetadataItem"
|
||||||
|
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||||
|
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||||
|
MetadataDict = "DEPRECATED_MetadataDict"
|
||||||
|
|
||||||
|
|
||||||
|
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""
|
||||||
|
The type of UI component to use for a field, used to override the default components, which are
|
||||||
|
inferred from the field type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
None_ = "none"
|
||||||
|
Textarea = "textarea"
|
||||||
|
Slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
|
class FieldDescriptions:
|
||||||
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||||
|
scheduler = "Scheduler to use during inference"
|
||||||
|
positive_cond = "Positive conditioning tensor"
|
||||||
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
noise = "Noise tensor"
|
||||||
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
vae = "VAE"
|
||||||
|
cond = "Conditioning tensor"
|
||||||
|
controlnet_model = "ControlNet model to load"
|
||||||
|
vae_model = "VAE model to load"
|
||||||
|
lora_model = "LoRA model to load"
|
||||||
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||||
|
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||||
|
raw_prompt = "Raw prompt text (no parsing)"
|
||||||
|
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||||
|
skipped_layers = "Number of layers to skip in text encoder"
|
||||||
|
seed = "Seed for random number generation"
|
||||||
|
steps = "Number of steps to run"
|
||||||
|
width = "Width of output (px)"
|
||||||
|
height = "Height of output (px)"
|
||||||
|
control = "ControlNet(s) to apply"
|
||||||
|
ip_adapter = "IP-Adapter to apply"
|
||||||
|
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||||
|
denoised_latents = "Denoised latents tensor"
|
||||||
|
latents = "Latents tensor"
|
||||||
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
metadata = "Optional metadata to be saved with the image"
|
||||||
|
metadata_collection = "Collection of Metadata"
|
||||||
|
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||||
|
metadata_item_label = "Label for this metadata item"
|
||||||
|
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||||
|
workflow = "Optional workflow to be saved with the image"
|
||||||
|
interp_mode = "Interpolation mode"
|
||||||
|
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||||
|
fp32 = "Whether or not to use full float32 precision"
|
||||||
|
precision = "Precision to use"
|
||||||
|
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||||
|
detect_res = "Pixel resolution for detection"
|
||||||
|
image_res = "Pixel resolution for output image"
|
||||||
|
safe_mode = "Whether or not to use safe mode"
|
||||||
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
|
scale_factor = "The factor by which to scale"
|
||||||
|
blend_alpha = (
|
||||||
|
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||||
|
)
|
||||||
|
num_1 = "The first number"
|
||||||
|
num_2 = "The second number"
|
||||||
|
mask = "The mask to use for the operation"
|
||||||
|
board = "The board to save the image to"
|
||||||
|
image = "The image to process"
|
||||||
|
tile_size = "Tile size"
|
||||||
|
inclusive_low = "The inclusive low value"
|
||||||
|
exclusive_high = "The exclusive high value"
|
||||||
|
decimal_places = "The number of decimal places to round to"
|
||||||
|
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||||
|
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||||
|
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||||
|
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||||
|
|
||||||
|
|
||||||
|
class ImageField(BaseModel):
|
||||||
|
"""An image primitive field"""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class BoardField(BaseModel):
|
||||||
|
"""A board primitive field"""
|
||||||
|
|
||||||
|
board_id: str = Field(description="The id of the board")
|
||||||
|
|
||||||
|
|
||||||
|
class DenoiseMaskField(BaseModel):
|
||||||
|
"""An inpaint mask field"""
|
||||||
|
|
||||||
|
mask_name: str = Field(description="The name of the mask image")
|
||||||
|
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsField(BaseModel):
|
||||||
|
"""A latents tensor primitive field"""
|
||||||
|
|
||||||
|
latents_name: str = Field(description="The name of the latents")
|
||||||
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||||
|
|
||||||
|
|
||||||
|
class ColorField(BaseModel):
|
||||||
|
"""A color primitive field"""
|
||||||
|
|
||||||
|
r: int = Field(ge=0, le=255, description="The red component")
|
||||||
|
g: int = Field(ge=0, le=255, description="The green component")
|
||||||
|
b: int = Field(ge=0, le=255, description="The blue component")
|
||||||
|
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||||
|
|
||||||
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningField(BaseModel):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataField(RootModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
Metadata is stored without a strict schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: dict[str, Any] = Field(description="The metadata")
|
||||||
|
|
||||||
|
|
||||||
|
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||||
|
|
||||||
|
|
||||||
|
class Input(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""
|
||||||
|
The type of input a field accepts.
|
||||||
|
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||||
|
are instantiated.
|
||||||
|
- `Input.Connection`: The field must have its value provided by a connection.
|
||||||
|
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Connection = "connection"
|
||||||
|
Direct = "direct"
|
||||||
|
Any = "any"
|
||||||
|
|
||||||
|
|
||||||
|
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""
|
||||||
|
The kind of field.
|
||||||
|
- `Input`: An input field on a node.
|
||||||
|
- `Output`: An output field on a node.
|
||||||
|
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||||
|
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||||
|
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||||
|
allowing "metadata" for that field.
|
||||||
|
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||||
|
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||||
|
attributes.
|
||||||
|
|
||||||
|
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||||
|
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Input = "input"
|
||||||
|
Output = "output"
|
||||||
|
Internal = "internal"
|
||||||
|
NodeAttribute = "node_attribute"
|
||||||
|
|
||||||
|
|
||||||
|
class InputFieldJSONSchemaExtra(BaseModel):
|
||||||
|
"""
|
||||||
|
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||||
|
and by the workflow editor during schema parsing and UI rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input: Input
|
||||||
|
orig_required: bool
|
||||||
|
field_kind: FieldKind
|
||||||
|
default: Optional[Any] = None
|
||||||
|
orig_default: Optional[Any] = None
|
||||||
|
ui_hidden: bool = False
|
||||||
|
ui_type: Optional[UIType] = None
|
||||||
|
ui_component: Optional[UIComponent] = None
|
||||||
|
ui_order: Optional[int] = None
|
||||||
|
ui_choice_labels: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WithMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Inherit from this class if your node needs a metadata input field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: Optional[MetadataField] = Field(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.metadata,
|
||||||
|
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||||
|
field_kind=FieldKind.Internal,
|
||||||
|
input=Input.Connection,
|
||||||
|
orig_required=False,
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WithWorkflow:
|
||||||
|
workflow = None
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
logger.warn(
|
||||||
|
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||||
|
)
|
||||||
|
super().__init_subclass__()
|
||||||
|
|
||||||
|
|
||||||
|
class WithBoard(BaseModel):
|
||||||
|
"""
|
||||||
|
Inherit from this class if your node needs a board input field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
board: Optional[BoardField] = Field(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.board,
|
||||||
|
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||||
|
field_kind=FieldKind.Internal,
|
||||||
|
input=Input.Direct,
|
||||||
|
orig_required=False,
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||||
|
"""
|
||||||
|
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||||
|
during schema parsing and UI rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_kind: FieldKind
|
||||||
|
ui_hidden: bool
|
||||||
|
ui_type: Optional[UIType]
|
||||||
|
ui_order: Optional[int]
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
validate_assignment=True,
|
||||||
|
json_schema_serialization_defaults_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def InputField(
|
||||||
|
# copied from pydantic's Field
|
||||||
|
# TODO: Can we support default_factory?
|
||||||
|
default: Any = _Unset,
|
||||||
|
default_factory: Callable[[], Any] | None = _Unset,
|
||||||
|
title: str | None = _Unset,
|
||||||
|
description: str | None = _Unset,
|
||||||
|
pattern: str | None = _Unset,
|
||||||
|
strict: bool | None = _Unset,
|
||||||
|
gt: float | None = _Unset,
|
||||||
|
ge: float | None = _Unset,
|
||||||
|
lt: float | None = _Unset,
|
||||||
|
le: float | None = _Unset,
|
||||||
|
multiple_of: float | None = _Unset,
|
||||||
|
allow_inf_nan: bool | None = _Unset,
|
||||||
|
max_digits: int | None = _Unset,
|
||||||
|
decimal_places: int | None = _Unset,
|
||||||
|
min_length: int | None = _Unset,
|
||||||
|
max_length: int | None = _Unset,
|
||||||
|
# custom
|
||||||
|
input: Input = Input.Any,
|
||||||
|
ui_type: Optional[UIType] = None,
|
||||||
|
ui_component: Optional[UIComponent] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
ui_order: Optional[int] = None,
|
||||||
|
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an input field for an invocation.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||||
|
`Input.Direct` means a value must be provided on instantiation. \
|
||||||
|
`Input.Connection` means the value must be provided by a connection. \
|
||||||
|
`Input.Any` means either will do.
|
||||||
|
|
||||||
|
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||||
|
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||||
|
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||||
|
For this case, you could provide `UIComponent.Textarea`.
|
||||||
|
|
||||||
|
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||||
|
|
||||||
|
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||||
|
|
||||||
|
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||||
|
input=input,
|
||||||
|
ui_type=ui_type,
|
||||||
|
ui_component=ui_component,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
ui_order=ui_order,
|
||||||
|
ui_choice_labels=ui_choice_labels,
|
||||||
|
field_kind=FieldKind.Input,
|
||||||
|
orig_required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||||
|
`invoke()` function.
|
||||||
|
|
||||||
|
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||||
|
any number of fields may be optional, because they may be provided by connections.
|
||||||
|
|
||||||
|
On calling of `invoke()`, however, those fields may be required.
|
||||||
|
|
||||||
|
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||||
|
|
||||||
|
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||||
|
the field may not be present. This is fine, because that image field will be provided by a
|
||||||
|
connection from an ancestor node, which outputs an image.
|
||||||
|
|
||||||
|
This means we want to type the `image` field as optional for the node class definition, but required
|
||||||
|
for the `invoke()` function.
|
||||||
|
|
||||||
|
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||||
|
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||||
|
any static type analysis tools will complain.
|
||||||
|
|
||||||
|
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||||
|
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||||
|
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if default_factory is not _Unset and default_factory is not None:
|
||||||
|
default = default_factory()
|
||||||
|
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||||
|
|
||||||
|
# These are the args we may wish pass to the pydantic `Field()` function
|
||||||
|
field_args = {
|
||||||
|
"default": default,
|
||||||
|
"title": title,
|
||||||
|
"description": description,
|
||||||
|
"pattern": pattern,
|
||||||
|
"strict": strict,
|
||||||
|
"gt": gt,
|
||||||
|
"ge": ge,
|
||||||
|
"lt": lt,
|
||||||
|
"le": le,
|
||||||
|
"multiple_of": multiple_of,
|
||||||
|
"allow_inf_nan": allow_inf_nan,
|
||||||
|
"max_digits": max_digits,
|
||||||
|
"decimal_places": decimal_places,
|
||||||
|
"min_length": min_length,
|
||||||
|
"max_length": max_length,
|
||||||
|
}
|
||||||
|
|
||||||
|
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||||
|
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||||
|
|
||||||
|
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||||
|
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||||
|
|
||||||
|
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||||
|
if input is Input.Any or input is Input.Connection:
|
||||||
|
default_ = None if default is PydanticUndefined else default
|
||||||
|
provided_args.update({"default": default_})
|
||||||
|
if default is not PydanticUndefined:
|
||||||
|
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||||
|
json_schema_extra_.default = default
|
||||||
|
json_schema_extra_.orig_default = default
|
||||||
|
elif default is not PydanticUndefined:
|
||||||
|
default_ = default
|
||||||
|
provided_args.update({"default": default_})
|
||||||
|
json_schema_extra_.orig_default = default_
|
||||||
|
|
||||||
|
return Field(
|
||||||
|
**provided_args,
|
||||||
|
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def OutputField(
|
||||||
|
# copied from pydantic's Field
|
||||||
|
default: Any = _Unset,
|
||||||
|
title: str | None = _Unset,
|
||||||
|
description: str | None = _Unset,
|
||||||
|
pattern: str | None = _Unset,
|
||||||
|
strict: bool | None = _Unset,
|
||||||
|
gt: float | None = _Unset,
|
||||||
|
ge: float | None = _Unset,
|
||||||
|
lt: float | None = _Unset,
|
||||||
|
le: float | None = _Unset,
|
||||||
|
multiple_of: float | None = _Unset,
|
||||||
|
allow_inf_nan: bool | None = _Unset,
|
||||||
|
max_digits: int | None = _Unset,
|
||||||
|
decimal_places: int | None = _Unset,
|
||||||
|
min_length: int | None = _Unset,
|
||||||
|
max_length: int | None = _Unset,
|
||||||
|
# custom
|
||||||
|
ui_type: Optional[UIType] = None,
|
||||||
|
ui_hidden: bool = False,
|
||||||
|
ui_order: Optional[int] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Creates an output field for an invocation output.
|
||||||
|
|
||||||
|
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||||
|
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||||
|
|
||||||
|
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||||
|
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||||
|
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||||
|
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||||
|
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||||
|
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||||
|
|
||||||
|
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||||
|
|
||||||
|
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||||
|
"""
|
||||||
|
return Field(
|
||||||
|
default=default,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
pattern=pattern,
|
||||||
|
strict=strict,
|
||||||
|
gt=gt,
|
||||||
|
ge=ge,
|
||||||
|
lt=lt,
|
||||||
|
le=le,
|
||||||
|
multiple_of=multiple_of,
|
||||||
|
allow_inf_nan=allow_inf_nan,
|
||||||
|
max_digits=max_digits,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
min_length=min_length,
|
||||||
|
max_length=max_length,
|
||||||
|
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||||
|
ui_type=ui_type,
|
||||||
|
ui_hidden=ui_hidden,
|
||||||
|
ui_order=ui_order,
|
||||||
|
field_kind=FieldKind.Output,
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
)
|
File diff suppressed because it is too large
Load Diff
@ -6,14 +6,16 @@ from typing import Literal, Optional, get_args
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
|
||||||
@ -118,8 +120,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -129,33 +131,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -168,33 +157,20 @@ class InfillTileInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
|
||||||
)
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
@ -202,7 +178,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
|||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
@ -227,77 +203,38 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
|||||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = infill_lama(image.copy())
|
infilled = infill_lama(image.copy())
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = infill_cv2(image.copy())
|
infilled = infill_cv2(image.copy())
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -7,16 +7,13 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||||
|
|
||||||
@ -65,7 +62,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
@ -98,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.services.model_manager.model_info(
|
ip_adapter_info = context.models.get_info(
|
||||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||||
)
|
)
|
||||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||||
@ -107,7 +104,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||||
# disk vs. downloading the model.
|
# disk vs. downloading the model.
|
||||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
os.path.join(context.config.get().models_path, ip_adapter_info["path"])
|
||||||
)
|
)
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_model = CLIPVisionModelField(
|
image_encoder_model = CLIPVisionModelField(
|
||||||
|
@ -23,21 +23,29 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import (
|
||||||
DenoiseMaskField,
|
|
||||||
DenoiseMaskOutput,
|
DenoiseMaskOutput,
|
||||||
ImageField,
|
|
||||||
ImageOutput,
|
ImageOutput,
|
||||||
LatentsField,
|
|
||||||
LatentsOutput,
|
LatentsOutput,
|
||||||
build_latents_output,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||||
@ -59,16 +67,9 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
WithMetadata,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
@ -77,18 +78,10 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
|
||||||
|
|
||||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
|
||||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
|
||||||
# factor is hard-coded to a literal '8' rather than using this constant.
|
|
||||||
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
|
||||||
LATENT_SCALE_FACTOR = 8
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
class SchedulerOutput(BaseInvocationOutput):
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -101,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
|||||||
class SchedulerInvocation(BaseInvocation):
|
class SchedulerInvocation(BaseInvocation):
|
||||||
"""Selects a scheduler."""
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
default="euler",
|
default="euler",
|
||||||
description=FieldDescriptions.scheduler,
|
description=FieldDescriptions.scheduler,
|
||||||
ui_type=UIType.Scheduler,
|
ui_type=UIType.Scheduler,
|
||||||
@ -116,7 +109,7 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
title="Create Denoise Mask",
|
title="Create Denoise Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -144,7 +137,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||||
if self.image is not None:
|
if self.image is not None:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image.dim() == 3:
|
if image.dim() == 3:
|
||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
@ -152,33 +145,26 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
image = None
|
image = None
|
||||||
|
|
||||||
mask = self.prep_mask_tensor(
|
mask = self.prep_mask_tensor(
|
||||||
context.services.images.get_pil_image(self.mask.image_name),
|
context.images.get_pil(self.mask.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
# TODO:
|
# TODO:
|
||||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
|
|
||||||
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||||
context.services.latents.save(masked_latents_name, masked_latents)
|
|
||||||
else:
|
else:
|
||||||
masked_latents_name = None
|
masked_latents_name = None
|
||||||
|
|
||||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
mask_name = context.tensors.save(tensor=mask)
|
||||||
context.services.latents.save(mask_name, mask)
|
|
||||||
|
|
||||||
return DenoiseMaskOutput(
|
return DenoiseMaskOutput.build(
|
||||||
denoise_mask=DenoiseMaskField(
|
mask_name=mask_name,
|
||||||
mask_name=mask_name,
|
masked_latents_name=masked_latents_name,
|
||||||
masked_latents_name=masked_latents_name,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -189,10 +175,7 @@ def get_scheduler(
|
|||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||||
**scheduler_info.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
@ -221,7 +204,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.5.1",
|
version="1.5.2",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -249,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.denoising_start,
|
description=FieldDescriptions.denoising_start,
|
||||||
)
|
)
|
||||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||||
default="euler",
|
default="euler",
|
||||||
description=FieldDescriptions.scheduler,
|
description=FieldDescriptions.scheduler,
|
||||||
ui_type=UIType.Scheduler,
|
ui_type=UIType.Scheduler,
|
||||||
@ -307,22 +290,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.model_dump(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -330,11 +297,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = c.extra_conditioning
|
extra_conditioning_info = c.extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
@ -422,17 +389,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(
|
control_model = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=control_info.control_model.model_name,
|
model_name=control_info.control_model.model_name,
|
||||||
model_type=ModelType.ControlNet,
|
model_type=ModelType.ControlNet,
|
||||||
base_model=control_info.control_model.base_model,
|
base_model=control_info.control_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
input_image = context.images.get_pil(control_image_field.image_name)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -490,19 +456,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||||
model_type=ModelType.IPAdapter,
|
model_type=ModelType.IPAdapter,
|
||||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.services.model_manager.get_model(
|
image_encoder_model_info = context.models.load(
|
||||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||||
model_type=ModelType.CLIPVision,
|
model_type=ModelType.CLIPVision,
|
||||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
@ -510,7 +474,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not isinstance(single_ipa_images, list):
|
if not isinstance(single_ipa_images, list):
|
||||||
single_ipa_images = [single_ipa_images]
|
single_ipa_images = [single_ipa_images]
|
||||||
|
|
||||||
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
@ -554,13 +518,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
t2i_adapter_model_info = context.models.load(
|
||||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||||
model_type=ModelType.T2IAdapter,
|
model_type=ModelType.T2IAdapter,
|
||||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||||
@ -647,14 +610,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
return num_inference_steps, timesteps, init_timestep
|
||||||
|
|
||||||
def prep_inpaint_mask(self, context, latents):
|
def prep_inpaint_mask(self, context: InvocationContext, latents):
|
||||||
if self.denoise_mask is None:
|
if self.denoise_mask is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
if self.denoise_mask.masked_latents_name is not None:
|
if self.denoise_mask.masked_latents_name is not None:
|
||||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||||
else:
|
else:
|
||||||
masked_latents = None
|
masked_latents = None
|
||||||
|
|
||||||
@ -666,11 +629,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.tensors.load(self.noise.latents_name)
|
||||||
seed = self.noise.seed
|
seed = self.noise.seed
|
||||||
|
|
||||||
if self.latents is not None:
|
if self.latents is not None:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = self.latents.seed
|
||||||
|
|
||||||
@ -696,27 +659,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
context.util.sd_step_callback(state, self.unet.unet.base_model)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||||
**self.unet.unet.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||||
@ -792,9 +745,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=result_latents)
|
||||||
context.services.latents.save(name, result_latents)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -802,9 +754,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
title="Latents to Image",
|
title="Latents to Image",
|
||||||
tags=["latents", "image", "vae", "l2i"],
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
@ -820,12 +772,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
@ -854,7 +803,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
if self.tiled or context.services.configuration.tiled_decode:
|
if self.tiled or context.config.get().tiled_decode:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@ -878,22 +827,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=image)
|
||||||
image=image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
@ -904,7 +840,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
|
|||||||
title="Resize Latents",
|
title="Resize Latents",
|
||||||
tags=["latents", "resize"],
|
tags=["latents", "resize"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
@ -927,7 +863,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -945,10 +881,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
context.services.latents.save(name, resized_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -956,7 +890,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
title="Scale Latents",
|
title="Scale Latents",
|
||||||
tags=["latents", "resize"],
|
tags=["latents", "resize"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
@ -970,7 +904,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -989,10 +923,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
context.services.latents.save(name, resized_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -1000,7 +932,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
title="Image to Latents",
|
title="Image to Latents",
|
||||||
tags=["latents", "image", "vae", "i2l"],
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
@ -1061,12 +993,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
@ -1074,10 +1003,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
name = context.tensors.save(tensor=latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
@singledispatchmethod
|
@singledispatchmethod
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1097,7 +1025,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
title="Blend Latents",
|
title="Blend Latents",
|
||||||
tags=["latents", "blend"],
|
tags=["latents", "blend"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
@ -1113,8 +1041,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
@ -1168,10 +1096,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||||
context.services.latents.save(name, blended_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
|
||||||
|
|
||||||
|
|
||||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||||
@ -1181,7 +1107,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
title="Crop Latents",
|
title="Crop Latents",
|
||||||
tags=["latents", "crop"],
|
tags=["latents", "crop"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||||
@ -1216,7 +1142,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
x1 = self.x // LATENT_SCALE_FACTOR
|
x1 = self.x // LATENT_SCALE_FACTOR
|
||||||
y1 = self.y // LATENT_SCALE_FACTOR
|
y1 = self.y // LATENT_SCALE_FACTOR
|
||||||
@ -1225,10 +1151,9 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
|||||||
|
|
||||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=cropped_latents)
|
||||||
context.services.latents.save(name, cropped_latents)
|
|
||||||
|
|
||||||
return build_latents_output(latents_name=name, latents=cropped_latents)
|
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("ideal_size_output")
|
@invocation_output("ideal_size_output")
|
||||||
|
@ -5,10 +5,11 @@ from typing import Literal
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import ValidationInfo, field_validator
|
from pydantic import ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
||||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||||
|
@ -5,20 +5,22 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
MetadataField,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
InputField,
|
||||||
|
MetadataField,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
|
@ -3,17 +3,14 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -105,7 +102,7 @@ class LoRAModelField(BaseModel):
|
|||||||
title="Main Model",
|
title="Main Model",
|
||||||
tags=["model"],
|
tags=["model"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
@ -119,7 +116,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -206,7 +203,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
@ -232,7 +229,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
base_model = self.lora.base_model
|
base_model = self.lora.base_model
|
||||||
lora_name = self.lora.model_name
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
@ -288,7 +285,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
title="SDXL LoRA",
|
title="SDXL LoRA",
|
||||||
tags=["lora", "model"],
|
tags=["lora", "model"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
@ -321,7 +318,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
base_model = self.lora.base_model
|
base_model = self.lora.base_model
|
||||||
lora_name = self.lora.model_name
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
@ -387,7 +384,7 @@ class VAEModelField(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
@ -402,7 +399,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
model_name = self.vae_model.model_name
|
model_name = self.vae_model.model_name
|
||||||
model_type = ModelType.Vae
|
model_type = ModelType.Vae
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -4,17 +4,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -69,13 +67,13 @@ class NoiseOutput(BaseInvocationOutput):
|
|||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput":
|
||||||
return NoiseOutput(
|
return cls(
|
||||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -96,13 +94,13 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
width: int = InputField(
|
width: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
gt=0,
|
gt=0,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = InputField(
|
height: int = InputField(
|
||||||
default=512,
|
default=512,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
gt=0,
|
gt=0,
|
||||||
description=FieldDescriptions.height,
|
description=FieldDescriptions.height,
|
||||||
)
|
)
|
||||||
@ -124,6 +122,5 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.tensors.save(tensor=noise)
|
||||||
context.services.latents.save(name, noise)
|
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
|
||||||
|
@ -1,508 +0,0 @@
|
|||||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
# from contextlib import ExitStack
|
|
||||||
from typing import List, Literal, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
|
||||||
|
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
||||||
from ...backend.util import choose_torch_device
|
|
||||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
|
||||||
from .baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
UIType,
|
|
||||||
WithMetadata,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
|
||||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
|
||||||
"tensor(bool)": np.bool_,
|
|
||||||
"tensor(int8)": np.int8,
|
|
||||||
"tensor(uint8)": np.uint8,
|
|
||||||
"tensor(int16)": np.int16,
|
|
||||||
"tensor(uint16)": np.uint16,
|
|
||||||
"tensor(int32)": np.int32,
|
|
||||||
"tensor(uint32)": np.uint32,
|
|
||||||
"tensor(int64)": np.int64,
|
|
||||||
"tensor(uint64)": np.uint64,
|
|
||||||
"tensor(float16)": np.float16,
|
|
||||||
"tensor(float)": np.float32,
|
|
||||||
"tensor(double)": np.float64,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.tokenizer.model_dump(),
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.text_encoder.model_dump(),
|
|
||||||
)
|
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
|
||||||
loras = [
|
|
||||||
(
|
|
||||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
|
||||||
lora.weight,
|
|
||||||
)
|
|
||||||
for lora in self.clip.loras
|
|
||||||
]
|
|
||||||
|
|
||||||
ti_list = []
|
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
(
|
|
||||||
name,
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=self.clip.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# print(e)
|
|
||||||
# import traceback
|
|
||||||
# print(traceback.format_exc())
|
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
|
||||||
if loras or ti_list:
|
|
||||||
text_encoder.release_session()
|
|
||||||
with (
|
|
||||||
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
|
||||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
|
||||||
):
|
|
||||||
text_encoder.create_session()
|
|
||||||
|
|
||||||
# copy from
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
self.prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
"""
|
|
||||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
|
||||||
|
|
||||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
|
||||||
removed_text = self.tokenizer.batch_decode(
|
|
||||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
||||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
|
||||||
@invocation(
|
|
||||||
"t2l_onnx",
|
|
||||||
title="ONNX Text to Latents",
|
|
||||||
tags=["latents", "inference", "txt2img", "onnx"],
|
|
||||||
category="latents",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
|
||||||
"""Generates latents from conditionings."""
|
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
|
||||||
description=FieldDescriptions.positive_cond,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
negative_conditioning: ConditioningField = InputField(
|
|
||||||
description=FieldDescriptions.negative_cond,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
noise: LatentsField = InputField(
|
|
||||||
description=FieldDescriptions.noise,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
||||||
cfg_scale: Union[float, List[float]] = InputField(
|
|
||||||
default=7.5,
|
|
||||||
ge=1,
|
|
||||||
description=FieldDescriptions.cfg_scale,
|
|
||||||
)
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
|
||||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
|
||||||
)
|
|
||||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
|
||||||
unet: UNetField = InputField(
|
|
||||||
description=FieldDescriptions.unet,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.control,
|
|
||||||
)
|
|
||||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
||||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
|
||||||
|
|
||||||
@field_validator("cfg_scale")
|
|
||||||
def ge_one(cls, v):
|
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
|
||||||
if isinstance(v, list):
|
|
||||||
for i in v:
|
|
||||||
if i < 1:
|
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
|
||||||
else:
|
|
||||||
if v < 1:
|
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# based on
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
if isinstance(c, torch.Tensor):
|
|
||||||
c = c.cpu().numpy()
|
|
||||||
if isinstance(uc, torch.Tensor):
|
|
||||||
uc = uc.cpu().numpy()
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
prompt_embeds = np.concatenate([uc, c])
|
|
||||||
|
|
||||||
latents = context.services.latents.get(self.noise.latents_name)
|
|
||||||
if isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.cpu().numpy()
|
|
||||||
|
|
||||||
# TODO: better execution device handling
|
|
||||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
|
||||||
do_classifier_free_guidance = True
|
|
||||||
# latents_dtype = prompt_embeds.dtype
|
|
||||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
|
||||||
# if latents.shape != latents_shape:
|
|
||||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
|
||||||
context=context,
|
|
||||||
scheduler_info=self.unet.scheduler,
|
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
seed=0, # TODO: refactor this node
|
|
||||||
)
|
|
||||||
|
|
||||||
def torch2numpy(latent: torch.Tensor):
|
|
||||||
return latent.cpu().numpy()
|
|
||||||
|
|
||||||
def numpy2torch(latent, device):
|
|
||||||
return torch.from_numpy(latent).to(device)
|
|
||||||
|
|
||||||
def dispatch_progress(
|
|
||||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.model_dump(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler.set_timesteps(self.steps)
|
|
||||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
|
||||||
|
|
||||||
extra_step_kwargs = {}
|
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
|
||||||
extra_step_kwargs.update(
|
|
||||||
eta=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
|
||||||
|
|
||||||
with unet_info as unet: # , ExitStack() as stack:
|
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
|
||||||
loras = [
|
|
||||||
(
|
|
||||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
|
||||||
lora.weight,
|
|
||||||
)
|
|
||||||
for lora in self.unet.loras
|
|
||||||
]
|
|
||||||
|
|
||||||
if loras:
|
|
||||||
unet.release_session()
|
|
||||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
|
||||||
# TODO:
|
|
||||||
_, _, h, w = latents.shape
|
|
||||||
unet.create_session(h, w)
|
|
||||||
|
|
||||||
timestep_dtype = next(
|
|
||||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
|
||||||
)
|
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
|
||||||
for i in tqdm(range(len(scheduler.timesteps))):
|
|
||||||
t = scheduler.timesteps[i]
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
|
||||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
|
||||||
noise_pred = noise_pred[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
scheduler_output = scheduler.step(
|
|
||||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
|
||||||
)
|
|
||||||
latents = torch2numpy(scheduler_output.prev_sample)
|
|
||||||
|
|
||||||
state = PipelineIntermediateState(
|
|
||||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
|
||||||
)
|
|
||||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
|
||||||
|
|
||||||
# call the callback, if provided
|
|
||||||
# if callback is not None and i % callback_steps == 0:
|
|
||||||
# callback(i, t, latents)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
|
||||||
@invocation(
|
|
||||||
"l2i_onnx",
|
|
||||||
title="ONNX Latents to Image",
|
|
||||||
tags=["latents", "image", "vae", "onnx"],
|
|
||||||
category="image",
|
|
||||||
version="1.2.0",
|
|
||||||
)
|
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|
||||||
"""Generates an image from latents."""
|
|
||||||
|
|
||||||
latents: LatentsField = InputField(
|
|
||||||
description=FieldDescriptions.denoised_latents,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
vae: VaeField = InputField(
|
|
||||||
description=FieldDescriptions.vae,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
|
||||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with vae_info as vae:
|
|
||||||
vae.create_session()
|
|
||||||
|
|
||||||
# copied from
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
|
||||||
latents = 1 / 0.18215 * latents
|
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
|
||||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
|
||||||
|
|
||||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
||||||
image = image.transpose((0, 2, 3, 1))
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("model_loader_output_onnx")
|
|
||||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|
||||||
"""Model loader output"""
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
|
||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
|
||||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
|
||||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelField(BaseModel):
|
|
||||||
"""Onnx model field"""
|
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
|
||||||
model_type: ModelType = Field(description="Model Type")
|
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loads a main model, outputting its submodels."""
|
|
||||||
|
|
||||||
model: OnnxModelField = InputField(
|
|
||||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
||||||
base_model = self.model.base_model
|
|
||||||
model_name = self.model.model_name
|
|
||||||
model_type = ModelType.ONNX
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae_decoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.VaeDecoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
vae_encoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.VaeEncoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
@ -40,8 +40,10 @@ from easing_functions import (
|
|||||||
from matplotlib.ticker import MaxNLocator
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -109,7 +111,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
|||||||
title="Step Param Easing",
|
title="Step Param Easing",
|
||||||
tags=["step", "easing"],
|
tags=["step", "easing"],
|
||||||
category="step",
|
category="step",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
@ -148,19 +150,19 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
postlist = list(num_poststeps * [self.post_end_value])
|
postlist = list(num_poststeps * [self.post_end_value])
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("start_step: " + str(start_step))
|
context.logger.debug("start_step: " + str(start_step))
|
||||||
context.services.logger.debug("end_step: " + str(end_step))
|
context.logger.debug("end_step: " + str(end_step))
|
||||||
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||||
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||||
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
context.services.logger.debug("prelist: " + str(prelist))
|
context.logger.debug("prelist: " + str(prelist))
|
||||||
context.services.logger.debug("postlist: " + str(postlist))
|
context.logger.debug("postlist: " + str(postlist))
|
||||||
|
|
||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("easing class: " + str(easing_class))
|
context.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = []
|
easing_list = []
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
@ -171,7 +173,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
|
|
||||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||||
easing_function = easing_class(
|
easing_function = easing_class(
|
||||||
start=self.start_value,
|
start=self.start_value,
|
||||||
@ -183,14 +185,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||||
if even_num_steps:
|
if even_num_steps:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||||
else:
|
else:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||||
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||||
easing_list = base_easing_vals + mirror_easing_vals
|
easing_list = base_easing_vals + mirror_easing_vals
|
||||||
|
|
||||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
@ -225,12 +227,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
|
|
||||||
param_list = prelist + easing_list + postlist
|
param_list = prelist + easing_list + postlist
|
||||||
|
|
||||||
|
@ -1,20 +1,28 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
ColorField,
|
||||||
|
ConditioningField,
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -221,18 +229,6 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
# region Image
|
# region Image
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
|
||||||
"""An image primitive field"""
|
|
||||||
|
|
||||||
image_name: str = Field(description="The name of the image")
|
|
||||||
|
|
||||||
|
|
||||||
class BoardField(BaseModel):
|
|
||||||
"""A board primitive field"""
|
|
||||||
|
|
||||||
board_id: str = Field(description="The id of the board")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_output")
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
@ -241,6 +237,14 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
width: int = OutputField(description="The width of the image in pixels")
|
width: int = OutputField(description="The width of the image in pixels")
|
||||||
height: int = OutputField(description="The height of the image in pixels")
|
height: int = OutputField(description="The height of the image in pixels")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, image_dto: ImageDTO) -> "ImageOutput":
|
||||||
|
return cls(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_collection_output")
|
@invocation_output("image_collection_output")
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
@ -251,16 +255,14 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
||||||
class ImageInvocation(
|
class ImageInvocation(BaseInvocation):
|
||||||
BaseInvocation,
|
|
||||||
):
|
|
||||||
"""An image primitive value"""
|
"""An image primitive value"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to load")
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_name=self.image.image_name),
|
image=ImageField(image_name=self.image.image_name),
|
||||||
@ -290,42 +292,40 @@ class ImageCollectionInvocation(BaseInvocation):
|
|||||||
# region DenoiseMask
|
# region DenoiseMask
|
||||||
|
|
||||||
|
|
||||||
class DenoiseMaskField(BaseModel):
|
|
||||||
"""An inpaint mask field"""
|
|
||||||
|
|
||||||
mask_name: str = Field(description="The name of the mask image")
|
|
||||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("denoise_mask_output")
|
@invocation_output("denoise_mask_output")
|
||||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
|
||||||
|
return cls(
|
||||||
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Latents
|
# region Latents
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
|
||||||
"""A latents tensor primitive field"""
|
|
||||||
|
|
||||||
latents_name: str = Field(description="The name of the latents")
|
|
||||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("latents_output")
|
@invocation_output("latents_output")
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single latents tensor"""
|
"""Base class for nodes that output a single latents tensor"""
|
||||||
|
|
||||||
latents: LatentsField = OutputField(
|
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||||
description=FieldDescriptions.latents,
|
|
||||||
)
|
|
||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput":
|
||||||
|
return cls(
|
||||||
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
|
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||||
|
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("latents_collection_output")
|
@invocation_output("latents_collection_output")
|
||||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||||
@ -337,7 +337,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
||||||
)
|
)
|
||||||
class LatentsInvocation(BaseInvocation):
|
class LatentsInvocation(BaseInvocation):
|
||||||
"""A latents tensor primitive value"""
|
"""A latents tensor primitive value"""
|
||||||
@ -345,9 +345,9 @@ class LatentsInvocation(BaseInvocation):
|
|||||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
return build_latents_output(self.latents.latents_name, latents)
|
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -368,31 +368,11 @@ class LatentsCollectionInvocation(BaseInvocation):
|
|||||||
return LatentsCollectionOutput(collection=self.collection)
|
return LatentsCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
|
||||||
return LatentsOutput(
|
|
||||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
|
||||||
width=latents.size()[3] * 8,
|
|
||||||
height=latents.size()[2] * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Color
|
# region Color
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
|
||||||
"""A color primitive field"""
|
|
||||||
|
|
||||||
r: int = Field(ge=0, le=255, description="The red component")
|
|
||||||
g: int = Field(ge=0, le=255, description="The green component")
|
|
||||||
b: int = Field(ge=0, le=255, description="The blue component")
|
|
||||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
|
||||||
|
|
||||||
def tuple(self) -> Tuple[int, int, int, int]:
|
|
||||||
return (self.r, self.g, self.b, self.a)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("color_output")
|
@invocation_output("color_output")
|
||||||
class ColorOutput(BaseInvocationOutput):
|
class ColorOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single color"""
|
"""Base class for nodes that output a single color"""
|
||||||
@ -424,18 +404,16 @@ class ColorInvocation(BaseInvocation):
|
|||||||
# region Conditioning
|
# region Conditioning
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
|
||||||
"""A conditioning tensor primitive value"""
|
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_output")
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||||
|
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_collection_output")
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
|
@ -6,8 +6,10 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt
|
|||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField, UIComponent
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -34,7 +30,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -49,7 +45,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -120,7 +116,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
title="SDXL Refiner Model",
|
title="SDXL Refiner Model",
|
||||||
tags=["model", "sdxl", "refiner"],
|
tags=["model", "sdxl", "refiner"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
@ -138,7 +134,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -2,16 +2,15 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIComponent,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
from .fields import InputField, OutputField, UIComponent
|
||||||
from .primitives import StringOutput
|
from .primitives import StringOutput
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,17 +5,13 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,16 +8,12 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Classification,
|
Classification,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
WithMetadata,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.tiles.tiles import (
|
from invokeai.backend.tiles.tiles import (
|
||||||
calc_tiles_even_split,
|
calc_tiles_even_split,
|
||||||
calc_tiles_min_overlap,
|
calc_tiles_min_overlap,
|
||||||
@ -236,7 +232,7 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
|||||||
version="1.1.0",
|
version="1.1.0",
|
||||||
classification=Classification.Beta,
|
classification=Classification.Beta,
|
||||||
)
|
)
|
||||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Merge multiple tile images into a single image."""
|
"""Merge multiple tile images into a single image."""
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -268,7 +264,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
# existed in memory at an earlier point in the graph.
|
# existed in memory at an earlier point in the graph.
|
||||||
tile_np_images: list[np.ndarray] = []
|
tile_np_images: list[np.ndarray] = []
|
||||||
for image in images:
|
for image in images:
|
||||||
pil_image = context.services.images.get_pil_image(image.image_name)
|
pil_image = context.images.get_pil(image.image_name)
|
||||||
pil_image = pil_image.convert("RGB")
|
pil_image = pil_image.convert("RGB")
|
||||||
tile_np_images.append(np.array(pil_image))
|
tile_np_images.append(np.array(pil_image))
|
||||||
|
|
||||||
@ -291,18 +287,5 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
# Convert into a PIL image and save
|
# Convert into a PIL image and save
|
||||||
pil_image = Image.fromarray(np_image)
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=pil_image)
|
||||||
image=pil_image,
|
return ImageOutput.build(image_dto)
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -8,13 +8,15 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
# TODO: Use model manager to load?
|
# TODO: Use model manager to load?
|
||||||
@ -29,8 +31,8 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
from torch import mps
|
from torch import mps
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
|
||||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
@ -42,8 +44,8 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
models_path = context.services.configuration.models_path
|
models_path = context.config.get().models_path
|
||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
@ -87,7 +89,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
netscale = 2
|
netscale = 2
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
context.services.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
@ -110,19 +112,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=pil_image)
|
||||||
image=pil_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class EventServiceBase:
|
|||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
queue_batch_id: str,
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node_id: str,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
progress_image: Optional[ProgressImage],
|
progress_image: Optional[ProgressImage],
|
||||||
step: int,
|
step: int,
|
||||||
@ -70,7 +70,7 @@ class EventServiceBase:
|
|||||||
"queue_item_id": queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
"node_id": node.get("id"),
|
"node_id": node_id,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||||
"step": step,
|
"step": step,
|
||||||
@ -201,7 +201,7 @@ class EventServiceBase:
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
model_info: ModelInfo,
|
loaded_model_info: LoadedModelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
@ -215,9 +215,9 @@ class EventServiceBase:
|
|||||||
"base_model": base_model,
|
"base_model": base_model,
|
||||||
"model_type": model_type,
|
"model_type": model_type,
|
||||||
"submodel": submodel,
|
"submodel": submodel,
|
||||||
"hash": model_info.hash,
|
"hash": loaded_model_info.hash,
|
||||||
"location": str(model_info.location),
|
"location": str(loaded_model_info.location),
|
||||||
"precision": str(model_info.precision),
|
"precision": str(loaded_model_info.precision),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
from invokeai.app.invocations.fields import MetadataField
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from PIL import Image, PngImagePlugin
|
|||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
from invokeai.app.invocations.fields import MetadataField
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import MetadataField
|
from invokeai.app.invocations.fields import MetadataField
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||||
|
@ -3,7 +3,7 @@ import threading
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
from invokeai.app.invocations.fields import MetadataField
|
||||||
from invokeai.app.services.image_records.image_records_common import (
|
from invokeai.app.services.image_records.image_records_common import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
from invokeai.app.invocations.fields import MetadataField
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
|
@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
if self._max_cache_size == 0:
|
if self._max_cache_size == 0:
|
||||||
return
|
return
|
||||||
self._invoker.services.images.on_deleted(self._delete_by_match)
|
self._invoker.services.images.on_deleted(self._delete_by_match)
|
||||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
self._invoker.services.tensors.on_deleted(self._delete_by_match)
|
||||||
|
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
|
||||||
|
|
||||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -5,11 +5,11 @@ from threading import BoundedSemaphore, Event, Thread
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||||
GESStatsNotFoundError,
|
GESStatsNotFoundError,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
@ -131,16 +131,20 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# which handles a few things:
|
# which handles a few things:
|
||||||
# - nodes that require a value, but get it only from a connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
# - referencing the invocation cache instead of executing the node
|
# - referencing the invocation cache instead of executing the node
|
||||||
outputs = invocation.invoke_internal(
|
context_data = InvocationContextData(
|
||||||
InvocationContext(
|
invocation=invocation,
|
||||||
services=self.__invoker.services,
|
session_id=graph_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
workflow=queue_item.workflow,
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
source_node_id=source_node_id,
|
||||||
queue_id=queue_item.session_queue_id,
|
queue_id=queue_item.session_queue_id,
|
||||||
queue_batch_id=queue_item.session_queue_batch_id,
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
workflow=queue_item.workflow,
|
batch_id=queue_item.session_queue_batch_id,
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
context = build_invocation_context(
|
||||||
|
services=self.__invoker.services,
|
||||||
|
context_data=context_data,
|
||||||
|
)
|
||||||
|
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
|
@ -3,9 +3,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||||
from .board_images.board_images_base import BoardImagesServiceABC
|
from .board_images.board_images_base import BoardImagesServiceABC
|
||||||
from .board_records.board_records_base import BoardRecordStorageBase
|
from .board_records.board_records_base import BoardRecordStorageBase
|
||||||
@ -21,7 +27,6 @@ if TYPE_CHECKING:
|
|||||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .item_storage.item_storage_base import ItemStorageABC
|
from .item_storage.item_storage_base import ItemStorageABC
|
||||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
|
||||||
from .model_install import ModelInstallServiceBase
|
from .model_install import ModelInstallServiceBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .model_records import ModelRecordServiceBase
|
from .model_records import ModelRecordServiceBase
|
||||||
@ -36,33 +41,6 @@ if TYPE_CHECKING:
|
|||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
|
||||||
board_images: "BoardImagesServiceABC"
|
|
||||||
board_image_record_storage: "BoardImageRecordStorageBase"
|
|
||||||
boards: "BoardServiceABC"
|
|
||||||
board_records: "BoardRecordStorageBase"
|
|
||||||
configuration: "InvokeAIAppConfig"
|
|
||||||
events: "EventServiceBase"
|
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
|
||||||
images: "ImageServiceABC"
|
|
||||||
image_records: "ImageRecordStorageBase"
|
|
||||||
image_files: "ImageFileStorageBase"
|
|
||||||
latents: "LatentsStorageBase"
|
|
||||||
logger: "Logger"
|
|
||||||
model_manager: "ModelManagerServiceBase"
|
|
||||||
model_records: "ModelRecordServiceBase"
|
|
||||||
download_queue: "DownloadQueueServiceBase"
|
|
||||||
model_install: "ModelInstallServiceBase"
|
|
||||||
processor: "InvocationProcessorABC"
|
|
||||||
performance_statistics: "InvocationStatsServiceBase"
|
|
||||||
queue: "InvocationQueueABC"
|
|
||||||
session_queue: "SessionQueueBase"
|
|
||||||
session_processor: "SessionProcessorBase"
|
|
||||||
invocation_cache: "InvocationCacheBase"
|
|
||||||
names: "NameServiceBase"
|
|
||||||
urls: "UrlServiceBase"
|
|
||||||
workflow_records: "WorkflowRecordsStorageBase"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
@ -75,7 +53,6 @@ class InvocationServices:
|
|||||||
images: "ImageServiceABC",
|
images: "ImageServiceABC",
|
||||||
image_files: "ImageFileStorageBase",
|
image_files: "ImageFileStorageBase",
|
||||||
image_records: "ImageRecordStorageBase",
|
image_records: "ImageRecordStorageBase",
|
||||||
latents: "LatentsStorageBase",
|
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
model_records: "ModelRecordServiceBase",
|
model_records: "ModelRecordServiceBase",
|
||||||
@ -90,6 +67,8 @@ class InvocationServices:
|
|||||||
names: "NameServiceBase",
|
names: "NameServiceBase",
|
||||||
urls: "UrlServiceBase",
|
urls: "UrlServiceBase",
|
||||||
workflow_records: "WorkflowRecordsStorageBase",
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
|
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||||
|
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@ -101,7 +80,6 @@ class InvocationServices:
|
|||||||
self.images = images
|
self.images = images
|
||||||
self.image_files = image_files
|
self.image_files = image_files
|
||||||
self.image_records = image_records
|
self.image_records = image_records
|
||||||
self.latents = latents
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.model_records = model_records
|
self.model_records = model_records
|
||||||
@ -116,3 +94,5 @@ class InvocationServices:
|
|||||||
self.names = names
|
self.names = names
|
||||||
self.urls = urls
|
self.urls = urls
|
||||||
self.workflow_records = workflow_records
|
self.workflow_records = workflow_records
|
||||||
|
self.tensors = tensors
|
||||||
|
self.conditioning = conditioning
|
||||||
|
@ -30,7 +30,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set(self, item: T) -> None:
|
def set(self, item: T) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the item. The id will be extracted based on id_field.
|
Sets the item.
|
||||||
:param item: the item to set
|
:param item: the item to set
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
|
||||||
"""Responsible for storing and retrieving latents."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = []
|
|
||||||
self._on_deleted_callbacks = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
|
||||||
"""Register a callback for when an item is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an item is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: torch.Tensor) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
@ -1,58 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
|
|
||||||
from .latents_storage_base import LatentsStorageBase
|
|
||||||
|
|
||||||
|
|
||||||
class DiskLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Stores latents in a folder on disk without caching"""
|
|
||||||
|
|
||||||
__output_folder: Path
|
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
self._delete_all_latents()
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
return torch.load(latent_path)
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
torch.save(data, latent_path)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
latent_path.unlink()
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
|
||||||
return self.__output_folder / name
|
|
||||||
|
|
||||||
def _delete_all_latents(self) -> None:
|
|
||||||
"""
|
|
||||||
Deletes all latents from disk.
|
|
||||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
|
||||||
"""
|
|
||||||
deleted_latents_count = 0
|
|
||||||
freed_space = 0
|
|
||||||
for latents_file in Path(self.__output_folder).glob("*"):
|
|
||||||
if latents_file.is_file():
|
|
||||||
freed_space += latents_file.stat().st_size
|
|
||||||
deleted_latents_count += 1
|
|
||||||
latents_file.unlink()
|
|
||||||
if deleted_latents_count > 0:
|
|
||||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
|
||||||
self._invoker.services.logger.info(
|
|
||||||
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
|
|
||||||
)
|
|
@ -1,68 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
|
|
||||||
from .latents_storage_base import LatentsStorageBase
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
|
||||||
|
|
||||||
__cache: Dict[str, torch.Tensor]
|
|
||||||
__cache_ids: Queue
|
|
||||||
__max_cache_size: int
|
|
||||||
__underlying_storage: LatentsStorageBase
|
|
||||||
|
|
||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
|
||||||
super().__init__()
|
|
||||||
self.__underlying_storage = underlying_storage
|
|
||||||
self.__cache = {}
|
|
||||||
self.__cache_ids = Queue()
|
|
||||||
self.__max_cache_size = max_cache_size
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
start_op = getattr(self.__underlying_storage, "start", None)
|
|
||||||
if callable(start_op):
|
|
||||||
start_op(invoker)
|
|
||||||
|
|
||||||
def stop(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
stop_op = getattr(self.__underlying_storage, "stop", None)
|
|
||||||
if callable(stop_op):
|
|
||||||
stop_op(invoker)
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
cache_item = self.__get_cache(name)
|
|
||||||
if cache_item is not None:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
latent = self.__underlying_storage.get(name)
|
|
||||||
self.__set_cache(name, latent)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__underlying_storage.save(name, data)
|
|
||||||
self.__set_cache(name, data)
|
|
||||||
self._on_changed(data)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
self.__underlying_storage.delete(name)
|
|
||||||
if name in self.__cache:
|
|
||||||
del self.__cache[name]
|
|
||||||
self._on_deleted(name)
|
|
||||||
|
|
||||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
|
||||||
return None if name not in self.__cache else self.__cache[name]
|
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
|
||||||
if name not in self.__cache:
|
|
||||||
self.__cache[name] = data
|
|
||||||
self.__cache_ids.put(name)
|
|
||||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
|
||||||
self.__cache.pop(self.__cache_ids.get())
|
|
@ -5,25 +5,23 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
LoadedModelInfo,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
ModelInfo,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
class ModelManagerServiceBase(ABC):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
@ -49,9 +47,8 @@ class ModelManagerServiceBase(ABC):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
) -> LoadedModelInfo:
|
||||||
) -> ModelInfo:
|
|
||||||
"""Retrieve the indicated model with name and type.
|
"""Retrieve the indicated model with name and type.
|
||||||
submodel can be used to get a part (such as the vae)
|
submodel can be used to get a part (such as the vae)
|
||||||
of a diffusers pipeline."""
|
of a diffusers pipeline."""
|
||||||
|
@ -11,11 +11,13 @@ from pydantic import Field
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
LoadedModelInfo,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
ModelInfo,
|
|
||||||
ModelManager,
|
ModelManager,
|
||||||
ModelMerger,
|
ModelMerger,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
@ -30,7 +32,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device
|
|||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
pass
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
||||||
@ -86,47 +88,50 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
logger.info("Model manager service initialized")
|
logger.info("Model manager service initialized")
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker: Optional[Invoker] = invoker
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> ModelInfo:
|
) -> LoadedModelInfo:
|
||||||
"""
|
"""
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
Retrieve the indicated model. submodel can be used to get a
|
||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we can emit model loading events if we are executing with access to the invocation context
|
# we can emit model loading events if we are executing with access to the invocation context
|
||||||
if context:
|
if context_data is not None:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_info = self.mgr.get_model(
|
loaded_model_info = self.mgr.get_model(
|
||||||
model_name,
|
model_name,
|
||||||
base_model,
|
base_model,
|
||||||
model_type,
|
model_type,
|
||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if context:
|
if context_data is not None:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
loaded_model_info=loaded_model_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_info
|
return loaded_model_info
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
@ -263,34 +268,37 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context_data: InvocationContextData,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
model_info: Optional[ModelInfo] = None,
|
loaded_model_info: Optional[LoadedModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if self._invoker is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
|
|
||||||
if model_info:
|
if loaded_model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
self._invoker.services.events.emit_model_load_completed(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
loaded_model_info=loaded_model_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
self._invoker.services.events.emit_model_load_started(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerBase(ABC, Generic[T]):
|
||||||
|
"""Saves and loads arbitrary python objects."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_deleted_callbacks: list[Callable[[str], None]] = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
"""
|
||||||
|
Loads the object.
|
||||||
|
:param name: The name of the object to load.
|
||||||
|
:raises ObjectNotFoundError: if the object is not found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
"""
|
||||||
|
Saves the object, returning its name.
|
||||||
|
:param obj: The object to save.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Deletes the object, if it exists.
|
||||||
|
:param name: The name of the object to delete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an object is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_deleted(self, name: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(name)
|
@ -0,0 +1,5 @@
|
|||||||
|
class ObjectNotFoundError(KeyError):
|
||||||
|
"""Raised when an object is not found while loading"""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
super().__init__(f"Object with name {name} not found")
|
@ -0,0 +1,85 @@
|
|||||||
|
import tempfile
|
||||||
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeleteAllResult:
|
||||||
|
deleted_count: int
|
||||||
|
freed_space_bytes: float
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||||
|
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||||
|
|
||||||
|
:param output_dir: The folder where the serialized objects will be stored
|
||||||
|
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self._ephemeral = ephemeral
|
||||||
|
self._base_output_dir = output_dir
|
||||||
|
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||||
|
self._tempdir = (
|
||||||
|
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||||
|
)
|
||||||
|
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||||
|
self.__obj_class_name: Optional[str] = None
|
||||||
|
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
try:
|
||||||
|
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise ObjectNotFoundError(name) from e
|
||||||
|
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
name = self._new_name()
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||||
|
return name
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
file_path = self._get_path(name)
|
||||||
|
file_path.unlink()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _obj_class_name(self) -> str:
|
||||||
|
if not self.__obj_class_name:
|
||||||
|
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||||
|
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
|
||||||
|
return self.__obj_class_name
|
||||||
|
|
||||||
|
def _get_path(self, name: str) -> Path:
|
||||||
|
return self._output_dir / name
|
||||||
|
|
||||||
|
def _new_name(self) -> str:
|
||||||
|
return f"{self._obj_class_name}_{uuid_string()}"
|
||||||
|
|
||||||
|
def _tempdir_cleanup(self) -> None:
|
||||||
|
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||||
|
if self._tempdir:
|
||||||
|
self._tempdir.cleanup()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
|
||||||
|
self._tempdir_cleanup()
|
||||||
|
|
||||||
|
def stop(self, invoker: "Invoker") -> None:
|
||||||
|
self._tempdir_cleanup()
|
@ -0,0 +1,65 @@
|
|||||||
|
from queue import Queue
|
||||||
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||||
|
|
||||||
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||||
|
"""
|
||||||
|
Provides a LRU cache for an instance of `ObjectSerializerBase`.
|
||||||
|
Saving an object to the cache always writes through to the underlying storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||||
|
super().__init__()
|
||||||
|
self._underlying_storage = underlying_storage
|
||||||
|
self._cache: dict[str, T] = {}
|
||||||
|
self._cache_ids = Queue[str]()
|
||||||
|
self._max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def start(self, invoker: "Invoker") -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
start_op = getattr(self._underlying_storage, "start", None)
|
||||||
|
if callable(start_op):
|
||||||
|
start_op(invoker)
|
||||||
|
|
||||||
|
def stop(self, invoker: "Invoker") -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||||
|
if callable(stop_op):
|
||||||
|
stop_op(invoker)
|
||||||
|
|
||||||
|
def load(self, name: str) -> T:
|
||||||
|
cache_item = self._get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
obj = self._underlying_storage.load(name)
|
||||||
|
self._set_cache(name, obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def save(self, obj: T) -> str:
|
||||||
|
name = self._underlying_storage.save(obj)
|
||||||
|
self._set_cache(name, obj)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self._underlying_storage.delete(name)
|
||||||
|
if name in self._cache:
|
||||||
|
del self._cache[name]
|
||||||
|
self._on_deleted(name)
|
||||||
|
|
||||||
|
def _get_cache(self, name: str) -> Optional[T]:
|
||||||
|
return None if name not in self._cache else self._cache[name]
|
||||||
|
|
||||||
|
def _set_cache(self, name: str, data: T):
|
||||||
|
if name not in self._cache:
|
||||||
|
self._cache[name] = data
|
||||||
|
self._cache_ids.put(name)
|
||||||
|
if self._cache_ids.qsize() > self._max_cache_size:
|
||||||
|
self._cache.pop(self._cache_ids.get())
|
@ -13,14 +13,11 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
|
409
invokeai/app/services/shared/invocation_context.py
Normal file
409
invokeai/app/services/shared/invocation_context.py
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||||
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
"""
|
||||||
|
The InvocationContext provides access to various services and data about the current invocation.
|
||||||
|
|
||||||
|
We do not provide the invocation services directly, as their methods are both dangerous and
|
||||||
|
inconvenient to use.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- The `images` service allows nodes to delete or unsafely modify existing images.
|
||||||
|
- The `configuration` service allows nodes to change the app's config at runtime.
|
||||||
|
- The `events` service allows nodes to emit arbitrary events.
|
||||||
|
|
||||||
|
Wrapping these services provides a simpler and safer interface for nodes to use.
|
||||||
|
|
||||||
|
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
||||||
|
with each other.
|
||||||
|
|
||||||
|
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
|
||||||
|
user-facing docstrings and not need to go and update the internal services to match.
|
||||||
|
|
||||||
|
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvocationContextData:
|
||||||
|
invocation: "BaseInvocation"
|
||||||
|
"""The invocation that is being executed."""
|
||||||
|
session_id: str
|
||||||
|
"""The session that is being executed."""
|
||||||
|
queue_id: str
|
||||||
|
"""The queue in which the session is being executed."""
|
||||||
|
source_node_id: str
|
||||||
|
"""The ID of the node from which the currently executing invocation was prepared."""
|
||||||
|
queue_item_id: int
|
||||||
|
"""The ID of the queue item that is being executed."""
|
||||||
|
batch_id: str
|
||||||
|
"""The ID of the batch that is being executed."""
|
||||||
|
workflow: Optional[WorkflowWithoutID] = None
|
||||||
|
"""The workflow associated with this queue item, if any."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationContextInterface:
|
||||||
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||||
|
self._services = services
|
||||||
|
self._context_data = context_data
|
||||||
|
|
||||||
|
|
||||||
|
class BoardsInterface(InvocationContextInterface):
|
||||||
|
def create(self, board_name: str) -> BoardDTO:
|
||||||
|
"""
|
||||||
|
Creates a board.
|
||||||
|
|
||||||
|
:param board_name: The name of the board to create.
|
||||||
|
"""
|
||||||
|
return self._services.boards.create(board_name)
|
||||||
|
|
||||||
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
|
"""
|
||||||
|
Gets a board DTO.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to get.
|
||||||
|
"""
|
||||||
|
return self._services.boards.get_dto(board_id)
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
"""
|
||||||
|
Gets all boards.
|
||||||
|
"""
|
||||||
|
return self._services.boards.get_all()
|
||||||
|
|
||||||
|
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Adds an image to a board.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to add the image to.
|
||||||
|
:param image_name: The name of the image to add to the board.
|
||||||
|
"""
|
||||||
|
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Gets all image names for a board.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to get the image names for.
|
||||||
|
"""
|
||||||
|
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerInterface(InvocationContextInterface):
|
||||||
|
def debug(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a debug message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.debug(message)
|
||||||
|
|
||||||
|
def info(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an info message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.info(message)
|
||||||
|
|
||||||
|
def warning(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a warning message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.warning(message)
|
||||||
|
|
||||||
|
def error(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an error message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.error(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImagesInterface(InvocationContextInterface):
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: Image,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
|
metadata: Optional[MetadataField] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""
|
||||||
|
Saves an image, returning its DTO.
|
||||||
|
|
||||||
|
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||||
|
|
||||||
|
:param image: The image to save, as a PIL image.
|
||||||
|
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||||
|
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||||
|
you want to override or provide a board manually!**
|
||||||
|
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||||
|
to the gallery.
|
||||||
|
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||||
|
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||||
|
**Use this only if you want to override or provide metadata manually!**
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||||
|
metadata_ = None
|
||||||
|
if metadata:
|
||||||
|
metadata_ = metadata
|
||||||
|
elif isinstance(self._context_data.invocation, WithMetadata):
|
||||||
|
metadata_ = self._context_data.invocation.metadata
|
||||||
|
|
||||||
|
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||||
|
board_id_ = None
|
||||||
|
if board_id:
|
||||||
|
board_id_ = board_id
|
||||||
|
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
|
||||||
|
board_id_ = self._context_data.invocation.board.board_id
|
||||||
|
|
||||||
|
return self._services.images.create(
|
||||||
|
image=image,
|
||||||
|
is_intermediate=self._context_data.invocation.is_intermediate,
|
||||||
|
image_category=image_category,
|
||||||
|
board_id=board_id_,
|
||||||
|
metadata=metadata_,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
workflow=self._context_data.workflow,
|
||||||
|
session_id=self._context_data.session_id,
|
||||||
|
node_id=self._context_data.invocation.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pil(self, image_name: str) -> Image:
|
||||||
|
"""
|
||||||
|
Gets an image as a PIL Image object.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_pil_image(image_name)
|
||||||
|
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
|
"""
|
||||||
|
Gets an image's metadata, if it has any.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get the metadata for.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_metadata(image_name)
|
||||||
|
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
"""
|
||||||
|
Gets an image as an ImageDTO object.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorsInterface(InvocationContextInterface):
|
||||||
|
def save(self, tensor: Tensor) -> str:
|
||||||
|
"""
|
||||||
|
Saves a tensor, returning its name.
|
||||||
|
|
||||||
|
:param tensor: The tensor to save.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = self._services.tensors.save(obj=tensor)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def load(self, name: str) -> Tensor:
|
||||||
|
"""
|
||||||
|
Loads a tensor by name.
|
||||||
|
|
||||||
|
:param name: The name of the tensor to load.
|
||||||
|
"""
|
||||||
|
return self._services.tensors.load(name)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningInterface(InvocationContextInterface):
|
||||||
|
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||||
|
"""
|
||||||
|
Saves a conditioning data object, returning its name.
|
||||||
|
|
||||||
|
:param conditioning_context_data: The conditioning data to save.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = self._services.conditioning.save(obj=conditioning_data)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def load(self, name: str) -> ConditioningFieldData:
|
||||||
|
"""
|
||||||
|
Loads conditioning data by name.
|
||||||
|
|
||||||
|
:param name: The name of the conditioning data to load.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._services.conditioning.load(name)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsInterface(InvocationContextInterface):
|
||||||
|
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a model exists.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to check.
|
||||||
|
:param base_model: The base model of the model to check.
|
||||||
|
:param model_type: The type of the model to check.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.model_exists(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||||
|
) -> LoadedModelInfo:
|
||||||
|
"""
|
||||||
|
Loads a model.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to get.
|
||||||
|
:param base_model: The base model of the model to get.
|
||||||
|
:param model_type: The type of the model to get.
|
||||||
|
:param submodel: The submodel of the model to get.
|
||||||
|
:returns: An object representing the loaded model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
|
# the event payloads.
|
||||||
|
|
||||||
|
return self._services.model_manager.get_model(
|
||||||
|
model_name, base_model, model_type, submodel, context_data=self._context_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Gets a model's info, an dict-like object.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to get.
|
||||||
|
:param base_model: The base model of the model to get.
|
||||||
|
:param model_type: The type of the model to get.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.model_info(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigInterface(InvocationContextInterface):
|
||||||
|
def get(self) -> InvokeAIAppConfig:
|
||||||
|
"""Gets the app's config."""
|
||||||
|
|
||||||
|
return self._services.configuration.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
class UtilInterface(InvocationContextInterface):
|
||||||
|
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||||
|
"""
|
||||||
|
The step callback emits a progress event with the current step, the total number of
|
||||||
|
steps, a preview image, and some other internal metadata.
|
||||||
|
|
||||||
|
This should be called after each denoising step.
|
||||||
|
|
||||||
|
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||||
|
:param base_model: The base model for the current denoising step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The step callback needs access to the events and the invocation queue services, but this
|
||||||
|
# represents a dangerous level of access.
|
||||||
|
#
|
||||||
|
# We wrap the step callback so that nodes do not have direct access to these services.
|
||||||
|
|
||||||
|
stable_diffusion_step_callback(
|
||||||
|
context_data=self._context_data,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
base_model=base_model,
|
||||||
|
invocation_queue=self._services.queue,
|
||||||
|
events=self._services.events,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationContext:
|
||||||
|
"""
|
||||||
|
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
images: ImagesInterface,
|
||||||
|
tensors: TensorsInterface,
|
||||||
|
conditioning: ConditioningInterface,
|
||||||
|
models: ModelsInterface,
|
||||||
|
logger: LoggerInterface,
|
||||||
|
config: ConfigInterface,
|
||||||
|
util: UtilInterface,
|
||||||
|
boards: BoardsInterface,
|
||||||
|
context_data: InvocationContextData,
|
||||||
|
services: InvocationServices,
|
||||||
|
) -> None:
|
||||||
|
self.images = images
|
||||||
|
"""Provides methods to save, get and update images and their metadata."""
|
||||||
|
self.tensors = tensors
|
||||||
|
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||||
|
self.conditioning = conditioning
|
||||||
|
"""Provides methods to save and get conditioning data."""
|
||||||
|
self.models = models
|
||||||
|
"""Provides methods to check if a model exists, get a model, and get a model's info."""
|
||||||
|
self.logger = logger
|
||||||
|
"""Provides access to the app logger."""
|
||||||
|
self.config = config
|
||||||
|
"""Provides access to the app's config."""
|
||||||
|
self.util = util
|
||||||
|
"""Provides utility methods."""
|
||||||
|
self.boards = boards
|
||||||
|
"""Provides methods to interact with boards."""
|
||||||
|
self._data = context_data
|
||||||
|
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
||||||
|
self._services = services
|
||||||
|
"""Provides access to the full application services. This is an internal API and may change without warning."""
|
||||||
|
|
||||||
|
|
||||||
|
def build_invocation_context(
|
||||||
|
services: InvocationServices,
|
||||||
|
context_data: InvocationContextData,
|
||||||
|
) -> InvocationContext:
|
||||||
|
"""
|
||||||
|
Builds the invocation context for a specific invocation execution.
|
||||||
|
|
||||||
|
:param invocation_services: The invocation services to wrap.
|
||||||
|
:param invocation_context_data: The invocation context data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = LoggerInterface(services=services, context_data=context_data)
|
||||||
|
images = ImagesInterface(services=services, context_data=context_data)
|
||||||
|
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||||
|
models = ModelsInterface(services=services, context_data=context_data)
|
||||||
|
config = ConfigInterface(services=services, context_data=context_data)
|
||||||
|
util = UtilInterface(services=services, context_data=context_data)
|
||||||
|
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||||
|
boards = BoardsInterface(services=services, context_data=context_data)
|
||||||
|
|
||||||
|
ctx = InvocationContext(
|
||||||
|
images=images,
|
||||||
|
logger=logger,
|
||||||
|
config=config,
|
||||||
|
tensors=tensors,
|
||||||
|
models=models,
|
||||||
|
context_data=context_data,
|
||||||
|
util=util,
|
||||||
|
conditioning=conditioning,
|
||||||
|
services=services,
|
||||||
|
boards=boards,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ctx
|
@ -1,67 +0,0 @@
|
|||||||
class FieldDescriptions:
|
|
||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
|
||||||
cfg_scale = "Classifier-Free Guidance scale"
|
|
||||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
|
||||||
scheduler = "Scheduler to use during inference"
|
|
||||||
positive_cond = "Positive conditioning tensor"
|
|
||||||
negative_cond = "Negative conditioning tensor"
|
|
||||||
noise = "Noise tensor"
|
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
|
||||||
vae = "VAE"
|
|
||||||
cond = "Conditioning tensor"
|
|
||||||
controlnet_model = "ControlNet model to load"
|
|
||||||
vae_model = "VAE model to load"
|
|
||||||
lora_model = "LoRA model to load"
|
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
|
||||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
|
||||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
|
||||||
raw_prompt = "Raw prompt text (no parsing)"
|
|
||||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
|
||||||
skipped_layers = "Number of layers to skip in text encoder"
|
|
||||||
seed = "Seed for random number generation"
|
|
||||||
steps = "Number of steps to run"
|
|
||||||
width = "Width of output (px)"
|
|
||||||
height = "Height of output (px)"
|
|
||||||
control = "ControlNet(s) to apply"
|
|
||||||
ip_adapter = "IP-Adapter to apply"
|
|
||||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
|
||||||
denoised_latents = "Denoised latents tensor"
|
|
||||||
latents = "Latents tensor"
|
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
|
||||||
metadata = "Optional metadata to be saved with the image"
|
|
||||||
metadata_collection = "Collection of Metadata"
|
|
||||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
|
||||||
metadata_item_label = "Label for this metadata item"
|
|
||||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
|
||||||
workflow = "Optional workflow to be saved with the image"
|
|
||||||
interp_mode = "Interpolation mode"
|
|
||||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
|
||||||
fp32 = "Whether or not to use full float32 precision"
|
|
||||||
precision = "Precision to use"
|
|
||||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
|
||||||
detect_res = "Pixel resolution for detection"
|
|
||||||
image_res = "Pixel resolution for output image"
|
|
||||||
safe_mode = "Whether or not to use safe mode"
|
|
||||||
scribble_mode = "Whether or not to use scribble mode"
|
|
||||||
scale_factor = "The factor by which to scale"
|
|
||||||
blend_alpha = (
|
|
||||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
|
||||||
)
|
|
||||||
num_1 = "The first number"
|
|
||||||
num_2 = "The second number"
|
|
||||||
mask = "The mask to use for the operation"
|
|
||||||
board = "The board to save the image to"
|
|
||||||
image = "The image to process"
|
|
||||||
tile_size = "Tile size"
|
|
||||||
inclusive_low = "The inclusive low value"
|
|
||||||
exclusive_high = "The exclusive high value"
|
|
||||||
decimal_places = "The number of decimal places to round to"
|
|
||||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
|
||||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
|
||||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
|
@ -1,6 +1,6 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
from invokeai.app.invocations.fields import FieldDescriptions
|
||||||
|
|
||||||
|
|
||||||
class FreeUConfig(BaseModel):
|
class FreeUConfig(BaseModel):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -6,7 +8,11 @@ from invokeai.app.services.invocation_processor.invocation_processor_common impo
|
|||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||||
@ -25,13 +31,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
|||||||
|
|
||||||
|
|
||||||
def stable_diffusion_step_callback(
|
def stable_diffusion_step_callback(
|
||||||
context: InvocationContext,
|
context_data: "InvocationContextData",
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
):
|
invocation_queue: "InvocationQueueABC",
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
events: "EventServiceBase",
|
||||||
|
) -> None:
|
||||||
|
if invocation_queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
@ -108,13 +114,13 @@ def stable_diffusion_step_callback(
|
|||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
events.emit_generator_progress(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
node=node,
|
node_id=context_data.invocation.id,
|
||||||
source_node_id=source_node_id,
|
source_node_id=context_data.source_node_id,
|
||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
step=intermediate_state.step,
|
step=intermediate_state.step,
|
||||||
order=intermediate_state.order,
|
order=intermediate_state.order,
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
from .model_management import ( # noqa: F401
|
||||||
|
BaseModelType,
|
||||||
|
LoadedModelInfo,
|
||||||
|
ModelCache,
|
||||||
|
ModelManager,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
from .model_management.models import SilenceWarnings # noqa: F401
|
from .model_management.models import SilenceWarnings # noqa: F401
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
# This import must be first
|
# This import must be first
|
||||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType
|
||||||
from .lora import ModelPatcher, ONNXModelPatcher
|
from .lora import ModelPatcher, ONNXModelPatcher
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ CONFIG_FILE_VERSION = "3.0.0"
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo:
|
class LoadedModelInfo:
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
base_model: BaseModelType
|
base_model: BaseModelType
|
||||||
@ -450,7 +450,7 @@ class ModelManager(object):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> ModelInfo:
|
) -> LoadedModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an ModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
@ -508,7 +508,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_hash = "<NO_HASH>" # TODO:
|
model_hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return ModelInfo(
|
return LoadedModelInfo(
|
||||||
context=model_context,
|
context=model_context,
|
||||||
name=model_name,
|
name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
@ -32,6 +32,11 @@ class BasicConditioningInfo:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConditioningFieldData:
|
||||||
|
conditionings: List[BasicConditioningInfo]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
|
@ -3,7 +3,7 @@ from typing import Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
|
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
@ -34,8 +34,8 @@ def install_and_load_model(
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> ModelInfo:
|
) -> LoadedModelInfo:
|
||||||
"""Install a model if it is not already installed, then get the ModelInfo for that model.
|
"""Install a model if it is not already installed, then get the LoadedModelInfo for that model.
|
||||||
|
|
||||||
This is intended as a utility function for tests.
|
This is intended as a utility function for tests.
|
||||||
|
|
||||||
@ -49,9 +49,9 @@ def install_and_load_model(
|
|||||||
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelInfo
|
LoadedModelInfo
|
||||||
"""
|
"""
|
||||||
# If the requested model is already installed, return its ModelInfo.
|
# If the requested model is already installed, return its LoadedModelInfo.
|
||||||
with contextlib.suppress(ModelNotFoundException):
|
with contextlib.suppress(ModelNotFoundException):
|
||||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||||
|
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
|
||||||
import type { PluginOption, UserConfig } from 'vite';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const commonPlugins: UserConfig['plugins'] = [
|
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
];
|
|
@ -1,33 +0,0 @@
|
|||||||
import type { UserConfig } from 'vite';
|
|
||||||
|
|
||||||
import { commonPlugins } from './common.mjs';
|
|
||||||
|
|
||||||
export const appConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [...commonPlugins],
|
|
||||||
build: {
|
|
||||||
chunkSizeWarningLimit: 1500,
|
|
||||||
},
|
|
||||||
server: {
|
|
||||||
// Proxy HTTP requests to the flask server
|
|
||||||
proxy: {
|
|
||||||
// Proxy socket.io to the nodes socketio server
|
|
||||||
'/ws/socket.io': {
|
|
||||||
target: 'ws://127.0.0.1:9090',
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
// Proxy openapi schema definiton
|
|
||||||
'/openapi.json': {
|
|
||||||
target: 'http://127.0.0.1:9090/openapi.json',
|
|
||||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
// proxy nodes api
|
|
||||||
'/api/v1': {
|
|
||||||
target: 'http://127.0.0.1:9090/api/v1',
|
|
||||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
@ -1,46 +0,0 @@
|
|||||||
import path from 'path';
|
|
||||||
import type { UserConfig } from 'vite';
|
|
||||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
|
||||||
import dts from 'vite-plugin-dts';
|
|
||||||
|
|
||||||
import { commonPlugins } from './common.mjs';
|
|
||||||
|
|
||||||
export const packageConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [
|
|
||||||
...commonPlugins,
|
|
||||||
dts({
|
|
||||||
insertTypesEntry: true,
|
|
||||||
}),
|
|
||||||
cssInjectedByJsPlugin(),
|
|
||||||
],
|
|
||||||
build: {
|
|
||||||
cssCodeSplit: true,
|
|
||||||
lib: {
|
|
||||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
|
||||||
name: 'InvokeAIUI',
|
|
||||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
|
||||||
},
|
|
||||||
rollupOptions: {
|
|
||||||
external: ['react', 'react-dom', '@emotion/react', '@chakra-ui/react', '@invoke-ai/ui-library'],
|
|
||||||
output: {
|
|
||||||
globals: {
|
|
||||||
react: 'React',
|
|
||||||
'react-dom': 'ReactDOM',
|
|
||||||
'@emotion/react': 'EmotionReact',
|
|
||||||
'@invoke-ai/ui-library': 'UiLibrary',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resolve: {
|
|
||||||
alias: {
|
|
||||||
app: path.resolve(__dirname, '../src/app'),
|
|
||||||
assets: path.resolve(__dirname, '../src/assets'),
|
|
||||||
common: path.resolve(__dirname, '../src/common'),
|
|
||||||
features: path.resolve(__dirname, '../src/features'),
|
|
||||||
services: path.resolve(__dirname, '../src/services'),
|
|
||||||
theme: path.resolve(__dirname, '../src/theme'),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
@ -33,7 +33,9 @@
|
|||||||
"preinstall": "npx only-allow pnpm",
|
"preinstall": "npx only-allow pnpm",
|
||||||
"storybook": "storybook dev -p 6006",
|
"storybook": "storybook dev -p 6006",
|
||||||
"build-storybook": "storybook build",
|
"build-storybook": "storybook build",
|
||||||
"unimported": "npx unimported"
|
"unimported": "npx unimported",
|
||||||
|
"test": "vitest",
|
||||||
|
"test:no-watch": "vitest --no-watch"
|
||||||
},
|
},
|
||||||
"madge": {
|
"madge": {
|
||||||
"excludeRegExp": [
|
"excludeRegExp": [
|
||||||
@ -157,7 +159,8 @@
|
|||||||
"vite-plugin-css-injected-by-js": "^3.3.1",
|
"vite-plugin-css-injected-by-js": "^3.3.1",
|
||||||
"vite-plugin-dts": "^3.7.1",
|
"vite-plugin-dts": "^3.7.1",
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
"vite-tsconfig-paths": "^4.3.1"
|
"vite-tsconfig-paths": "^4.3.1",
|
||||||
|
"vitest": "^1.2.2"
|
||||||
},
|
},
|
||||||
"pnpm": {
|
"pnpm": {
|
||||||
"patchedDependencies": {
|
"patchedDependencies": {
|
||||||
|
222
invokeai/frontend/web/pnpm-lock.yaml
generated
222
invokeai/frontend/web/pnpm-lock.yaml
generated
@ -215,7 +215,7 @@ devDependencies:
|
|||||||
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12)
|
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)(typescript@5.3.3)(vite@5.0.12)
|
||||||
'@storybook/test':
|
'@storybook/test':
|
||||||
specifier: ^7.6.10
|
specifier: ^7.6.10
|
||||||
version: 7.6.10
|
version: 7.6.10(vitest@1.2.2)
|
||||||
'@storybook/theming':
|
'@storybook/theming':
|
||||||
specifier: ^7.6.10
|
specifier: ^7.6.10
|
||||||
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)
|
version: 7.6.10(react-dom@18.2.0)(react@18.2.0)
|
||||||
@ -318,6 +318,9 @@ devDependencies:
|
|||||||
vite-tsconfig-paths:
|
vite-tsconfig-paths:
|
||||||
specifier: ^4.3.1
|
specifier: ^4.3.1
|
||||||
version: 4.3.1(typescript@5.3.3)(vite@5.0.12)
|
version: 4.3.1(typescript@5.3.3)(vite@5.0.12)
|
||||||
|
vitest:
|
||||||
|
specifier: ^1.2.2
|
||||||
|
version: 1.2.2(@types/node@20.11.5)
|
||||||
|
|
||||||
packages:
|
packages:
|
||||||
|
|
||||||
@ -5464,7 +5467,7 @@ packages:
|
|||||||
- supports-color
|
- supports-color
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/@storybook/test@7.6.10:
|
/@storybook/test@7.6.10(vitest@1.2.2):
|
||||||
resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==}
|
resolution: {integrity: sha512-dn/T+HcWOBlVh3c74BHurp++BaqBoQgNbSIaXlYDpJoZ+DzNIoEQVsWFYm5gCbtKK27iFd4n52RiQI3f6Vblqw==}
|
||||||
dependencies:
|
dependencies:
|
||||||
'@storybook/client-logger': 7.6.10
|
'@storybook/client-logger': 7.6.10
|
||||||
@ -5472,7 +5475,7 @@ packages:
|
|||||||
'@storybook/instrumenter': 7.6.10
|
'@storybook/instrumenter': 7.6.10
|
||||||
'@storybook/preview-api': 7.6.10
|
'@storybook/preview-api': 7.6.10
|
||||||
'@testing-library/dom': 9.3.4
|
'@testing-library/dom': 9.3.4
|
||||||
'@testing-library/jest-dom': 6.2.0
|
'@testing-library/jest-dom': 6.2.0(vitest@1.2.2)
|
||||||
'@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4)
|
'@testing-library/user-event': 14.3.0(@testing-library/dom@9.3.4)
|
||||||
'@types/chai': 4.3.11
|
'@types/chai': 4.3.11
|
||||||
'@vitest/expect': 0.34.7
|
'@vitest/expect': 0.34.7
|
||||||
@ -5652,7 +5655,7 @@ packages:
|
|||||||
pretty-format: 27.5.1
|
pretty-format: 27.5.1
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/@testing-library/jest-dom@6.2.0:
|
/@testing-library/jest-dom@6.2.0(vitest@1.2.2):
|
||||||
resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==}
|
resolution: {integrity: sha512-+BVQlJ9cmEn5RDMUS8c2+TU6giLvzaHZ8sU/x0Jj7fk+6/46wPdwlgOPcpxS17CjcanBi/3VmGMqVr2rmbUmNw==}
|
||||||
engines: {node: '>=14', npm: '>=6', yarn: '>=1'}
|
engines: {node: '>=14', npm: '>=6', yarn: '>=1'}
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@ -5678,6 +5681,7 @@ packages:
|
|||||||
dom-accessibility-api: 0.6.3
|
dom-accessibility-api: 0.6.3
|
||||||
lodash: 4.17.21
|
lodash: 4.17.21
|
||||||
redent: 3.0.0
|
redent: 3.0.0
|
||||||
|
vitest: 1.2.2(@types/node@20.11.5)
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
/@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4):
|
/@testing-library/user-event@14.3.0(@testing-library/dom@9.3.4):
|
||||||
@ -6490,12 +6494,42 @@ packages:
|
|||||||
chai: 4.4.1
|
chai: 4.4.1
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/@vitest/expect@1.2.2:
|
||||||
|
resolution: {integrity: sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==}
|
||||||
|
dependencies:
|
||||||
|
'@vitest/spy': 1.2.2
|
||||||
|
'@vitest/utils': 1.2.2
|
||||||
|
chai: 4.4.1
|
||||||
|
dev: true
|
||||||
|
|
||||||
|
/@vitest/runner@1.2.2:
|
||||||
|
resolution: {integrity: sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==}
|
||||||
|
dependencies:
|
||||||
|
'@vitest/utils': 1.2.2
|
||||||
|
p-limit: 5.0.0
|
||||||
|
pathe: 1.1.2
|
||||||
|
dev: true
|
||||||
|
|
||||||
|
/@vitest/snapshot@1.2.2:
|
||||||
|
resolution: {integrity: sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==}
|
||||||
|
dependencies:
|
||||||
|
magic-string: 0.30.5
|
||||||
|
pathe: 1.1.2
|
||||||
|
pretty-format: 29.7.0
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@vitest/spy@0.34.7:
|
/@vitest/spy@0.34.7:
|
||||||
resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==}
|
resolution: {integrity: sha512-NMMSzOY2d8L0mcOt4XcliDOS1ISyGlAXuQtERWVOoVHnKwmG+kKhinAiGw3dTtMQWybfa89FG8Ucg9tiC/FhTQ==}
|
||||||
dependencies:
|
dependencies:
|
||||||
tinyspy: 2.2.0
|
tinyspy: 2.2.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/@vitest/spy@1.2.2:
|
||||||
|
resolution: {integrity: sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==}
|
||||||
|
dependencies:
|
||||||
|
tinyspy: 2.2.0
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@vitest/utils@0.34.7:
|
/@vitest/utils@0.34.7:
|
||||||
resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==}
|
resolution: {integrity: sha512-ziAavQLpCYS9sLOorGrFFKmy2gnfiNU0ZJ15TsMz/K92NAPS/rp9K4z6AJQQk5Y8adCy4Iwpxy7pQumQ/psnRg==}
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -6504,6 +6538,15 @@ packages:
|
|||||||
pretty-format: 29.7.0
|
pretty-format: 29.7.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/@vitest/utils@1.2.2:
|
||||||
|
resolution: {integrity: sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==}
|
||||||
|
dependencies:
|
||||||
|
diff-sequences: 29.6.3
|
||||||
|
estree-walker: 3.0.3
|
||||||
|
loupe: 2.3.7
|
||||||
|
pretty-format: 29.7.0
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@volar/language-core@1.11.1:
|
/@volar/language-core@1.11.1:
|
||||||
resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==}
|
resolution: {integrity: sha512-dOcNn3i9GgZAcJt43wuaEykSluAuOkQgzni1cuxLxTV0nJKanQztp7FxyswdRILaKH+P2XZMPRp2S4MV/pElCw==}
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -7184,6 +7227,11 @@ packages:
|
|||||||
engines: {node: '>=0.4.0'}
|
engines: {node: '>=0.4.0'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/acorn-walk@8.3.2:
|
||||||
|
resolution: {integrity: sha512-cjkyv4OtNCIeqhHrfS81QWXoCBPExR/J62oyEqepVw8WaQeSqpW2uhuLPh1m9eWhDuOo/jUXVTlifvesOWp/4A==}
|
||||||
|
engines: {node: '>=0.4.0'}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/acorn@7.4.1:
|
/acorn@7.4.1:
|
||||||
resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==}
|
resolution: {integrity: sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==}
|
||||||
engines: {node: '>=0.4.0'}
|
engines: {node: '>=0.4.0'}
|
||||||
@ -7661,6 +7709,11 @@ packages:
|
|||||||
engines: {node: '>= 0.8'}
|
engines: {node: '>= 0.8'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/cac@6.7.14:
|
||||||
|
resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==}
|
||||||
|
engines: {node: '>=8'}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/call-bind@1.0.5:
|
/call-bind@1.0.5:
|
||||||
resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==}
|
resolution: {integrity: sha512-C3nQxfFZxFRVoJoGKKI8y3MOEo129NQ+FgQ08iye+Mk4zNZZGdjfs06bVTr+DBSlA66Q2VEcMki/cUCP4SercQ==}
|
||||||
dependencies:
|
dependencies:
|
||||||
@ -9173,6 +9226,12 @@ packages:
|
|||||||
resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==}
|
resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/estree-walker@3.0.3:
|
||||||
|
resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==}
|
||||||
|
dependencies:
|
||||||
|
'@types/estree': 1.0.5
|
||||||
|
dev: true
|
||||||
|
|
||||||
/esutils@2.0.3:
|
/esutils@2.0.3:
|
||||||
resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==}
|
resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
@ -10547,6 +10606,10 @@ packages:
|
|||||||
hasBin: true
|
hasBin: true
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/jsonc-parser@3.2.1:
|
||||||
|
resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/jsondiffpatch@0.6.0:
|
/jsondiffpatch@0.6.0:
|
||||||
resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==}
|
resolution: {integrity: sha512-3QItJOXp2AP1uv7waBkao5nCvhEv+QmJAd38Ybq7wNI74Q+BBmnLn4EDKz6yI9xGAIQoUF87qHt+kc1IVxB4zQ==}
|
||||||
engines: {node: ^18.0.0 || >=20.0.0}
|
engines: {node: ^18.0.0 || >=20.0.0}
|
||||||
@ -10648,6 +10711,14 @@ packages:
|
|||||||
engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0}
|
engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/local-pkg@0.5.0:
|
||||||
|
resolution: {integrity: sha512-ok6z3qlYyCDS4ZEU27HaU6x/xZa9Whf8jD4ptH5UZTQYZVYeb9bnZ3ojVhiJNLiXK1Hfc0GNbLXcmZ5plLDDBg==}
|
||||||
|
engines: {node: '>=14'}
|
||||||
|
dependencies:
|
||||||
|
mlly: 1.5.0
|
||||||
|
pkg-types: 1.0.3
|
||||||
|
dev: true
|
||||||
|
|
||||||
/locate-path@3.0.0:
|
/locate-path@3.0.0:
|
||||||
resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==}
|
resolution: {integrity: sha512-7AO748wWnIhNqAuaty2ZWHkQHRSNfPVIsPIfwEOWO22AmaoVrWavlOcMR5nzTLNYvp36X220/maaRsrec1G65A==}
|
||||||
engines: {node: '>=6'}
|
engines: {node: '>=6'}
|
||||||
@ -10986,6 +11057,15 @@ packages:
|
|||||||
hasBin: true
|
hasBin: true
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/mlly@1.5.0:
|
||||||
|
resolution: {integrity: sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==}
|
||||||
|
dependencies:
|
||||||
|
acorn: 8.11.3
|
||||||
|
pathe: 1.1.2
|
||||||
|
pkg-types: 1.0.3
|
||||||
|
ufo: 1.3.2
|
||||||
|
dev: true
|
||||||
|
|
||||||
/module-definition@3.4.0:
|
/module-definition@3.4.0:
|
||||||
resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==}
|
resolution: {integrity: sha512-XxJ88R1v458pifaSkPNLUTdSPNVGMP2SXVncVmApGO+gAfrLANiYe6JofymCzVceGOMwQE2xogxBSc8uB7XegA==}
|
||||||
engines: {node: '>=6.0'}
|
engines: {node: '>=6.0'}
|
||||||
@ -11380,6 +11460,13 @@ packages:
|
|||||||
yocto-queue: 0.1.0
|
yocto-queue: 0.1.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/p-limit@5.0.0:
|
||||||
|
resolution: {integrity: sha512-/Eaoq+QyLSiXQ4lyYV23f14mZRQcXnxfHrN0vCai+ak9G0pp9iEQukIIZq5NccEvwRB8PUnZT0KsOoDCINS1qQ==}
|
||||||
|
engines: {node: '>=18'}
|
||||||
|
dependencies:
|
||||||
|
yocto-queue: 1.0.0
|
||||||
|
dev: true
|
||||||
|
|
||||||
/p-locate@3.0.0:
|
/p-locate@3.0.0:
|
||||||
resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==}
|
resolution: {integrity: sha512-x+12w/To+4GFfgJhBEpiDcLozRJGegY+Ei7/z0tSLkMmxGZNybVMSfWj9aJn8Z5Fc7dBUNJOOVgPv2H7IwulSQ==}
|
||||||
engines: {node: '>=6'}
|
engines: {node: '>=6'}
|
||||||
@ -11550,6 +11637,14 @@ packages:
|
|||||||
find-up: 5.0.0
|
find-up: 5.0.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/pkg-types@1.0.3:
|
||||||
|
resolution: {integrity: sha512-nN7pYi0AQqJnoLPC9eHFQ8AcyaixBUOwvqc5TDnIKCMEE6I0y8P7OKA7fPexsXGCGxQDl/cmrLAp26LhcwxZ4A==}
|
||||||
|
dependencies:
|
||||||
|
jsonc-parser: 3.2.1
|
||||||
|
mlly: 1.5.0
|
||||||
|
pathe: 1.1.2
|
||||||
|
dev: true
|
||||||
|
|
||||||
/pluralize@8.0.0:
|
/pluralize@8.0.0:
|
||||||
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
|
resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==}
|
||||||
engines: {node: '>=4'}
|
engines: {node: '>=4'}
|
||||||
@ -12850,6 +12945,10 @@ packages:
|
|||||||
object-inspect: 1.13.1
|
object-inspect: 1.13.1
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/siginfo@2.0.0:
|
||||||
|
resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/signal-exit@3.0.7:
|
/signal-exit@3.0.7:
|
||||||
resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==}
|
resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==}
|
||||||
dev: true
|
dev: true
|
||||||
@ -12968,6 +13067,10 @@ packages:
|
|||||||
stackframe: 1.3.4
|
stackframe: 1.3.4
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/stackback@0.0.2:
|
||||||
|
resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/stackframe@1.3.4:
|
/stackframe@1.3.4:
|
||||||
resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==}
|
resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==}
|
||||||
dev: false
|
dev: false
|
||||||
@ -12992,6 +13095,10 @@ packages:
|
|||||||
engines: {node: '>= 0.8'}
|
engines: {node: '>= 0.8'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/std-env@3.7.0:
|
||||||
|
resolution: {integrity: sha512-JPbdCEQLj1w5GilpiHAx3qJvFndqybBysA3qUOnznweH4QbNYUsW/ea8QzSrnh0vNsezMMw5bcVool8lM0gwzg==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/stop-iteration-iterator@1.0.0:
|
/stop-iteration-iterator@1.0.0:
|
||||||
resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==}
|
resolution: {integrity: sha512-iCGQj+0l0HOdZ2AEeBADlsRC+vsnDsZsbdSiH1yNSjcfKM7fdpCMfqAL/dwF5BLiw/XhRft/Wax6zQbhq2BcjQ==}
|
||||||
engines: {node: '>= 0.4'}
|
engines: {node: '>= 0.4'}
|
||||||
@ -13161,6 +13268,12 @@ packages:
|
|||||||
engines: {node: '>=8'}
|
engines: {node: '>=8'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/strip-literal@1.3.0:
|
||||||
|
resolution: {integrity: sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==}
|
||||||
|
dependencies:
|
||||||
|
acorn: 8.11.3
|
||||||
|
dev: true
|
||||||
|
|
||||||
/stylis@4.2.0:
|
/stylis@4.2.0:
|
||||||
resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==}
|
resolution: {integrity: sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw==}
|
||||||
dev: false
|
dev: false
|
||||||
@ -13311,6 +13424,15 @@ packages:
|
|||||||
/tiny-invariant@1.3.1:
|
/tiny-invariant@1.3.1:
|
||||||
resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==}
|
resolution: {integrity: sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==}
|
||||||
|
|
||||||
|
/tinybench@2.6.0:
|
||||||
|
resolution: {integrity: sha512-N8hW3PG/3aOoZAN5V/NSAEDz0ZixDSSt5b/a05iqtpgfLWMSVuCo7w0k2vVvEjdrIoeGqZzweX2WlyioNIHchA==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
|
/tinypool@0.8.2:
|
||||||
|
resolution: {integrity: sha512-SUszKYe5wgsxnNOVlBYO6IC+8VGWdVGZWAqUxp3UErNBtptZvWbwyUOyzNL59zigz2rCA92QiL3wvG+JDSdJdQ==}
|
||||||
|
engines: {node: '>=14.0.0'}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/tinyspy@2.2.0:
|
/tinyspy@2.2.0:
|
||||||
resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==}
|
resolution: {integrity: sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==}
|
||||||
engines: {node: '>=14.0.0'}
|
engines: {node: '>=14.0.0'}
|
||||||
@ -13828,6 +13950,27 @@ packages:
|
|||||||
engines: {node: '>= 0.8'}
|
engines: {node: '>= 0.8'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/vite-node@1.2.2(@types/node@20.11.5):
|
||||||
|
resolution: {integrity: sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==}
|
||||||
|
engines: {node: ^18.0.0 || >=20.0.0}
|
||||||
|
hasBin: true
|
||||||
|
dependencies:
|
||||||
|
cac: 6.7.14
|
||||||
|
debug: 4.3.4
|
||||||
|
pathe: 1.1.2
|
||||||
|
picocolors: 1.0.0
|
||||||
|
vite: 5.0.12(@types/node@20.11.5)
|
||||||
|
transitivePeerDependencies:
|
||||||
|
- '@types/node'
|
||||||
|
- less
|
||||||
|
- lightningcss
|
||||||
|
- sass
|
||||||
|
- stylus
|
||||||
|
- sugarss
|
||||||
|
- supports-color
|
||||||
|
- terser
|
||||||
|
dev: true
|
||||||
|
|
||||||
/vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12):
|
/vite-plugin-css-injected-by-js@3.3.1(vite@5.0.12):
|
||||||
resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==}
|
resolution: {integrity: sha512-PjM/X45DR3/V1K1fTRs8HtZHEQ55kIfdrn+dzaqNBFrOYO073SeSNCxp4j7gSYhV9NffVHaEnOL4myoko0ePAg==}
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@ -13926,6 +14069,63 @@ packages:
|
|||||||
fsevents: 2.3.3
|
fsevents: 2.3.3
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/vitest@1.2.2(@types/node@20.11.5):
|
||||||
|
resolution: {integrity: sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==}
|
||||||
|
engines: {node: ^18.0.0 || >=20.0.0}
|
||||||
|
hasBin: true
|
||||||
|
peerDependencies:
|
||||||
|
'@edge-runtime/vm': '*'
|
||||||
|
'@types/node': ^18.0.0 || >=20.0.0
|
||||||
|
'@vitest/browser': ^1.0.0
|
||||||
|
'@vitest/ui': ^1.0.0
|
||||||
|
happy-dom: '*'
|
||||||
|
jsdom: '*'
|
||||||
|
peerDependenciesMeta:
|
||||||
|
'@edge-runtime/vm':
|
||||||
|
optional: true
|
||||||
|
'@types/node':
|
||||||
|
optional: true
|
||||||
|
'@vitest/browser':
|
||||||
|
optional: true
|
||||||
|
'@vitest/ui':
|
||||||
|
optional: true
|
||||||
|
happy-dom:
|
||||||
|
optional: true
|
||||||
|
jsdom:
|
||||||
|
optional: true
|
||||||
|
dependencies:
|
||||||
|
'@types/node': 20.11.5
|
||||||
|
'@vitest/expect': 1.2.2
|
||||||
|
'@vitest/runner': 1.2.2
|
||||||
|
'@vitest/snapshot': 1.2.2
|
||||||
|
'@vitest/spy': 1.2.2
|
||||||
|
'@vitest/utils': 1.2.2
|
||||||
|
acorn-walk: 8.3.2
|
||||||
|
cac: 6.7.14
|
||||||
|
chai: 4.4.1
|
||||||
|
debug: 4.3.4
|
||||||
|
execa: 8.0.1
|
||||||
|
local-pkg: 0.5.0
|
||||||
|
magic-string: 0.30.5
|
||||||
|
pathe: 1.1.2
|
||||||
|
picocolors: 1.0.0
|
||||||
|
std-env: 3.7.0
|
||||||
|
strip-literal: 1.3.0
|
||||||
|
tinybench: 2.6.0
|
||||||
|
tinypool: 0.8.2
|
||||||
|
vite: 5.0.12(@types/node@20.11.5)
|
||||||
|
vite-node: 1.2.2(@types/node@20.11.5)
|
||||||
|
why-is-node-running: 2.2.2
|
||||||
|
transitivePeerDependencies:
|
||||||
|
- less
|
||||||
|
- lightningcss
|
||||||
|
- sass
|
||||||
|
- stylus
|
||||||
|
- sugarss
|
||||||
|
- supports-color
|
||||||
|
- terser
|
||||||
|
dev: true
|
||||||
|
|
||||||
/void-elements@3.1.0:
|
/void-elements@3.1.0:
|
||||||
resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==}
|
resolution: {integrity: sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
@ -14049,6 +14249,15 @@ packages:
|
|||||||
isexe: 2.0.0
|
isexe: 2.0.0
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/why-is-node-running@2.2.2:
|
||||||
|
resolution: {integrity: sha512-6tSwToZxTOcotxHeA+qGCq1mVzKR3CwcJGmVcY+QE8SHy6TnpFnh8PAvPNHYr7EcuVeG0QSMxtYCuO1ta/G/oA==}
|
||||||
|
engines: {node: '>=8'}
|
||||||
|
hasBin: true
|
||||||
|
dependencies:
|
||||||
|
siginfo: 2.0.0
|
||||||
|
stackback: 0.0.2
|
||||||
|
dev: true
|
||||||
|
|
||||||
/wordwrap@1.0.0:
|
/wordwrap@1.0.0:
|
||||||
resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==}
|
resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==}
|
||||||
dev: true
|
dev: true
|
||||||
@ -14189,6 +14398,11 @@ packages:
|
|||||||
engines: {node: '>=10'}
|
engines: {node: '>=10'}
|
||||||
dev: true
|
dev: true
|
||||||
|
|
||||||
|
/yocto-queue@1.0.0:
|
||||||
|
resolution: {integrity: sha512-9bnSc/HEW2uRy67wc+T8UwauLuPJVn28jb+GtJY16iiKWyvmYJRXVT4UamsAEGQfPohgr2q4Tq0sQbQlxTfi1g==}
|
||||||
|
engines: {node: '>=12.20'}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/z-schema@5.0.5:
|
/z-schema@5.0.5:
|
||||||
resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==}
|
resolution: {integrity: sha512-D7eujBWkLa3p2sIpJA0d1pr7es+a7m0vFAnZLlCEKq/Ij2k0MLi9Br2UPxoxdYystm5K1yeBGzub0FlYUEWj2Q==}
|
||||||
engines: {node: '>=8.0.0'}
|
engines: {node: '>=8.0.0'}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
import { cloneDeep } from 'lodash-es';
|
import { cloneDeep } from 'lodash-es';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { Graph } from 'services/api/types';
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
|
@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
|||||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
import { isImageOutput } from 'features/nodes/types/common';
|
||||||
import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { imagesAdapter } from 'services/api/util';
|
import { imagesAdapter } from 'services/api/util';
|
||||||
@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
const { data } = action.payload;
|
const { data } = action.payload;
|
||||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||||
|
|
||||||
const { result, node, queue_batch_id, source_node_id } = data;
|
const { result, node, queue_batch_id } = data;
|
||||||
|
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) {
|
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||||
const { image_name } = result.image;
|
const { image_name } = result.image;
|
||||||
const { canvas, gallery } = getState();
|
const { canvas, gallery } = getState();
|
||||||
|
|
||||||
@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
imageDTORequest.unsubscribe();
|
imageDTORequest.unsubscribe();
|
||||||
|
|
||||||
// Add canvas images to the staging area
|
// Add canvas images to the staging area
|
||||||
if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
|
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||||
dispatch(addImageToStagingArea(imageDTO));
|
dispatch(addImageToStagingArea(imageDTO));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => {
|
|||||||
actionCreator: updateAllNodesRequested,
|
actionCreator: updateAllNodesRequested,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const nodes = getState().nodes.nodes;
|
const { nodes, templates } = getState().nodes;
|
||||||
const templates = getState().nodeTemplates.templates;
|
|
||||||
|
|
||||||
let unableToUpdateCount = 0;
|
let unableToUpdateCount = 0;
|
||||||
|
|
||||||
|
@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { esrganModelName } = state.postprocessing;
|
|
||||||
const { autoAddBoardId } = state.gallery;
|
|
||||||
|
|
||||||
const enqueueBatchArg: BatchConfig = {
|
const enqueueBatchArg: BatchConfig = {
|
||||||
prepend: true,
|
prepend: true,
|
||||||
batch: {
|
batch: {
|
||||||
graph: buildAdHocUpscaleGraph({
|
graph: buildAdHocUpscaleGraph({
|
||||||
image_name,
|
image_name,
|
||||||
esrganModelName,
|
state,
|
||||||
autoAddBoardId,
|
|
||||||
}),
|
}),
|
||||||
runs: 1,
|
runs: 1,
|
||||||
},
|
},
|
||||||
|
@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => {
|
|||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const { workflow, asCopy } = action.payload;
|
const { workflow, asCopy } = action.payload;
|
||||||
const nodeTemplates = getState().nodeTemplates.templates;
|
const nodeTemplates = getState().nodes.templates;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
|
||||||
|
@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
|||||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
|
||||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||||
@ -46,7 +45,6 @@ const allReducers = {
|
|||||||
[gallerySlice.name]: gallerySlice.reducer,
|
[gallerySlice.name]: gallerySlice.reducer,
|
||||||
[generationSlice.name]: generationSlice.reducer,
|
[generationSlice.name]: generationSlice.reducer,
|
||||||
[nodesSlice.name]: nodesSlice.reducer,
|
[nodesSlice.name]: nodesSlice.reducer,
|
||||||
[nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer,
|
|
||||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||||
[systemSlice.name]: systemSlice.reducer,
|
[systemSlice.name]: systemSlice.reducer,
|
||||||
[configSlice.name]: configSlice.reducer,
|
[configSlice.name]: configSlice.reducer,
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import type { AppThunkDispatch, RootState } from 'app/store/store';
|
import type { AppThunkDispatch, RootState } from 'app/store/store';
|
||||||
import type { TypedUseSelectorHook } from 'react-redux';
|
import type { TypedUseSelectorHook } from 'react-redux';
|
||||||
import { useDispatch, useSelector } from 'react-redux';
|
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||||
|
|
||||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||||
|
export const useAppStore = () => useStore<RootState>();
|
||||||
|
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
2
invokeai/frontend/web/src/app/store/util.ts
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
export const EMPTY_ARRAY = [];
|
||||||
|
export const EMPTY_OBJECT = {};
|
@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
|||||||
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||||
@ -23,11 +22,10 @@ const selector = createMemoizedSelector(
|
|||||||
selectGenerationSlice,
|
selectGenerationSlice,
|
||||||
selectSystemSlice,
|
selectSystemSlice,
|
||||||
selectNodesSlice,
|
selectNodesSlice,
|
||||||
selectNodeTemplatesSlice,
|
|
||||||
selectDynamicPromptsSlice,
|
selectDynamicPromptsSlice,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
],
|
],
|
||||||
(controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => {
|
(controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => {
|
||||||
const { initialImage, model, positivePrompt } = generation;
|
const { initialImage, model, positivePrompt } = generation;
|
||||||
|
|
||||||
const { isConnected } = system;
|
const { isConnected } = system;
|
||||||
@ -54,7 +52,7 @@ const selector = createMemoizedSelector(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
const nodeTemplate = nodes.templates[node.data.type];
|
||||||
|
|
||||||
if (!nodeTemplate) {
|
if (!nodeTemplate) {
|
||||||
// Node type not found
|
// Node type not found
|
||||||
|
@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { SelectInstance } from 'chakra-react-select';
|
import type { SelectInstance } from 'chakra-react-select';
|
||||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||||
import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
|
import {
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
addNodePopoverClosed,
|
||||||
|
addNodePopoverOpened,
|
||||||
|
nodeAdded,
|
||||||
|
selectNodesSlice,
|
||||||
|
} from 'features/nodes/store/nodesSlice';
|
||||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||||
import { filter, map, memoize, some } from 'lodash-es';
|
import { filter, map, memoize, some } from 'lodash-es';
|
||||||
import type { KeyboardEventHandler } from 'react';
|
import type { KeyboardEventHandler } from 'react';
|
||||||
@ -54,10 +58,10 @@ const AddNodePopover = () => {
|
|||||||
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType);
|
||||||
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType);
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
// If we have a connection in progress, we need to filter the node choices
|
// If we have a connection in progress, we need to filter the node choices
|
||||||
const filteredNodeTemplates = fieldFilter
|
const filteredNodeTemplates = fieldFilter
|
||||||
? filter(nodeTemplates.templates, (template) => {
|
? filter(nodes.templates, (template) => {
|
||||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
||||||
|
|
||||||
return some(handles, (handle) => {
|
return some(handles, (handle) => {
|
||||||
@ -67,7 +71,7 @@ const AddNodePopover = () => {
|
|||||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
: map(nodeTemplates.templates);
|
: map(nodes.templates);
|
||||||
|
|
||||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
||||||
return {
|
return {
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
|
||||||
import { getFieldColor } from './getEdgeColor';
|
import { getFieldColor } from './getEdgeColor';
|
||||||
|
|
||||||
|
const defaultReturnValue = {
|
||||||
|
isSelected: false,
|
||||||
|
shouldAnimate: false,
|
||||||
|
stroke: colorTokenToCssVar('base.500'),
|
||||||
|
};
|
||||||
|
|
||||||
export const makeEdgeSelector = (
|
export const makeEdgeSelector = (
|
||||||
source: string,
|
source: string,
|
||||||
sourceHandleId: string | null | undefined,
|
sourceHandleId: string | null | undefined,
|
||||||
@ -12,14 +19,19 @@ export const makeEdgeSelector = (
|
|||||||
targetHandleId: string | null | undefined,
|
targetHandleId: string | null | undefined,
|
||||||
selected?: boolean
|
selected?: boolean
|
||||||
) =>
|
) =>
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
||||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||||
|
|
||||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||||
|
|
||||||
const isSelected = sourceNode?.selected || targetNode?.selected || selected;
|
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||||
const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined;
|
if (!sourceNode || !sourceHandleId) {
|
||||||
|
return defaultReturnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||||
|
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||||
|
|
||||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
|
||||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import type { CSSProperties } from 'react';
|
import type { CSSProperties } from 'react';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
@ -13,7 +12,7 @@ interface Props {
|
|||||||
const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' };
|
const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' };
|
||||||
|
|
||||||
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
||||||
const data = useNodeData(nodeId);
|
const template = useNodeTemplate(nodeId);
|
||||||
const { base600 } = useChakraThemeTokens();
|
const { base600 } = useChakraThemeTokens();
|
||||||
|
|
||||||
const dummyHandleStyles: CSSProperties = useMemo(
|
const dummyHandleStyles: CSSProperties = useMemo(
|
||||||
@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
|||||||
[dummyHandleStyles]
|
[dummyHandleStyles]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!isInvocationNodeData(data)) {
|
if (!template) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
|||||||
<>
|
<>
|
||||||
<Handle
|
<Handle
|
||||||
type="target"
|
type="target"
|
||||||
id={`${data.id}-collapsed-target`}
|
id={`${nodeId}-collapsed-target`}
|
||||||
isConnectable={false}
|
isConnectable={false}
|
||||||
position={Position.Left}
|
position={Position.Left}
|
||||||
style={collapsedTargetStyles}
|
style={collapsedTargetStyles}
|
||||||
/>
|
/>
|
||||||
{map(data.inputs, (input) => (
|
{map(template.inputs, (input) => (
|
||||||
<Handle
|
<Handle
|
||||||
key={`${data.id}-${input.name}-collapsed-input-handle`}
|
key={`${nodeId}-${input.name}-collapsed-input-handle`}
|
||||||
type="target"
|
type="target"
|
||||||
id={input.name}
|
id={input.name}
|
||||||
isConnectable={false}
|
isConnectable={false}
|
||||||
@ -62,14 +61,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
|
|||||||
))}
|
))}
|
||||||
<Handle
|
<Handle
|
||||||
type="source"
|
type="source"
|
||||||
id={`${data.id}-collapsed-source`}
|
id={`${nodeId}-collapsed-source`}
|
||||||
isConnectable={false}
|
isConnectable={false}
|
||||||
position={Position.Right}
|
position={Position.Right}
|
||||||
style={collapsedSourceStyles}
|
style={collapsedSourceStyles}
|
||||||
/>
|
/>
|
||||||
{map(data.outputs, (output) => (
|
{map(template.outputs, (output) => (
|
||||||
<Handle
|
<Handle
|
||||||
key={`${data.id}-${output.name}-collapsed-output-handle`}
|
key={`${nodeId}-${output.name}-collapsed-output-handle`}
|
||||||
type="source"
|
type="source"
|
||||||
id={output.name}
|
id={output.name}
|
||||||
isConnectable={false}
|
isConnectable={false}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
import type { InvocationNodeData } from 'features/nodes/types/invocation';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import type { NodeProps } from 'reactflow';
|
import type { NodeProps } from 'reactflow';
|
||||||
@ -13,7 +13,7 @@ const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
|
|||||||
const { id: nodeId, type, isOpen, label } = data;
|
const { id: nodeId, type, isOpen, label } = data;
|
||||||
|
|
||||||
const hasTemplateSelector = useMemo(
|
const hasTemplateSelector = useMemo(
|
||||||
() => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])),
|
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
|
||||||
[type]
|
[type]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent';
|
|||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
kind: 'input' | 'output';
|
kind: 'inputs' | 'outputs';
|
||||||
isMissingInput?: boolean;
|
isMissingInput?: boolean;
|
||||||
withTooltip?: boolean;
|
withTooltip?: boolean;
|
||||||
}
|
}
|
||||||
@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip
|
<Tooltip
|
||||||
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" /> : undefined}
|
label={withTooltip ? <FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" /> : undefined}
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
>
|
>
|
||||||
<Editable
|
<Editable
|
||||||
|
@ -6,7 +6,7 @@ import { memo } from 'react';
|
|||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
kind: 'input' | 'output';
|
kind: 'inputs' | 'outputs';
|
||||||
isMissingInput?: boolean;
|
isMissingInput?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||||
import { isFieldInputInstance, isFieldInputTemplate } from 'features/nodes/types/field';
|
import { isFieldInputInstance, isFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
@ -9,11 +9,11 @@ import { useTranslation } from 'react-i18next';
|
|||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
kind: 'input' | 'output';
|
kind: 'inputs' | 'outputs';
|
||||||
}
|
}
|
||||||
|
|
||||||
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
||||||
const field = useFieldInstance(nodeId, fieldName);
|
const field = useFieldInputInstance(nodeId, fieldName);
|
||||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
|
||||||
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
const isInputTemplate = isFieldInputTemplate(fieldTemplate);
|
||||||
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
const fieldTypeName = useFieldTypeName(fieldTemplate?.type);
|
||||||
|
@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||||
useConnectionState({ nodeId, fieldName, kind: 'input' });
|
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||||
|
|
||||||
const isMissingInput = useMemo(() => {
|
const isMissingInput = useMemo(() => {
|
||||||
if (!fieldTemplate) {
|
if (!fieldTemplate) {
|
||||||
@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
<EditableFieldTitle
|
<EditableFieldTitle
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
fieldName={fieldName}
|
fieldName={fieldName}
|
||||||
kind="input"
|
kind="inputs"
|
||||||
isMissingInput={isMissingInput}
|
isMissingInput={isMissingInput}
|
||||||
withTooltip
|
withTooltip
|
||||||
/>
|
/>
|
||||||
@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
<EditableFieldTitle
|
<EditableFieldTitle
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
fieldName={fieldName}
|
fieldName={fieldName}
|
||||||
kind="input"
|
kind="inputs"
|
||||||
isMissingInput={isMissingInput}
|
isMissingInput={isMissingInput}
|
||||||
withTooltip
|
withTooltip
|
||||||
/>
|
/>
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { Box, Text } from '@invoke-ai/ui-library';
|
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||||
import { useFieldInstance } from 'features/nodes/hooks/useFieldData';
|
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
|
||||||
import {
|
import {
|
||||||
isBoardFieldInputInstance,
|
isBoardFieldInputInstance,
|
||||||
isBoardFieldInputTemplate,
|
isBoardFieldInputTemplate,
|
||||||
@ -38,7 +37,6 @@ import {
|
|||||||
isVAEModelFieldInputTemplate,
|
isVAEModelFieldInputTemplate,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
|
||||||
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
|
||||||
@ -63,17 +61,8 @@ type InputFieldProps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||||
const { t } = useTranslation();
|
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||||
const fieldInstance = useFieldInstance(nodeId, fieldName);
|
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
|
||||||
|
|
||||||
if (fieldTemplate?.fieldKind === 'output') {
|
|
||||||
return (
|
|
||||||
<Box p={2}>
|
|
||||||
{t('nodes.outputFieldInInput')}: {fieldInstance?.type.name}
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fieldInstance && fieldTemplate) {
|
if (fieldTemplate) {
|
||||||
// Fallback for when there is no component for the type
|
// Fallback for when there is no component for the type
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
|
||||||
<Box p={1}>
|
|
||||||
<Text fontSize="sm" fontWeight="semibold" color="error.300">
|
|
||||||
{t('nodes.unknownFieldType', { type: fieldInstance?.type.name })}
|
|
||||||
</Text>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(InputFieldRenderer);
|
export default memo(InputFieldRenderer);
|
||||||
|
@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
/>
|
/>
|
||||||
<Flex flexDir="column" w="full">
|
<Flex flexDir="column" w="full">
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" />
|
||||||
<Spacer />
|
<Spacer />
|
||||||
{isValueChanged && (
|
{isValueChanged && (
|
||||||
<IconButton
|
<IconButton
|
||||||
@ -75,7 +75,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<Tooltip
|
<Tooltip
|
||||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="input" />}
|
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="inputs" />}
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
placement="top"
|
placement="top"
|
||||||
>
|
>
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||||
import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance';
|
|
||||||
import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate';
|
import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate';
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||||
import type { PropsWithChildren } from 'react';
|
import type { PropsWithChildren } from 'react';
|
||||||
@ -18,18 +17,17 @@ interface Props {
|
|||||||
const OutputField = ({ nodeId, fieldName }: Props) => {
|
const OutputField = ({ nodeId, fieldName }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName);
|
||||||
const fieldInstance = useFieldOutputInstance(nodeId, fieldName);
|
|
||||||
|
|
||||||
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } =
|
||||||
useConnectionState({ nodeId, fieldName, kind: 'output' });
|
useConnectionState({ nodeId, fieldName, kind: 'outputs' });
|
||||||
|
|
||||||
if (!fieldTemplate || !fieldInstance) {
|
if (!fieldTemplate) {
|
||||||
return (
|
return (
|
||||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||||
<FormControl alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
<FormControl alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
|
||||||
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
|
||||||
{t('nodes.unknownOutput', {
|
{t('nodes.unknownOutput', {
|
||||||
name: fieldTemplate?.title ?? fieldName,
|
name: fieldName,
|
||||||
})}
|
})}
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<OutputFieldWrapper shouldDim={shouldDim}>
|
<OutputFieldWrapper shouldDim={shouldDim}>
|
||||||
<Tooltip
|
<Tooltip
|
||||||
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="output" />}
|
label={<FieldTooltipContent nodeId={nodeId} fieldName={fieldName} kind="outputs" />}
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
placement="top"
|
placement="top"
|
||||||
shouldWrapChildren
|
shouldWrapChildren
|
||||||
|
@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
|
|||||||
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
|
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
|
||||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import EditableNodeTitle from './details/EditableNodeTitle';
|
import EditableNodeTitle from './details/EditableNodeTitle';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||||
|
|
||||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||||
|
|
||||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||||
|
|
||||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||||
return;
|
return;
|
||||||
|
@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types';
|
|||||||
|
|
||||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||||
|
|
||||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||||
|
|
||||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||||
|
|
||||||
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||||
|
|
||||||
|
@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||||
|
|
||||||
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
|
||||||
|
|
||||||
const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined;
|
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
template: lastSelectedNodeTemplate,
|
template: lastSelectedNodeTemplate,
|
||||||
|
@ -1,26 +1,22 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { EMPTY_ARRAY } from 'app/store/util';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||||
import { keys, map } from 'lodash-es';
|
import { keys, map } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
const template = selectNodeTemplate(nodes, nodeId);
|
||||||
if (!isInvocationNode(node)) {
|
if (!template) {
|
||||||
return [];
|
return EMPTY_ARRAY;
|
||||||
}
|
}
|
||||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
const fields = map(template.inputs).filter(
|
||||||
if (!nodeTemplate) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const fields = map(nodeTemplate.inputs).filter(
|
|
||||||
(field) =>
|
(field) =>
|
||||||
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
|
||||||
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||||
|
@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const useBuildNode = () => {
|
export const useBuildNode = () => {
|
||||||
const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates);
|
const nodeTemplates = useAppSelector((s) => s.nodes.templates);
|
||||||
|
|
||||||
const flow = useReactFlow();
|
const flow = useReactFlow();
|
||||||
|
|
||||||
|
@ -1,28 +1,24 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { EMPTY_ARRAY } from 'app/store/util';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
|
||||||
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
|
||||||
import { keys, map } from 'lodash-es';
|
import { keys, map } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useConnectionInputFieldNames = (nodeId: string) => {
|
export const useConnectionInputFieldNames = (nodeId: string): string[] => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
const template = selectNodeTemplate(nodes, nodeId);
|
||||||
if (!isInvocationNode(node)) {
|
if (!template) {
|
||||||
return [];
|
return EMPTY_ARRAY;
|
||||||
}
|
|
||||||
const nodeTemplate = nodeTemplates.templates[node.data.type];
|
|
||||||
if (!nodeTemplate) {
|
|
||||||
return [];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the visible fields
|
// get the visible fields
|
||||||
const fields = map(nodeTemplate.inputs).filter(
|
const fields = map(template.inputs).filter(
|
||||||
(field) =>
|
(field) =>
|
||||||
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
|
||||||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
|
||||||
|
@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector(
|
|||||||
export type UseConnectionStateProps = {
|
export type UseConnectionStateProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
kind: 'input' | 'output';
|
kind: 'inputs' | 'outputs';
|
||||||
};
|
};
|
||||||
|
|
||||||
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => {
|
||||||
@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
Boolean(
|
Boolean(
|
||||||
nodes.edges.filter((edge) => {
|
nodes.edges.filter((edge) => {
|
||||||
return (
|
return (
|
||||||
(kind === 'input' ? edge.target : edge.source) === nodeId &&
|
(kind === 'inputs' ? edge.target : edge.source) === nodeId &&
|
||||||
(kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
(kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName
|
||||||
);
|
);
|
||||||
}).length
|
}).length
|
||||||
)
|
)
|
||||||
@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
);
|
);
|
||||||
|
|
||||||
const selectConnectionError = useMemo(
|
const selectConnectionError = useMemo(
|
||||||
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType),
|
() => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType),
|
||||||
[nodeId, fieldName, kind, fieldType]
|
[nodeId, fieldName, kind, fieldType]
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta
|
|||||||
Boolean(
|
Boolean(
|
||||||
nodes.connectionStartParams?.nodeId === nodeId &&
|
nodes.connectionStartParams?.nodeId === nodeId &&
|
||||||
nodes.connectionStartParams?.handleId === fieldName &&
|
nodes.connectionStartParams?.handleId === fieldName &&
|
||||||
nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind]
|
nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
[fieldName, kind, nodeId]
|
[fieldName, kind, nodeId]
|
||||||
|
@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { compareVersions } from 'compare-versions';
|
import { compareVersions } from 'compare-versions';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
export const useDoNodeVersionsMatch = (nodeId: string): boolean => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
createSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
const data = selectNodeData(nodes, nodeId);
|
||||||
if (!isInvocationNode(node)) {
|
const template = selectNodeTemplate(nodes, nodeId);
|
||||||
|
if (!template?.version || !data?.version) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
return compareVersions(template.version, data.version) === 0;
|
||||||
if (!nodeTemplate?.version || !node.data?.version) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
|
||||||
}),
|
}),
|
||||||
[nodeId]
|
[nodeId]
|
||||||
);
|
);
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { selectNodeData } from 'features/nodes/store/selectors';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
const data = selectNodeData(nodes, nodeId);
|
||||||
if (!isInvocationNode(node)) {
|
if (!data) {
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
return data.inputs[fieldName]?.value !== undefined;
|
||||||
}),
|
}),
|
||||||
[fieldName, nodeId]
|
[fieldName, nodeId]
|
||||||
);
|
);
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
|
||||||
import { useMemo } from 'react';
|
|
||||||
|
|
||||||
export const useFieldInstance = (nodeId: string, fieldName: string) => {
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
return node?.data.inputs[fieldName];
|
|
||||||
}),
|
|
||||||
[fieldName, nodeId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const fieldData = useAppSelector(selector);
|
|
||||||
|
|
||||||
return fieldData;
|
|
||||||
};
|
|
@ -1,23 +1,20 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { selectFieldInputInstance } from 'features/nodes/store/selectors';
|
||||||
|
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useFieldInputInstance = (nodeId: string, fieldName: string) => {
|
export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
return selectFieldInputInstance(nodes, nodeId, fieldName);
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
return node.data.inputs[fieldName];
|
|
||||||
}),
|
}),
|
||||||
[fieldName, nodeId]
|
[fieldName, nodeId]
|
||||||
);
|
);
|
||||||
|
|
||||||
const fieldTemplate = useAppSelector(selector);
|
const fieldData = useAppSelector(selector);
|
||||||
|
|
||||||
return fieldTemplate;
|
return fieldData;
|
||||||
};
|
};
|
||||||
|
@ -1,21 +1,16 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import type { FieldInput } from 'features/nodes/types/field';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
export const useFieldInputKind = (nodeId: string, fieldName: string) => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
createSelector(selectNodesSlice, (nodes): FieldInput | null => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
const template = selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||||
if (!isInvocationNode(node)) {
|
return template?.input ?? null;
|
||||||
return;
|
|
||||||
}
|
|
||||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
|
||||||
const fieldTemplate = nodeTemplate?.inputs[fieldName];
|
|
||||||
return fieldTemplate?.input;
|
|
||||||
}),
|
}),
|
||||||
[fieldName, nodeId]
|
[fieldName, nodeId]
|
||||||
);
|
);
|
||||||
|
@ -1,20 +1,15 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice';
|
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useFieldInputTemplate = (nodeId: string, fieldName: string) => {
|
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
return selectFieldInputTemplate(nodes, nodeId, fieldName);
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
|
|
||||||
return nodeTemplate?.inputs[fieldName];
|
|
||||||
}),
|
}),
|
||||||
[fieldName, nodeId]
|
[fieldName, nodeId]
|
||||||
);
|
);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user