mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
67 Commits
bugfix/han
...
feat/invoc
Author | SHA1 | Date | |
---|---|---|---|
c76a6bd65f | |||
6c4eeaa569 | |||
1bbd13ead7 | |||
321b939d0e | |||
8fb77e431e | |||
083a4f3faa | |||
2005411f7e | |||
ba7b1b2665 | |||
b7ffd36cc6 | |||
199ddd6623 | |||
a7207ed8cf | |||
6bb2dda3f1 | |||
c1e5cd5893 | |||
ff249a2315 | |||
c58f8c3269 | |||
ed772a7107 | |||
cb0b389b4b | |||
8892df1d97 | |||
bc5f356390 | |||
bcb85e100d | |||
1f27ddc07d | |||
7a2b606001 | |||
83ddcc5f3a | |||
55fa785561 | |||
06429028c8 | |||
8b6e322697 | |||
54a67459bf | |||
7fe5283e74 | |||
fe0391c86b | |||
25386a76ef | |||
fd30cb4d90 | |||
0266946d3d | |||
a7f91b3e01 | |||
de0b72528c | |||
2932652787 | |||
db6bc7305a | |||
a5db204629 | |||
8e2b61e19f | |||
a3faa3792a | |||
c16eba78ab | |||
1a191c4655 | |||
e36d925bce | |||
b1ba18b3d1 | |||
aff46759f9 | |||
d7b7dcc7fe | |||
889a26c5b6 | |||
b4c774896a | |||
afbe889d35 | |||
9c1e52b1ef | |||
3f5ab02da9 | |||
bf48e8a03a | |||
e52434cb99 | |||
483bdbcb9f | |||
ae421fb4ab | |||
cc295a9f0a | |||
a7e23af9c6 | |||
3de4390711 | |||
3ceee2b2b2 | |||
5c7ed24aab | |||
183c9c4799 | |||
8baf3f78a2 | |||
ac2eb16a65 | |||
4aa7bee4b9 | |||
7e5ba2795e | |||
97a6c6eea7 | |||
f0e60a4ba2 | |||
aa089e8108 |
@ -9,11 +9,15 @@ complex functionality.
|
||||
|
||||
## Invocations Directory
|
||||
|
||||
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes.
|
||||
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These
|
||||
can be used as examples to create your own nodes.
|
||||
|
||||
New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup.
|
||||
New nodes should be added to a subfolder in `nodes` direction found at the root
|
||||
level of the InvokeAI installation location. Nodes added to this folder will be
|
||||
able to be used upon application startup.
|
||||
|
||||
Example `nodes` subfolder structure:
|
||||
|
||||
Example `nodes` subfolder structure:
|
||||
```py
|
||||
├── __init__.py # Invoke-managed custom node loader
|
||||
│
|
||||
@ -30,14 +34,14 @@ Example `nodes` subfolder structure:
|
||||
└── fancy_node.py
|
||||
```
|
||||
|
||||
Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded.
|
||||
See the README in the nodes folder for more examples:
|
||||
Each node folder must have an `__init__.py` file that imports its nodes. Only
|
||||
nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||
folder for more examples:
|
||||
|
||||
```py
|
||||
from .cool_node import CoolInvocation
|
||||
```
|
||||
|
||||
|
||||
## Creating A New Invocation
|
||||
|
||||
In order to understand the process of creating a new Invocation, let us actually
|
||||
@ -131,7 +135,6 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@ -167,7 +170,6 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@ -197,7 +199,6 @@ from invokeai.app.invocations.image import ImageOutput
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@ -229,30 +230,17 @@ class ResizeInvocation(BaseInvocation):
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
# Load the input image as a PIL image
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Resizing the image
|
||||
# Resize the image
|
||||
resized_image = image.resize((self.width, self.height))
|
||||
|
||||
# Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
|
||||
output_image = context.services.images.create(
|
||||
image=resized_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
# Save the image
|
||||
image_dto = context.images.save(image=resized_image)
|
||||
|
||||
# Returning the Image
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=output_image.image_name,
|
||||
),
|
||||
width=output_image.width,
|
||||
height=output_image.height,
|
||||
)
|
||||
# Return an ImageOutput
|
||||
return ImageOutput.build(image_dto)
|
||||
```
|
||||
|
||||
**Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a
|
||||
@ -343,27 +331,25 @@ class ImageColorStringOutput(BaseInvocationOutput):
|
||||
|
||||
That's all there is to it.
|
||||
|
||||
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
|
||||
|
||||
### Custom Input Fields
|
||||
|
||||
Now that you know how to create your own Invocations, let us dive into slightly
|
||||
more advanced topics.
|
||||
|
||||
While creating your own Invocations, you might run into a scenario where the
|
||||
existing input types in InvokeAI do not meet your requirements. In such cases,
|
||||
you can create your own input types.
|
||||
existing fields in InvokeAI do not meet your requirements. In such cases, you
|
||||
can create your own fields.
|
||||
|
||||
Let us create one as an example. Let us say we want to create a color input
|
||||
field that represents a color code. But before we start on that here are some
|
||||
general good practices to keep in mind.
|
||||
|
||||
**Good Practices**
|
||||
### Best Practices
|
||||
|
||||
- There is no naming convention for input fields but we highly recommend that
|
||||
you name it something appropriate like `ColorField`.
|
||||
- It is not mandatory but it is heavily recommended to add a relevant
|
||||
`docstring` to describe your input field.
|
||||
`docstring` to describe your field.
|
||||
- Keep your field in the same file as the Invocation that it is made for or in
|
||||
another file where it is relevant.
|
||||
|
||||
@ -378,10 +364,13 @@ class ColorField(BaseModel):
|
||||
pass
|
||||
```
|
||||
|
||||
Perfect. Now let us create our custom inputs for our field. This is exactly
|
||||
similar how you created input fields for your Invocation. All the same rules
|
||||
apply. Let us create four fields representing the _red(r)_, _blue(b)_,
|
||||
_green(g)_ and _alpha(a)_ channel of the color.
|
||||
Perfect. Now let us create the properties for our field. This is similar to how
|
||||
you created input fields for your Invocation. All the same rules apply. Let us
|
||||
create four fields representing the _red(r)_, _blue(b)_, _green(g)_ and
|
||||
_alpha(a)_ channel of the color.
|
||||
|
||||
> Technically, the properties are _also_ called fields - but in this case, it
|
||||
> refers to a `pydantic` field.
|
||||
|
||||
```python
|
||||
class ColorField(BaseModel):
|
||||
@ -396,25 +385,11 @@ That's it. We now have a new input field type that we can use in our Invocations
|
||||
like this.
|
||||
|
||||
```python
|
||||
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||
```
|
||||
|
||||
### Custom Components For Frontend
|
||||
### Using the custom field
|
||||
|
||||
Every backend input type should have a corresponding frontend component so the
|
||||
UI knows what to render when you use a particular field type.
|
||||
When you start the UI, your custom field will be automatically recognized.
|
||||
|
||||
If you are using existing field types, we already have components for those. So
|
||||
you don't have to worry about creating anything new. But this might not always
|
||||
be the case. Sometimes you might want to create new field types and have the
|
||||
frontend UI deal with it in a different way.
|
||||
|
||||
This is where we venture into the world of React and Javascript and create our
|
||||
own new components for our Invocations. Do not fear the world of JS. It's
|
||||
actually pretty straightforward.
|
||||
|
||||
Let us create a new component for our custom color field we created above. When
|
||||
we use a color field, let us say we want the UI to display a color picker for
|
||||
the user to pick from rather than entering values. That is what we will build
|
||||
now.
|
||||
-->
|
||||
Custom fields only support connection inputs in the Workflow Editor.
|
||||
|
@ -2,9 +2,14 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -23,8 +28,6 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
@ -68,6 +71,9 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
if output_folder is None:
|
||||
raise ValueError("Output folder is not set")
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
@ -84,7 +90,12 @@ class ApiDependencies:
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
@ -117,7 +128,6 @@ class ApiDependencies:
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
@ -131,6 +141,8 @@ class ApiDependencies:
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
@ -8,7 +8,7 @@ from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
@ -6,6 +6,7 @@ import sys
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@ -57,8 +58,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
UIConfigBase,
|
||||
)
|
||||
|
||||
|
@ -12,13 +12,16 @@ from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -52,393 +55,6 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
@ -460,33 +76,6 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
workflow: Optional[WorkflowWithoutID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
workflow: Optional[WorkflowWithoutID],
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.workflow = workflow
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
Base class for all invocation outputs.
|
||||
@ -632,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||
"""
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
@ -657,23 +246,23 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
if services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = context.services.invocation_cache.create_key(self)
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
key = services.invocation_cache.create_key(self)
|
||||
cached_value = services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
context.services.invocation_cache.save(key, output)
|
||||
services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = Field(
|
||||
@ -714,9 +303,7 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"workflow",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"metadata",
|
||||
}
|
||||
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
|
||||
@ -926,37 +513,3 @@ def invocation_output(
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
@ -5,9 +5,11 @@ import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -1,14 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
@ -20,21 +27,12 @@ from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
@ -48,7 +46,7 @@ class ConditioningFieldData:
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -66,25 +64,17 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
@ -93,11 +83,10 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.models.load(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
@ -128,7 +117,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
if context.config.get().log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
@ -149,14 +138,9 @@ class CompelInvocation(BaseInvocation):
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
@ -169,14 +153,8 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -200,14 +178,12 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
@ -216,11 +192,10 @@ class SDXLPromptInvocationBase:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.models.load(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
@ -253,7 +228,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
if context.config.get().log_tokenization:
|
||||
# TODO: better logging for and syntax
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
@ -286,7 +261,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -368,14 +343,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -383,7 +353,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -421,14 +391,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
|
14
invokeai/app/invocations/constants.py
Normal file
14
invokeai/app/invocations/constants.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
"""
|
||||
HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
factor is hard-coded to a literal '8' rather than using this constant.
|
||||
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
"""
|
||||
|
||||
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
"""A literal type representing the valid scheduler names."""
|
@ -25,22 +25,25 @@ from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -140,7 +143,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
@ -150,22 +153,13 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
raw_image = context.images.get_pil(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
@ -184,7 +178,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@ -207,7 +201,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@ -236,7 +230,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@ -258,7 +252,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@ -281,7 +275,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@ -308,7 +302,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@ -325,7 +319,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@ -348,7 +342,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@ -375,7 +369,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@ -405,7 +399,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@ -421,7 +415,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@ -444,7 +438,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@ -473,7 +467,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@ -513,7 +507,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
@ -555,7 +549,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
@ -5,22 +5,24 @@ import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
@ -34,18 +36,6 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=image_inpainted)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
@ -13,15 +13,13 @@ from pydantic import field_validator
|
||||
import invokeai.assets.fonts as font_assets
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("face_mask_output")
|
||||
@ -306,37 +304,37 @@ def extract_face(
|
||||
|
||||
# Adjust the crop boundaries to stay within the original image's dimensions
|
||||
if x_min < 0:
|
||||
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||
context.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||
x_max -= x_min
|
||||
x_min = 0
|
||||
elif x_max > mask.width:
|
||||
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||
context.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||
x_min -= x_max - mask.width
|
||||
x_max = mask.width
|
||||
|
||||
if y_min < 0:
|
||||
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||
context.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||
y_max -= y_min
|
||||
y_min = 0
|
||||
elif y_max > mask.height:
|
||||
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||
context.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||
y_min -= y_max - mask.height
|
||||
y_max = mask.height
|
||||
|
||||
# Ensure the crop is square and adjust the boundaries if needed
|
||||
if x_max - x_min != crop_size:
|
||||
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||
context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (x_max - x_min)
|
||||
x_min -= diff // 2
|
||||
x_max += diff - diff // 2
|
||||
|
||||
if y_max - y_min != crop_size:
|
||||
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||
context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (y_max - y_min)
|
||||
y_min -= diff // 2
|
||||
y_max += diff - diff // 2
|
||||
|
||||
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||
context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||
|
||||
# Crop the output image to the specified size with the center of the face mesh as the center.
|
||||
mask = mask.crop((x_min, y_min, x_max, y_max))
|
||||
@ -368,7 +366,7 @@ def get_faces_list(
|
||||
|
||||
# Generate the face box mask and get the center of the face.
|
||||
if not should_chunk:
|
||||
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
||||
context.logger.info("FaceTools --> Attempting full image face detection.")
|
||||
result = generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
@ -380,7 +378,7 @@ def get_faces_list(
|
||||
draw_mesh=draw_mesh,
|
||||
)
|
||||
if should_chunk or len(result) == 0:
|
||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
width, height = image.size
|
||||
image_chunks = []
|
||||
x_offsets = []
|
||||
@ -399,7 +397,7 @@ def get_faces_list(
|
||||
x_offsets.append(x)
|
||||
y_offsets.append(0)
|
||||
fx += increment
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
context.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
elif height > width:
|
||||
# Portrait - slice the image vertically
|
||||
fy = 0.0
|
||||
@ -411,10 +409,10 @@ def get_faces_list(
|
||||
x_offsets.append(0)
|
||||
y_offsets.append(y)
|
||||
fy += increment
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
context.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
|
||||
for idx in range(len(image_chunks)):
|
||||
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||
context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||
result = result + generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
@ -428,7 +426,7 @@ def get_faces_list(
|
||||
|
||||
if len(result) == 0:
|
||||
# Give up
|
||||
context.services.logger.warning(
|
||||
context.logger.warning(
|
||||
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
||||
)
|
||||
|
||||
@ -437,7 +435,7 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@ -470,11 +468,11 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
if len(all_faces) == 0:
|
||||
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||
context.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||
return None
|
||||
|
||||
if self.face_id > len(all_faces) - 1:
|
||||
context.services.logger.warning(
|
||||
context.logger.warning(
|
||||
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
||||
)
|
||||
return None
|
||||
@ -486,7 +484,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return face_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result = self.faceoff(context=context, image=image)
|
||||
|
||||
if result is None:
|
||||
@ -500,24 +498,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
x = result["x_min"]
|
||||
y = result["y_min"]
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
image=result_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK)
|
||||
|
||||
output = FaceOffOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
@ -531,7 +514,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@ -580,7 +563,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
if len(intersected_face_ids) == 0:
|
||||
id_range_str = ",".join([str(id) for id in id_range])
|
||||
context.services.logger.warning(
|
||||
context.logger.warning(
|
||||
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
||||
)
|
||||
return FaceMaskResult(
|
||||
@ -616,27 +599,12 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result = self.facemask(context=context, image=image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result["image"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=result["image"])
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
image=result["mask"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK)
|
||||
|
||||
output = FaceMaskOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
@ -649,9 +617,9 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@ -705,21 +673,9 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result_image = self.faceidentifier(context=context, image=image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=result_image)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
565
invokeai/app/invocations/fields.py
Normal file
565
invokeai/app/invocations/fields.py
Normal file
@ -0,0 +1,565 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
"""
|
||||
Inherit from this class if your node needs a metadata input field.
|
||||
"""
|
||||
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
||||
|
||||
class WithBoard(BaseModel):
|
||||
"""
|
||||
Inherit from this class if your node needs a board input field.
|
||||
"""
|
||||
|
||||
board: Optional[BoardField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.board,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Direct,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
File diff suppressed because it is too large
Load Diff
@ -6,14 +6,16 @@ from typing import Literal, Optional, get_args
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@ -118,8 +120,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -129,33 +131,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -168,33 +157,20 @@ class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -202,7 +178,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
@ -227,77 +203,38 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
@ -7,16 +7,13 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
|
||||
@ -65,7 +62,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@ -98,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.model_info(
|
||||
ip_adapter_info = context.models.get_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
)
|
||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||
@ -107,7 +104,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||
# disk vs. downloading the model.
|
||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
||||
os.path.join(context.config.get().models_path, ip_adapter_info["path"])
|
||||
)
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = CLIPVisionModelField(
|
||||
|
@ -23,21 +23,29 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
@ -59,16 +67,9 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
@ -77,18 +78,10 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
|
||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
# factor is hard-coded to a literal '8' rather than using this constant.
|
||||
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -101,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
@ -116,7 +109,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@ -144,7 +137,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
@ -152,33 +145,26 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.services.images.get_pil_image(self.mask.image_name),
|
||||
context.images.get_pil(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
||||
context.services.latents.save(masked_latents_name, masked_latents)
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
else:
|
||||
masked_latents_name = None
|
||||
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
mask_name = context.tensors.save(tensor=mask)
|
||||
|
||||
return DenoiseMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
),
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
)
|
||||
|
||||
|
||||
@ -189,10 +175,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -221,7 +204,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.5.1",
|
||||
version="1.5.2",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -249,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
@ -307,22 +290,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -330,11 +297,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
seed,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
@ -422,17 +389,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.models.load(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
input_image = context.images.get_pil(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -490,19 +456,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.models.load(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
image_encoder_model_info = context.models.load(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
@ -510,7 +474,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if not isinstance(single_ipa_images, list):
|
||||
single_ipa_images = [single_ipa_images]
|
||||
|
||||
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@ -554,13 +518,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||
t2i_adapter_model_info = context.models.load(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
@ -647,14 +610,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
def prep_inpaint_mask(self, context: InvocationContext, latents):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
@ -666,11 +629,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@ -696,27 +659,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
context.util.sd_step_callback(state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||
@ -792,9 +745,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -802,9 +754,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.2.0",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@ -820,12 +772,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||
latents = latents.to(vae.device)
|
||||
@ -854,7 +803,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
if self.tiled or context.config.get().tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@ -878,22 +827,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
@ -904,7 +840,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
|
||||
title="Resize Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
@ -927,7 +863,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -945,10 +881,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -956,7 +890,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
title="Scale Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
@ -970,7 +904,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@ -989,10 +923,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -1000,7 +932,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@ -1061,12 +993,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@ -1074,10 +1003,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
@ -1097,7 +1025,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
@ -1113,8 +1041,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
@ -1168,10 +1096,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, blended_latents)
|
||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||
@ -1181,7 +1107,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
title="Crop Latents",
|
||||
tags=["latents", "crop"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||
@ -1216,7 +1142,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
@ -1225,10 +1151,9 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
|
||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, cropped_latents)
|
||||
name = context.tensors.save(tensor=cropped_latents)
|
||||
|
||||
return build_latents_output(latents_name=name, latents=cropped_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||
|
||||
|
||||
@invocation_output("ideal_size_output")
|
||||
|
@ -5,10 +5,11 @@ from typing import Literal
|
||||
import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
|
||||
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||
|
@ -5,20 +5,22 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
|
@ -3,17 +3,14 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -105,7 +102,7 @@ class LoRAModelField(BaseModel):
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@ -119,7 +116,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -206,7 +203,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@ -232,7 +229,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
@ -288,7 +285,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@ -321,7 +318,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
@ -387,7 +384,7 @@ class VAEModelField(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@ -402,7 +399,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
|
@ -4,17 +4,15 @@
|
||||
import torch
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -69,13 +67,13 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
@classmethod
|
||||
def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput":
|
||||
return cls(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -96,13 +94,13 @@ class NoiseInvocation(BaseInvocation):
|
||||
)
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
default=512,
|
||||
multiple_of=8,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
@ -124,6 +122,5 @@ class NoiseInvocation(BaseInvocation):
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
name = context.tensors.save(tensor=noise)
|
||||
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
@ -1,508 +0,0 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
import inspect
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
||||
|
||||
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with (
|
||||
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
||||
):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
@invocation(
|
||||
"t2l_onnx",
|
||||
title="ONNX Text to Latents",
|
||||
tags=["latents", "inference", "txt2img", "onnx"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
noise: LatentsField = InputField(
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
# latents_dtype = prompt_embeds.dtype
|
||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
# if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=0, # TODO: refactor this node
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = {}
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
if loras:
|
||||
unet.release_session()
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
_, _, h, w = latents.shape
|
||||
unet.create_session(h, w)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
||||
|
||||
# call the callback, if provided
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
|
||||
# Latent to image
|
||||
@invocation(
|
||||
"l2i_onnx",
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.denoised_latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("model_loader_output_onnx")
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: OnnxModelField = InputField(
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
@ -40,8 +40,10 @@ from easing_functions import (
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -109,7 +111,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
@ -148,19 +150,19 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("start_step: " + str(start_step))
|
||||
context.services.logger.debug("end_step: " + str(end_step))
|
||||
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.services.logger.debug("prelist: " + str(prelist))
|
||||
context.services.logger.debug("postlist: " + str(postlist))
|
||||
context.logger.debug("start_step: " + str(start_step))
|
||||
context.logger.debug("end_step: " + str(end_step))
|
||||
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist: " + str(prelist))
|
||||
context.logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("easing class: " + str(easing_class))
|
||||
context.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
@ -171,7 +173,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
@ -183,14 +185,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
@ -225,12 +227,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
|
@ -1,20 +1,28 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -221,18 +229,6 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
# region Image
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
@ -241,6 +237,14 @@ class ImageOutput(BaseInvocationOutput):
|
||||
width: int = OutputField(description="The width of the image in pixels")
|
||||
height: int = OutputField(description="The height of the image in pixels")
|
||||
|
||||
@classmethod
|
||||
def build(cls, image_dto: ImageDTO) -> "ImageOutput":
|
||||
return cls(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("image_collection_output")
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
@ -251,16 +255,14 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
class ImageInvocation(
|
||||
BaseInvocation,
|
||||
):
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
@ -290,42 +292,40 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
# region DenoiseMask
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||
|
||||
|
||||
@invocation_output("denoise_mask_output")
|
||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
|
||||
@classmethod
|
||||
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
|
||||
return cls(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
@invocation_output("latents_output")
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@classmethod
|
||||
def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput":
|
||||
return cls(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("latents_collection_output")
|
||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
@ -337,7 +337,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
||||
)
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
@ -345,9 +345,9 @@ class LatentsInvocation(BaseInvocation):
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -368,31 +368,11 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Color
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
@invocation_output("color_output")
|
||||
class ColorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single color"""
|
||||
@ -424,18 +404,16 @@ class ColorInvocation(BaseInvocation):
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_collection_output")
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
@ -6,8 +6,10 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, UIComponent
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -1,14 +1,10 @@
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -34,7 +30,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@ -49,7 +45,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -120,7 +116,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
@ -138,7 +134,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.models.exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
@ -2,16 +2,15 @@
|
||||
|
||||
import re
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .fields import InputField, OutputField, UIComponent
|
||||
from .primitives import StringOutput
|
||||
|
||||
|
||||
|
@ -5,17 +5,13 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
|
||||
|
||||
|
@ -8,16 +8,12 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_even_split,
|
||||
calc_tiles_min_overlap,
|
||||
@ -236,7 +232,7 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
# Inputs
|
||||
@ -268,7 +264,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
# existed in memory at an earlier point in the graph.
|
||||
tile_np_images: list[np.ndarray] = []
|
||||
for image in images:
|
||||
pil_image = context.services.images.get_pil_image(image.image_name)
|
||||
pil_image = context.images.get_pil(image.image_name)
|
||||
pil_image = pil_image.convert("RGB")
|
||||
tile_np_images.append(np.array(pil_image))
|
||||
|
||||
@ -291,18 +287,5 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
# Convert into a PIL image and save
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
@ -8,13 +8,15 @@ import torch
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@ -29,8 +31,8 @@ if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@ -42,8 +44,8 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
models_path = context.services.configuration.models_path
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
@ -87,7 +89,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
netscale = 2
|
||||
else:
|
||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||
context.services.logger.error(msg)
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
@ -110,19 +112,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
@ -11,7 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
node_id: str,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
@ -70,7 +70,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node.get("id"),
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||
"step": step,
|
||||
@ -201,7 +201,7 @@ class EventServiceBase:
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
loaded_model_info: LoadedModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
@ -215,9 +215,9 @@ class EventServiceBase:
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"hash": model_info.hash,
|
||||
"location": str(model_info.location),
|
||||
"precision": str(model_info.precision),
|
||||
"hash": loaded_model_info.hash,
|
||||
"location": str(loaded_model_info.location),
|
||||
"precision": str(loaded_model_info.precision),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.metadata import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||
|
@ -3,7 +3,7 @@ import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
@ -3,7 +3,7 @@ from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
@ -37,7 +37,8 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._invoker.services.images.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.tensors.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
with self._lock:
|
||||
|
@ -5,11 +5,11 @@ from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
GESStatsNotFoundError,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
@ -131,16 +131,20 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
workflow=queue_item.workflow,
|
||||
)
|
||||
context_data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
session_id=graph_id,
|
||||
workflow=queue_item.workflow,
|
||||
source_node_id=source_node_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
batch_id=queue_item.session_queue_batch_id,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
services=self.__invoker.services,
|
||||
context_data=context_data,
|
||||
)
|
||||
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
|
@ -3,9 +3,15 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
@ -21,7 +27,6 @@ if TYPE_CHECKING:
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
@ -36,33 +41,6 @@ if TYPE_CHECKING:
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
board_images: "BoardImagesServiceABC"
|
||||
board_image_record_storage: "BoardImageRecordStorageBase"
|
||||
boards: "BoardServiceABC"
|
||||
board_records: "BoardRecordStorageBase"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
images: "ImageServiceABC"
|
||||
image_records: "ImageRecordStorageBase"
|
||||
image_files: "ImageFileStorageBase"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_install: "ModelInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_images: "BoardImagesServiceABC",
|
||||
@ -75,7 +53,6 @@ class InvocationServices:
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
@ -90,6 +67,8 @@ class InvocationServices:
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
@ -101,7 +80,6 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
@ -116,3 +94,5 @@ class InvocationServices:
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
|
@ -30,7 +30,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
"""
|
||||
Sets the item. The id will be extracted based on id_field.
|
||||
Sets the item.
|
||||
:param item: the item to set
|
||||
"""
|
||||
pass
|
||||
|
@ -1,45 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
@ -1,58 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: Path
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_latents()
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
def _delete_all_latents(self) -> None:
|
||||
"""
|
||||
Deletes all latents from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
deleted_latents_count = 0
|
||||
freed_space = 0
|
||||
for latents_file in Path(self.__output_folder).glob("*"):
|
||||
if latents_file.is_file():
|
||||
freed_space += latents_file.stat().st_size
|
||||
deleted_latents_count += 1
|
||||
latents_file.unlink()
|
||||
if deleted_latents_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
|
||||
)
|
@ -1,68 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self.__underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self.__underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
@ -5,25 +5,23 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
LoadedModelInfo,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
@ -49,9 +47,8 @@ class ModelManagerServiceBase(ABC):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
|
@ -11,11 +11,13 @@ from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
LoadedModelInfo,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelManager,
|
||||
ModelMerger,
|
||||
ModelNotFoundException,
|
||||
@ -30,7 +32,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
@ -86,47 +88,50 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Optional[Invoker] = invoker
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
if context_data is not None:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
context_data=context_data,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
loaded_model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if context:
|
||||
if context_data is not None:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
context_data=context_data,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
loaded_model_info=loaded_model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
return loaded_model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
@ -263,34 +268,37 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
context_data: InvocationContextData,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
loaded_model_info: Optional[LoadedModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
if self._invoker is None:
|
||||
return
|
||||
|
||||
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
if loaded_model_info:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
loaded_model_info=loaded_model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectSerializerBase(ABC, Generic[T]):
|
||||
"""Saves and loads arbitrary python objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_deleted_callbacks: list[Callable[[str], None]] = []
|
||||
|
||||
@abstractmethod
|
||||
def load(self, name: str) -> T:
|
||||
"""
|
||||
Loads the object.
|
||||
:param name: The name of the object to load.
|
||||
:raises ObjectNotFoundError: if the object is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, obj: T) -> str:
|
||||
"""
|
||||
Saves the object, returning its name.
|
||||
:param obj: The object to save.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
"""
|
||||
Deletes the object, if it exists.
|
||||
:param name: The name of the object to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an object is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_deleted(self, name: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(name)
|
@ -0,0 +1,5 @@
|
||||
class ObjectNotFoundError(KeyError):
|
||||
"""Raised when an object is not found while loading"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Object with name {name} not found")
|
@ -0,0 +1,85 @@
|
||||
import tempfile
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteAllResult:
|
||||
deleted_count: int
|
||||
freed_space_bytes: float
|
||||
|
||||
|
||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||
self._tempdir = (
|
||||
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||
)
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
except FileNotFoundError as e:
|
||||
raise ObjectNotFoundError(name) from e
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._new_name()
|
||||
file_path = self._get_path(name)
|
||||
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
file_path = self._get_path(name)
|
||||
file_path.unlink()
|
||||
|
||||
@property
|
||||
def _obj_class_name(self) -> str:
|
||||
if not self.__obj_class_name:
|
||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
|
||||
return self.__obj_class_name
|
||||
|
||||
def _get_path(self, name: str) -> Path:
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
if self._tempdir:
|
||||
self._tempdir.cleanup()
|
||||
|
||||
def __del__(self) -> None:
|
||||
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
|
||||
self._tempdir_cleanup()
|
||||
|
||||
def stop(self, invoker: "Invoker") -> None:
|
||||
self._tempdir_cleanup()
|
@ -0,0 +1,65 @@
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
"""
|
||||
Provides a LRU cache for an instance of `ObjectSerializerBase`.
|
||||
Saving an object to the cache always writes through to the underlying storage.
|
||||
"""
|
||||
|
||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: "Invoker") -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: "Invoker") -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
cache_item = self._get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
obj = self._underlying_storage.load(name)
|
||||
self._set_cache(name, obj)
|
||||
return obj
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._underlying_storage.save(obj)
|
||||
self._set_cache(name, obj)
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self._underlying_storage.delete(name)
|
||||
if name in self._cache:
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
@ -13,14 +13,11 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
|
409
invokeai/app/services/shared/invocation_context.py
Normal file
409
invokeai/app/services/shared/invocation_context.py
Normal file
@ -0,0 +1,409 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
|
||||
We do not provide the invocation services directly, as their methods are both dangerous and
|
||||
inconvenient to use.
|
||||
|
||||
For example:
|
||||
- The `images` service allows nodes to delete or unsafely modify existing images.
|
||||
- The `configuration` service allows nodes to change the app's config at runtime.
|
||||
- The `events` service allows nodes to emit arbitrary events.
|
||||
|
||||
Wrapping these services provides a simpler and safer interface for nodes to use.
|
||||
|
||||
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
||||
with each other.
|
||||
|
||||
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
|
||||
user-facing docstrings and not need to go and update the internal services to match.
|
||||
|
||||
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvocationContextData:
|
||||
invocation: "BaseInvocation"
|
||||
"""The invocation that is being executed."""
|
||||
session_id: str
|
||||
"""The session that is being executed."""
|
||||
queue_id: str
|
||||
"""The queue in which the session is being executed."""
|
||||
source_node_id: str
|
||||
"""The ID of the node from which the currently executing invocation was prepared."""
|
||||
queue_item_id: int
|
||||
"""The ID of the queue item that is being executed."""
|
||||
batch_id: str
|
||||
"""The ID of the batch that is being executed."""
|
||||
workflow: Optional[WorkflowWithoutID] = None
|
||||
"""The workflow associated with this queue item, if any."""
|
||||
|
||||
|
||||
class InvocationContextInterface:
|
||||
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||
self._services = services
|
||||
self._context_data = context_data
|
||||
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""
|
||||
Creates a board.
|
||||
|
||||
:param board_name: The name of the board to create.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""
|
||||
Gets a board DTO.
|
||||
|
||||
:param board_id: The ID of the board to get.
|
||||
"""
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""
|
||||
Gets all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""
|
||||
Adds an image to a board.
|
||||
|
||||
:param board_id: The ID of the board to add the image to.
|
||||
:param image_name: The name of the image to add to the board.
|
||||
"""
|
||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||
|
||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
"""
|
||||
Gets all image names for a board.
|
||||
|
||||
:param board_id: The ID of the board to get the image names for.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
def debug(self, message: str) -> None:
|
||||
"""
|
||||
Logs a debug message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""
|
||||
Logs an info message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""
|
||||
Logs a warning message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""
|
||||
Logs an error message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.error(message)
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
board_id: Optional[str] = None,
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> ImageDTO:
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||
you want to override or provide a board manually!**
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
"""
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
metadata_ = metadata
|
||||
elif isinstance(self._context_data.invocation, WithMetadata):
|
||||
metadata_ = self._context_data.invocation.metadata
|
||||
|
||||
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||
board_id_ = None
|
||||
if board_id:
|
||||
board_id_ = board_id
|
||||
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
|
||||
board_id_ = self._context_data.invocation.board.board_id
|
||||
|
||||
return self._services.images.create(
|
||||
image=image,
|
||||
is_intermediate=self._context_data.invocation.is_intermediate,
|
||||
image_category=image_category,
|
||||
board_id=board_id_,
|
||||
metadata=metadata_,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
workflow=self._context_data.workflow,
|
||||
session_id=self._context_data.session_id,
|
||||
node_id=self._context_data.invocation.id,
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str) -> Image:
|
||||
"""
|
||||
Gets an image as a PIL Image object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
"""
|
||||
return self._services.images.get_pil_image(image_name)
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""
|
||||
Gets an image's metadata, if it has any.
|
||||
|
||||
:param image_name: The name of the image to get the metadata for.
|
||||
"""
|
||||
return self._services.images.get_metadata(image_name)
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""
|
||||
Gets an image as an ImageDTO object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""
|
||||
Loads a tensor by name.
|
||||
|
||||
:param name: The name of the tensor to load.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
|
||||
:param conditioning_context_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.save(obj=conditioning_data)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""
|
||||
Loads conditioning data by name.
|
||||
|
||||
:param name: The name of the conditioning data to load.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
||||
"""
|
||||
Checks if a model exists.
|
||||
|
||||
:param model_name: The name of the model to check.
|
||||
:param base_model: The base model of the model to check.
|
||||
:param model_type: The type of the model to check.
|
||||
"""
|
||||
return self._services.model_manager.model_exists(model_name, base_model, model_type)
|
||||
|
||||
def load(
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
) -> LoadedModelInfo:
|
||||
"""
|
||||
Loads a model.
|
||||
|
||||
:param model_name: The name of the model to get.
|
||||
:param base_model: The base model of the model to get.
|
||||
:param model_type: The type of the model to get.
|
||||
:param submodel: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
return self._services.model_manager.get_model(
|
||||
model_name, base_model, model_type, submodel, context_data=self._context_data
|
||||
)
|
||||
|
||||
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
|
||||
:param model_name: The name of the model to get.
|
||||
:param base_model: The base model of the model to get.
|
||||
:param model_type: The type of the model to get.
|
||||
"""
|
||||
return self._services.model_manager.model_info(model_name, base_model, model_type)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config."""
|
||||
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
steps, a preview image, and some other internal metadata.
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
# The step callback needs access to the events and the invocation queue services, but this
|
||||
# represents a dangerous level of access.
|
||||
#
|
||||
# We wrap the step callback so that nodes do not have direct access to these services.
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
context_data=self._context_data,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=base_model,
|
||||
invocation_queue=self._services.queue,
|
||||
events=self._services.events,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
images: ImagesInterface,
|
||||
tensors: TensorsInterface,
|
||||
conditioning: ConditioningInterface,
|
||||
models: ModelsInterface,
|
||||
logger: LoggerInterface,
|
||||
config: ConfigInterface,
|
||||
util: UtilInterface,
|
||||
boards: BoardsInterface,
|
||||
context_data: InvocationContextData,
|
||||
services: InvocationServices,
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
self.tensors = tensors
|
||||
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"""Provides methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
"""Provides methods to check if a model exists, get a model, and get a model's info."""
|
||||
self.logger = logger
|
||||
"""Provides access to the app logger."""
|
||||
self.config = config
|
||||
"""Provides access to the app's config."""
|
||||
self.util = util
|
||||
"""Provides utility methods."""
|
||||
self.boards = boards
|
||||
"""Provides methods to interact with boards."""
|
||||
self._data = context_data
|
||||
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
||||
self._services = services
|
||||
"""Provides access to the full application services. This is an internal API and may change without warning."""
|
||||
|
||||
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
context_data: InvocationContextData,
|
||||
) -> InvocationContext:
|
||||
"""
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
|
||||
:param invocation_services: The invocation services to wrap.
|
||||
:param invocation_context_data: The invocation context data.
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, context_data=context_data)
|
||||
images = ImagesInterface(services=services, context_data=context_data)
|
||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||
models = ModelsInterface(services=services, context_data=context_data)
|
||||
config = ConfigInterface(services=services, context_data=context_data)
|
||||
util = UtilInterface(services=services, context_data=context_data)
|
||||
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||
boards = BoardsInterface(services=services, context_data=context_data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
images=images,
|
||||
logger=logger,
|
||||
config=config,
|
||||
tensors=tensors,
|
||||
models=models,
|
||||
context_data=context_data,
|
||||
util=util,
|
||||
conditioning=conditioning,
|
||||
services=services,
|
||||
boards=boards,
|
||||
)
|
||||
|
||||
return ctx
|
@ -1,67 +0,0 @@
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.invocations.fields import FieldDescriptions
|
||||
|
||||
|
||||
class FreeUConfig(BaseModel):
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
@ -6,7 +8,11 @@ from invokeai.app.services.invocation_processor.invocation_processor_common impo
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
@ -25,13 +31,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
context_data: "InvocationContextData",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
invocation_queue: "InvocationQueueABC",
|
||||
events: "EventServiceBase",
|
||||
) -> None:
|
||||
if invocation_queue.is_canceled(context_data.session_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
@ -108,13 +114,13 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
|
@ -1,5 +1,12 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
||||
from .model_management import ( # noqa: F401
|
||||
BaseModelType,
|
||||
LoadedModelInfo,
|
||||
ModelCache,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
@ -3,7 +3,7 @@
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
# This import must be first
|
||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
||||
from .model_manager import AddModelResult, LoadedModelInfo, ModelManager, SchedulerPredictionType
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .model_cache import ModelCache
|
||||
|
||||
|
@ -271,7 +271,7 @@ CONFIG_FILE_VERSION = "3.0.0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
class LoadedModelInfo:
|
||||
context: ModelLocker
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
@ -450,7 +450,7 @@ class ModelManager(object):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelInfo:
|
||||
) -> LoadedModelInfo:
|
||||
"""Given a model named identified in models.yaml, return
|
||||
an ModelInfo object describing it.
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
@ -508,7 +508,7 @@ class ModelManager(object):
|
||||
|
||||
model_hash = "<NO_HASH>" # TODO:
|
||||
|
||||
return ModelInfo(
|
||||
return LoadedModelInfo(
|
||||
context=model_context,
|
||||
name=model_name,
|
||||
base_model=base_model,
|
||||
|
@ -32,6 +32,11 @@ class BasicConditioningInfo:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
|
@ -3,7 +3,7 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ def install_and_load_model(
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelInfo:
|
||||
"""Install a model if it is not already installed, then get the ModelInfo for that model.
|
||||
) -> LoadedModelInfo:
|
||||
"""Install a model if it is not already installed, then get the LoadedModelInfo for that model.
|
||||
|
||||
This is intended as a utility function for tests.
|
||||
|
||||
@ -49,9 +49,9 @@ def install_and_load_model(
|
||||
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
||||
|
||||
Returns:
|
||||
ModelInfo
|
||||
LoadedModelInfo
|
||||
"""
|
||||
# If the requested model is already installed, return its ModelInfo.
|
||||
# If the requested model is already installed, return its LoadedModelInfo.
|
||||
with contextlib.suppress(ModelNotFoundException):
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { LINEAR_UI_OUTPUT, nodeIDDenyList } from 'features/nodes/util/graph/constants';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
@ -24,10 +24,9 @@ export const addInvocationCompleteEventListener = () => {
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
|
||||
const { result, node, queue_batch_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type) && !nodeIDDenyList.includes(source_node_id)) {
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
@ -42,7 +41,7 @@ export const addInvocationCompleteEventListener = () => {
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(queue_batch_id) && [LINEAR_UI_OUTPUT].includes(data.source_node_id)) {
|
||||
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
|
@ -39,16 +39,12 @@ export const addUpscaleRequestedListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { esrganModelName } = state.postprocessing;
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: buildAdHocUpscaleGraph({
|
||||
image_name,
|
||||
esrganModelName,
|
||||
autoAddBoardId,
|
||||
state,
|
||||
}),
|
||||
runs: 1,
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import type {
|
||||
DenoiseLatentsInvocation,
|
||||
@ -314,11 +315,16 @@ export const addHrfToGraph = (state: RootState, graph: NonNullableGraph): void =
|
||||
);
|
||||
copyConnectionsToDenoiseLatentsHrf(graph);
|
||||
|
||||
// The original l2i node is unnecessary now, remove it
|
||||
graph.edges = graph.edges.filter((edge) => edge.destination.node_id !== LATENTS_TO_IMAGE);
|
||||
delete graph.nodes[LATENTS_TO_IMAGE];
|
||||
|
||||
graph.nodes[LATENTS_TO_IMAGE_HRF_HR] = {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE_HRF_HR,
|
||||
fp32: originalLatentsToImageNode?.fp32,
|
||||
is_intermediate: true,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
graph.edges.push(
|
||||
{
|
||||
|
@ -1,78 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import type { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CANVAS_OUTPUT,
|
||||
LATENTS_TO_IMAGE,
|
||||
LATENTS_TO_IMAGE_HRF_HR,
|
||||
LINEAR_UI_OUTPUT,
|
||||
NSFW_CHECKER,
|
||||
WATERMARKER,
|
||||
} from './constants';
|
||||
|
||||
/**
|
||||
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
|
||||
*/
|
||||
export const addLinearUIOutputNode = (state: RootState, graph: NonNullableGraph): void => {
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
|
||||
const linearUIOutputNode: LinearUIOutputInvocation = {
|
||||
id: LINEAR_UI_OUTPUT,
|
||||
type: 'linear_ui_output',
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId },
|
||||
};
|
||||
|
||||
graph.nodes[LINEAR_UI_OUTPUT] = linearUIOutputNode;
|
||||
|
||||
const destination = {
|
||||
node_id: LINEAR_UI_OUTPUT,
|
||||
field: 'image',
|
||||
};
|
||||
|
||||
if (WATERMARKER in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: WATERMARKER,
|
||||
field: 'image',
|
||||
},
|
||||
destination,
|
||||
});
|
||||
} else if (NSFW_CHECKER in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: NSFW_CHECKER,
|
||||
field: 'image',
|
||||
},
|
||||
destination,
|
||||
});
|
||||
} else if (CANVAS_OUTPUT in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_OUTPUT,
|
||||
field: 'image',
|
||||
},
|
||||
destination,
|
||||
});
|
||||
} else if (LATENTS_TO_IMAGE_HRF_HR in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: LATENTS_TO_IMAGE_HRF_HR,
|
||||
field: 'image',
|
||||
},
|
||||
destination,
|
||||
});
|
||||
} else if (LATENTS_TO_IMAGE in graph.nodes) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination,
|
||||
});
|
||||
}
|
||||
};
|
@ -2,6 +2,7 @@ import type { RootState } from 'app/store/store';
|
||||
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
|
||||
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
|
||||
export const addNSFWCheckerToGraph = (
|
||||
state: RootState,
|
||||
@ -21,7 +22,8 @@ export const addNSFWCheckerToGraph = (
|
||||
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
|
||||
id: NSFW_CHECKER,
|
||||
type: 'img_nsfw',
|
||||
is_intermediate: true,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
|
||||
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
|
||||
|
@ -24,7 +24,7 @@ import {
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
|
@ -1,5 +1,4 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import type {
|
||||
ImageNSFWBlurInvocation,
|
||||
ImageWatermarkInvocation,
|
||||
@ -8,16 +7,13 @@ import type {
|
||||
} from 'services/api/types';
|
||||
|
||||
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
|
||||
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
|
||||
export const addWatermarkerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
nodeIdToAddTo = LATENTS_TO_IMAGE
|
||||
): void => {
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
|
||||
const is_intermediate = activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
|
||||
|
||||
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
|
||||
|
||||
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
|
||||
@ -30,7 +26,8 @@ export const addWatermarkerToGraph = (
|
||||
const watermarkerNode: ImageWatermarkInvocation = {
|
||||
id: WATERMARKER,
|
||||
type: 'img_watermark',
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
|
||||
graph.nodes[WATERMARKER] = watermarkerNode;
|
||||
|
@ -1,51 +1,33 @@
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import type { ParamESRGANModelName } from 'features/parameters/store/postprocessingSlice';
|
||||
import type { ESRGANInvocation, Graph, LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { ESRGAN, LINEAR_UI_OUTPUT } from './constants';
|
||||
import { ESRGAN } from './constants';
|
||||
import { addCoreMetadataNode, upsertMetadata } from './metadata';
|
||||
|
||||
type Arg = {
|
||||
image_name: string;
|
||||
esrganModelName: ParamESRGANModelName;
|
||||
autoAddBoardId: BoardId;
|
||||
state: RootState;
|
||||
};
|
||||
|
||||
export const buildAdHocUpscaleGraph = ({ image_name, esrganModelName, autoAddBoardId }: Arg): Graph => {
|
||||
export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
|
||||
const { esrganModelName } = state.postprocessing;
|
||||
|
||||
const realesrganNode: ESRGANInvocation = {
|
||||
id: ESRGAN,
|
||||
type: 'esrgan',
|
||||
image: { image_name },
|
||||
model_name: esrganModelName,
|
||||
is_intermediate: true,
|
||||
};
|
||||
|
||||
const linearUIOutputNode: LinearUIOutputInvocation = {
|
||||
id: LINEAR_UI_OUTPUT,
|
||||
type: 'linear_ui_output',
|
||||
use_cache: false,
|
||||
is_intermediate: false,
|
||||
board: autoAddBoardId === 'none' ? undefined : { board_id: autoAddBoardId },
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
};
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: `adhoc-esrgan-graph`,
|
||||
nodes: {
|
||||
[ESRGAN]: realesrganNode,
|
||||
[LINEAR_UI_OUTPUT]: linearUIOutputNode,
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: ESRGAN,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: LINEAR_UI_OUTPUT,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
],
|
||||
edges: [],
|
||||
};
|
||||
|
||||
addCoreMetadataNode(graph, {}, ESRGAN);
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -132,7 +132,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
@ -242,7 +243,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
type: 'img_resize',
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
@ -284,7 +286,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
@ -355,7 +358,5 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageBlurInvocation,
|
||||
@ -12,7 +13,6 @@ import type {
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -191,7 +191,8 @@ export const buildCanvasInpaintGraph = (
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
reference: canvasInitImage,
|
||||
use_cache: false,
|
||||
},
|
||||
@ -663,7 +664,5 @@ export const buildCanvasInpaintGraph = (
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type {
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
@ -11,7 +12,6 @@ import type {
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -200,7 +200,8 @@ export const buildCanvasOutpaintGraph = (
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
@ -769,7 +770,5 @@ export const buildCanvasOutpaintGraph = (
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -4,7 +4,6 @@ import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'servi
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -26,7 +25,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
@ -246,7 +245,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
type: 'img_resize',
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
@ -368,7 +368,5 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -12,7 +12,6 @@ import type {
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -44,7 +43,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
@ -190,7 +189,8 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
reference: canvasInitImage,
|
||||
use_cache: false,
|
||||
},
|
||||
@ -687,7 +687,5 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -11,7 +11,6 @@ import type {
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -46,7 +45,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
@ -199,7 +198,8 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
@ -786,7 +786,5 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -24,7 +23,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
@ -222,7 +221,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
type: 'img_resize',
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
@ -254,7 +254,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
@ -330,7 +331,5 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -211,7 +211,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
type: 'img_resize',
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
@ -243,7 +244,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
@ -310,7 +312,5 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -117,7 +117,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -358,7 +359,5 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -4,7 +4,6 @@ import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -25,7 +24,7 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
/**
|
||||
@ -120,7 +119,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -380,7 +380,5 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -4,7 +4,6 @@ import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
@ -23,7 +22,7 @@ import {
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './getSDXLStylePrompt';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
@ -120,7 +119,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
@ -281,7 +281,5 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
import { addLinearUIOutputNode } from './addLinearUIOutputNode';
|
||||
import { addLoRAsToGraph } from './addLoRAsToGraph';
|
||||
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
|
||||
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
|
||||
@ -119,7 +119,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
is_intermediate: getIsIntermediate(state),
|
||||
board: getBoardField(state),
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
@ -267,7 +268,5 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
addWatermarkerToGraph(state, graph);
|
||||
}
|
||||
|
||||
addLinearUIOutputNode(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
||||
|
@ -9,7 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
|
||||
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
|
||||
export const RESIZE_HRF = 'resize_hrf';
|
||||
export const ESRGAN_HRF = 'esrgan_hrf';
|
||||
export const LINEAR_UI_OUTPUT = 'linear_ui_output';
|
||||
export const NSFW_CHECKER = 'nsfw_checker';
|
||||
export const WATERMARKER = 'invisible_watermark';
|
||||
export const NOISE = 'noise';
|
||||
|
@ -1,11 +0,0 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
|
||||
export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => {
|
||||
const { positivePrompt, negativePrompt } = state.generation;
|
||||
const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl;
|
||||
|
||||
return {
|
||||
positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt,
|
||||
negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt,
|
||||
};
|
||||
};
|
@ -0,0 +1,38 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { BoardField } from 'features/nodes/types/common';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
/**
|
||||
* Gets the board field, based on the autoAddBoardId setting.
|
||||
*/
|
||||
export const getBoardField = (state: RootState): BoardField | undefined => {
|
||||
const { autoAddBoardId } = state.gallery;
|
||||
if (autoAddBoardId === 'none') {
|
||||
return undefined;
|
||||
}
|
||||
return { board_id: autoAddBoardId };
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the SDXL style prompts, based on the concat setting.
|
||||
*/
|
||||
export const getSDXLStylePrompts = (state: RootState): { positiveStylePrompt: string; negativeStylePrompt: string } => {
|
||||
const { positivePrompt, negativePrompt } = state.generation;
|
||||
const { positiveStylePrompt, negativeStylePrompt, shouldConcatSDXLStylePrompt } = state.sdxl;
|
||||
|
||||
return {
|
||||
positiveStylePrompt: shouldConcatSDXLStylePrompt ? positivePrompt : positiveStylePrompt,
|
||||
negativeStylePrompt: shouldConcatSDXLStylePrompt ? negativePrompt : negativeStylePrompt,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the is_intermediate field, based on the active tab and shouldAutoSave setting.
|
||||
*/
|
||||
export const getIsIntermediate = (state: RootState) => {
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
return !state.canvas.shouldAutoSave;
|
||||
}
|
||||
return false;
|
||||
};
|
File diff suppressed because one or more lines are too long
@ -132,7 +132,6 @@ export type DivideInvocation = s['DivideInvocation'];
|
||||
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
|
||||
export type LinearUIOutputInvocation = s['LinearUIOutputInvocation'];
|
||||
export type MetadataInvocation = s['MetadataInvocation'];
|
||||
export type CoreMetadataInvocation = s['CoreMetadataInvocation'];
|
||||
export type MetadataItemInvocation = s['MetadataItemInvocation'];
|
||||
|
187
invokeai/invocation_api/__init__.py
Normal file
187
invokeai/invocation_api/__init__.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""
|
||||
This file re-exports all the public API for invocations. This is the only file that should be imported by custom nodes.
|
||||
|
||||
TODO(psyche): Do we want to dogfood this?
|
||||
"""
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
BoardField,
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FieldKind,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
)
|
||||
from invokeai.app.invocations.latent import SchedulerOutput
|
||||
from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOutput, MetadataOutput
|
||||
from invokeai.app.invocations.model import (
|
||||
ClipField,
|
||||
CLIPOutput,
|
||||
LoraInfo,
|
||||
LoraLoaderOutput,
|
||||
LoRAModelField,
|
||||
MainModelField,
|
||||
ModelInfo,
|
||||
ModelLoaderOutput,
|
||||
SDXLLoraLoaderOutput,
|
||||
UNetField,
|
||||
UNetOutput,
|
||||
VaeField,
|
||||
VAEModelField,
|
||||
VAEOutput,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import (
|
||||
BooleanCollectionOutput,
|
||||
BooleanOutput,
|
||||
ColorCollectionOutput,
|
||||
ColorOutput,
|
||||
ConditioningCollectionOutput,
|
||||
ConditioningOutput,
|
||||
DenoiseMaskOutput,
|
||||
FloatCollectionOutput,
|
||||
FloatOutput,
|
||||
ImageCollectionOutput,
|
||||
ImageOutput,
|
||||
IntegerCollectionOutput,
|
||||
IntegerOutput,
|
||||
LatentsCollectionOutput,
|
||||
LatentsOutput,
|
||||
StringCollectionOutput,
|
||||
StringOutput,
|
||||
)
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
|
||||
from invokeai.version import __version__
|
||||
|
||||
__all__ = [
|
||||
# invokeai.app.invocations.baseinvocation
|
||||
"BaseInvocation",
|
||||
"BaseInvocationOutput",
|
||||
"Classification",
|
||||
"invocation",
|
||||
"invocation_output",
|
||||
# invokeai.app.services.shared.invocation_context
|
||||
"InvocationContext",
|
||||
# invokeai.app.invocations.fields
|
||||
"BoardField",
|
||||
"ColorField",
|
||||
"ConditioningField",
|
||||
"DenoiseMaskField",
|
||||
"FieldDescriptions",
|
||||
"FieldKind",
|
||||
"ImageField",
|
||||
"Input",
|
||||
"InputField",
|
||||
"LatentsField",
|
||||
"MetadataField",
|
||||
"OutputField",
|
||||
"UIComponent",
|
||||
"UIType",
|
||||
"WithMetadata",
|
||||
"WithWorkflow",
|
||||
# invokeai.app.invocations.latent
|
||||
"SchedulerOutput",
|
||||
# invokeai.app.invocations.metadata
|
||||
"MetadataItemField",
|
||||
"MetadataItemOutput",
|
||||
"MetadataOutput",
|
||||
# invokeai.app.invocations.model
|
||||
"ModelInfo",
|
||||
"LoraInfo",
|
||||
"UNetField",
|
||||
"ClipField",
|
||||
"VaeField",
|
||||
"MainModelField",
|
||||
"LoRAModelField",
|
||||
"VAEModelField",
|
||||
"UNetOutput",
|
||||
"VAEOutput",
|
||||
"CLIPOutput",
|
||||
"ModelLoaderOutput",
|
||||
"LoraLoaderOutput",
|
||||
"SDXLLoraLoaderOutput",
|
||||
# invokeai.app.invocations.primitives
|
||||
"BooleanCollectionOutput",
|
||||
"BooleanOutput",
|
||||
"ColorCollectionOutput",
|
||||
"ColorOutput",
|
||||
"ConditioningCollectionOutput",
|
||||
"ConditioningOutput",
|
||||
"DenoiseMaskOutput",
|
||||
"FloatCollectionOutput",
|
||||
"FloatOutput",
|
||||
"ImageCollectionOutput",
|
||||
"ImageOutput",
|
||||
"IntegerCollectionOutput",
|
||||
"IntegerOutput",
|
||||
"LatentsCollectionOutput",
|
||||
"LatentsOutput",
|
||||
"StringCollectionOutput",
|
||||
"StringOutput",
|
||||
# invokeai.app.services.image_records.image_records_common
|
||||
"ImageCategory",
|
||||
# invokeai.app.services.boards.boards_common
|
||||
"BoardDTO",
|
||||
# invokeai.backend.stable_diffusion.diffusion.conditioning_data
|
||||
"BasicConditioningInfo",
|
||||
"ConditioningFieldData",
|
||||
"ExtraConditioningInfo",
|
||||
"SDXLConditioningInfo",
|
||||
# invokeai.backend.stable_diffusion.diffusers_pipeline
|
||||
"PipelineIntermediateState",
|
||||
# invokeai.app.services.workflow_records.workflow_records_common
|
||||
"WorkflowWithoutID",
|
||||
# invokeai.app.services.config.config_default
|
||||
"InvokeAIAppConfig",
|
||||
# invokeai.backend.model_management.model_manager
|
||||
"LoadedModelInfo",
|
||||
# invokeai.backend.model_management.models.base
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
# invokeai.app.invocations.constants
|
||||
"SCHEDULER_NAME_VALUES",
|
||||
# invokeai.version
|
||||
"__version__",
|
||||
# invokeai.backend.util.devices
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
"CPU_DEVICE",
|
||||
"CUDA_DEVICE",
|
||||
"MPS_DEVICE",
|
||||
# invokeai.app.util.misc
|
||||
"SEED_MAX",
|
||||
"get_random_seed",
|
||||
]
|
@ -66,6 +66,7 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"datasets",
|
||||
"Deprecated",
|
||||
"dnspython~=2.4.0",
|
||||
"dynamicprompts",
|
||||
"easing-functions",
|
||||
@ -169,6 +170,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
"invokeai.frontend.web.static*",
|
||||
"invokeai.configs*",
|
||||
"invokeai.app*",
|
||||
"invokeai.invocation_api*",
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
@ -280,3 +282,38 @@ module = [
|
||||
"invokeai.frontend.install.model_install",
|
||||
]
|
||||
#=== End: MyPy
|
||||
|
||||
[tool.pyright]
|
||||
# Start from strict mode
|
||||
typeCheckingMode = "strict"
|
||||
# This errors whenever an import is missing a type stub file - way too noisy
|
||||
reportMissingTypeStubs = "none"
|
||||
# These are the rest of the rules enabled by strict mode - enable them @ warning
|
||||
reportConstantRedefinition = "warning"
|
||||
reportDeprecated = "warning"
|
||||
reportDuplicateImport = "warning"
|
||||
reportIncompleteStub = "warning"
|
||||
reportInconsistentConstructor = "warning"
|
||||
reportInvalidStubStatement = "warning"
|
||||
reportMatchNotExhaustive = "warning"
|
||||
reportMissingParameterType = "warning"
|
||||
reportMissingTypeArgument = "warning"
|
||||
reportPrivateUsage = "warning"
|
||||
reportTypeCommentUsage = "warning"
|
||||
reportUnknownArgumentType = "warning"
|
||||
reportUnknownLambdaType = "warning"
|
||||
reportUnknownMemberType = "warning"
|
||||
reportUnknownParameterType = "warning"
|
||||
reportUnknownVariableType = "warning"
|
||||
reportUnnecessaryCast = "warning"
|
||||
reportUnnecessaryComparison = "warning"
|
||||
reportUnnecessaryContains = "warning"
|
||||
reportUnnecessaryIsInstance = "warning"
|
||||
reportUnusedClass = "warning"
|
||||
reportUnusedImport = "warning"
|
||||
reportUnusedFunction = "warning"
|
||||
reportUnusedVariable = "warning"
|
||||
reportUntypedBaseClass = "warning"
|
||||
reportUntypedClassDecorator = "warning"
|
||||
reportUntypedFunctionDecorator = "warning"
|
||||
reportUntypedNamedTuple = "warning"
|
||||
|
@ -21,7 +21,6 @@ from invokeai.app.services.invocation_processor.invocation_processor_default imp
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Graph,
|
||||
@ -61,7 +60,6 @@ def mock_services() -> InvocationServices:
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
@ -75,6 +73,8 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None,
|
||||
conditioning=None,
|
||||
)
|
||||
|
||||
|
||||
@ -86,12 +86,16 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
||||
print(f"invoking {n.id}: {type(n)}")
|
||||
o = n.invoke(
|
||||
InvocationContext(
|
||||
queue_batch_id="1",
|
||||
queue_item_id=1,
|
||||
queue_id=DEFAULT_QUEUE_ID,
|
||||
services=services,
|
||||
graph_execution_state_id="1",
|
||||
workflow=None,
|
||||
conditioning=None,
|
||||
config=None,
|
||||
context_data=None,
|
||||
images=None,
|
||||
tensors=None,
|
||||
logger=None,
|
||||
models=None,
|
||||
util=None,
|
||||
boards=None,
|
||||
services=None,
|
||||
)
|
||||
)
|
||||
g.complete(n.id, o)
|
||||
|
@ -63,7 +63,6 @@ def mock_services() -> InvocationServices:
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||
latents=None, # type: ignore
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
@ -77,6 +76,8 @@ def mock_services() -> InvocationServices:
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None,
|
||||
conditioning=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3,13 +3,12 @@ from typing import Any, Callable, Union
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
|
172
tests/test_object_serializer_disk.py
Normal file
172
tests/test_object_serializer_disk.py
Normal file
@ -0,0 +1,172 @@
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockDataclass:
|
||||
foo: str
|
||||
|
||||
|
||||
def count_files(path: Path):
|
||||
return len(list(path.iterdir()))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def obj_serializer(tmp_path: Path):
|
||||
return ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fwd_cache(tmp_path: Path):
|
||||
return ObjectSerializerForwardCache(ObjectSerializerDisk[MockDataclass](tmp_path), max_cache_size=2)
|
||||
|
||||
|
||||
def test_obj_serializer_disk_initializes(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
assert obj_serializer._output_dir == tmp_path
|
||||
|
||||
|
||||
def test_obj_serializer_disk_saves(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_loads(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert obj_serializer.load(obj_1_name).foo == "bar"
|
||||
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
assert obj_serializer.load(obj_2_name).foo == "baz"
|
||||
|
||||
with pytest.raises(ObjectNotFoundError):
|
||||
obj_serializer.load("nonexistent_object_name")
|
||||
|
||||
|
||||
def test_obj_serializer_disk_deletes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
|
||||
obj_2 = MockDataclass(foo="bar")
|
||||
obj_2_name = obj_serializer.save(obj_2)
|
||||
|
||||
obj_serializer.delete(obj_1_name)
|
||||
assert not Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
assert Path(obj_serializer._output_dir, obj_2_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_creates_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert isinstance(obj_serializer._tempdir, tempfile.TemporaryDirectory)
|
||||
assert obj_serializer._base_output_dir == tmp_path
|
||||
assert obj_serializer._output_dir != tmp_path
|
||||
assert obj_serializer._output_dir == Path(obj_serializer._tempdir.name)
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
del obj_serializer
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_tempdir_on_stop(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
tempdir_path = obj_serializer._output_dir
|
||||
obj_serializer.stop(None) # pyright: ignore [reportArgumentType]
|
||||
assert not tempdir_path.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
obj_serializer = ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer.save(obj_1)
|
||||
assert Path(obj_serializer._output_dir, obj_1_name).exists()
|
||||
assert not Path(tmp_path, obj_1_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = obj_serializer_1.save(obj_1)
|
||||
obj_1_loaded = obj_serializer_1.load(obj_1_name)
|
||||
assert obj_serializer_1._obj_class_name == "MockDataclass"
|
||||
assert isinstance(obj_1_loaded, MockDataclass)
|
||||
assert obj_1_loaded.foo == "bar"
|
||||
assert obj_1_name.startswith("MockDataclass_")
|
||||
|
||||
obj_serializer_2 = ObjectSerializerDisk[int](tmp_path)
|
||||
obj_2_name = obj_serializer_2.save(9001)
|
||||
assert obj_serializer_2._obj_class_name == "int"
|
||||
assert obj_serializer_2.load(obj_2_name) == 9001
|
||||
assert obj_2_name.startswith("int_")
|
||||
|
||||
obj_serializer_3 = ObjectSerializerDisk[str](tmp_path)
|
||||
obj_3_name = obj_serializer_3.save("foo")
|
||||
assert obj_serializer_3._obj_class_name == "str"
|
||||
assert obj_serializer_3.load(obj_3_name) == "foo"
|
||||
assert obj_3_name.startswith("str_")
|
||||
|
||||
obj_serializer_4 = ObjectSerializerDisk[torch.Tensor](tmp_path)
|
||||
obj_4_name = obj_serializer_4.save(torch.tensor([1, 2, 3]))
|
||||
obj_4_loaded = obj_serializer_4.load(obj_4_name)
|
||||
assert obj_serializer_4._obj_class_name == "Tensor"
|
||||
assert isinstance(obj_4_loaded, torch.Tensor)
|
||||
assert torch.equal(obj_4_loaded, torch.tensor([1, 2, 3]))
|
||||
assert obj_4_name.startswith("Tensor_")
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_initializes(obj_serializer: ObjectSerializerDisk[MockDataclass]):
|
||||
fwd_cache = ObjectSerializerForwardCache(obj_serializer)
|
||||
assert fwd_cache._underlying_storage == obj_serializer
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_saves_and_loads(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
obj = MockDataclass(foo="bar")
|
||||
obj_name = fwd_cache.save(obj)
|
||||
obj_loaded = fwd_cache.load(obj_name)
|
||||
obj_underlying = fwd_cache._underlying_storage.load(obj_name)
|
||||
assert obj_loaded == obj_underlying
|
||||
assert obj_loaded.foo == "bar"
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_respects_cache_size(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
obj_1_name = fwd_cache.save(obj_1)
|
||||
obj_2 = MockDataclass(foo="baz")
|
||||
obj_2_name = fwd_cache.save(obj_2)
|
||||
obj_3 = MockDataclass(foo="qux")
|
||||
obj_3_name = fwd_cache.save(obj_3)
|
||||
assert obj_1_name not in fwd_cache._cache
|
||||
assert obj_2_name in fwd_cache._cache
|
||||
assert obj_3_name in fwd_cache._cache
|
||||
# apparently qsize is "not reliable"?
|
||||
assert fwd_cache._cache_ids.qsize() == 2
|
||||
|
||||
|
||||
def test_obj_serializer_fwd_cache_calls_delete_callback(fwd_cache: ObjectSerializerForwardCache[MockDataclass]):
|
||||
called_name = None
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
|
||||
def on_deleted(name: str):
|
||||
nonlocal called_name
|
||||
called_name = name
|
||||
|
||||
fwd_cache.on_deleted(on_deleted)
|
||||
obj_1_name = fwd_cache.save(obj_1)
|
||||
fwd_cache.delete(obj_1_name)
|
||||
assert called_name == obj_1_name
|
Reference in New Issue
Block a user