nodes phase 5: workflow saving and loading (#4353)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

- Workflows are saved to image files directly
- Image-outputting nodes have an `Embed Workflow` checkbox which, if
enabled, saves the workflow
- `BaseInvocation` now has an `workflow: Optional[str]` field, so all
nodes automatically have the field (but again only image-outputting
nodes display this in UI)
- If this field is enabled, when the graph is created, the workflow is
stringified and set in this field
- Nodes should add `workflow=self.workflow` when they save their output
image to have the workflow written to the image
- Uploads now have their metadata retained so that you can upload
somebody else's image and have access to that workflow
- Graphs are no longer saved to images, workflows replace them

### TODO
- Images created in the linear UI do not have a workflow saved yet. Need
to write a function to build a workflow around the linear UI graph when
using linear tabs. Unfortunately it will not have the nice positioning
and size data the node editor gives you when you save a workflow...
we'll have to figure out how to handle this.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->
This commit is contained in:
Kent Keirsey 2023-08-30 15:05:17 -04:00 committed by GitHub
commit 2bd3cf28ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
103 changed files with 3434 additions and 2729 deletions

View File

@ -29,12 +29,13 @@ The first set of things we need to do when creating a new Invocation are -
- Create a new class that derives from a predefined parent class called - Create a new class that derives from a predefined parent class called
`BaseInvocation`. `BaseInvocation`.
- The name of every Invocation must end with the word `Invocation` in order for
it to be recognized as an Invocation.
- Every Invocation must have a `docstring` that describes what this Invocation - Every Invocation must have a `docstring` that describes what this Invocation
does. does.
- Every Invocation must have a unique `type` field defined which becomes its - While not strictly required, we suggest every invocation class name ends in
indentifier. "Invocation", eg "CropImageInvocation".
- Every Invocation must use the `@invocation` decorator to provide its unique
invocation type. You may also provide its title, tags and category using the
decorator.
- Invocations are strictly typed. We make use of the native - Invocations are strictly typed. We make use of the native
[typing](https://docs.python.org/3/library/typing.html) library and the [typing](https://docs.python.org/3/library/typing.html) library and the
installed [pydantic](https://pydantic-docs.helpmanual.io/) library for installed [pydantic](https://pydantic-docs.helpmanual.io/) library for
@ -43,12 +44,11 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that. So let us do that.
```python ```python
from typing import Literal from .baseinvocation import BaseInvocation, invocation
from .baseinvocation import BaseInvocation
@invocation('resize')
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image''' '''Resizes an image'''
type: Literal['resize'] = 'resize'
``` ```
That's great. That's great.
@ -62,8 +62,10 @@ our Invocation takes.
### **Inputs** ### **Inputs**
Every Invocation input is a pydantic `Field` and like everything else should be Every Invocation input must be defined using the `InputField` function. This is
strictly typed and defined. a wrapper around the pydantic `Field` function, which handles a few extra things
and provides type hints. Like everything else, this should be strictly typed and
defined.
So let us create these inputs for our Invocation. First up, the `image` input we So let us create these inputs for our Invocation. First up, the `image` input we
need. Generally, we can use standard variable types in Python but InvokeAI need. Generally, we can use standard variable types in Python but InvokeAI
@ -76,55 +78,51 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it. and use it.
```python ```python
from typing import Literal, Union from .baseinvocation import BaseInvocation, InputField, invocation
from pydantic import Field from .primitives import ImageField
from .baseinvocation import BaseInvocation
from ..models.image import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: ImageField = InputField(description="The input image")
``` ```
Let us break down our input code. Let us break down our input code.
```python ```python
image: Union[ImageField, None] = Field(description="The input image", default=None) image: ImageField = InputField(description="The input image")
``` ```
| Part | Value | Description | | Part | Value | Description |
| --------- | ---------------------------------------------------- | -------------------------------------------------------------------------------------------------- | | --------- | ------------------------------------------- | ------------------------------------------------------------------------------- |
| Name | `image` | The variable that will hold our image | | Name | `image` | The variable that will hold our image |
| Type Hint | `Union[ImageField, None]` | The types for our field. Indicates that the image can either be an `ImageField` type or `None` | | Type Hint | `ImageField` | The types for our field. Indicates that the image must be an `ImageField` type. |
| Field | `Field(description="The input image", default=None)` | The image variable is a field which needs a description and a default value that we set to `None`. | | Field | `InputField(description="The input image")` | The image variable is an `InputField` which needs a description. |
Great. Now let us create our other inputs for `width` and `height` Great. Now let us create our other inputs for `width` and `height`
```python ```python
from typing import Literal, Union from .baseinvocation import BaseInvocation, InputField, invocation
from pydantic import Field from .primitives import ImageField
from .baseinvocation import BaseInvocation
from ..models.image import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image''' '''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: ImageField = InputField(description="The input image")
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
``` ```
As you might have noticed, we added two new parameters to the field type for As you might have noticed, we added two new arguments to the `InputField`
`width` and `height` called `gt` and `le`. These basically stand for _greater definition for `width` and `height`, called `gt` and `le`. They stand for
than or equal to_ and _less than or equal to_. There are various other param _greater than or equal to_ and _less than or equal to_.
types for field that you can find on the **pydantic** documentation.
These impose contraints on those fields, and will raise an exception if the
values do not meet the constraints. Field constraints are provided by
**pydantic**, so anything you see in the **pydantic docs** will work.
**Note:** _Any time it is possible to define constraints for our field, we **Note:** _Any time it is possible to define constraints for our field, we
should do it so the frontend has more information on how to parse this field._ should do it so the frontend has more information on how to parse this field._
@ -141,20 +139,17 @@ that are provided by it by InvokeAI.
Let us create this function first. Let us create this function first.
```python ```python
from typing import Literal, Union from .baseinvocation import BaseInvocation, InputField, invocation
from pydantic import Field from .primitives import ImageField
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image''' '''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: ImageField = InputField(description="The input image")
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext): def invoke(self, context: InvocationContext):
pass pass
@ -173,21 +168,18 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide. We will cover how to create your own output types later in this guide.
```python ```python
from typing import Literal, Union from .baseinvocation import BaseInvocation, InputField, invocation
from pydantic import Field from .primitives import ImageField
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField
from .image import ImageOutput from .image import ImageOutput
@invocation('resize')
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image''' '''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: ImageField = InputField(description="The input image")
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image") width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pass pass
@ -195,39 +187,34 @@ class ResizeInvocation(BaseInvocation):
Perfect. Now that we have our Invocation setup, let us do what we want to do. Perfect. Now that we have our Invocation setup, let us do what we want to do.
- We will first load the image. Generally we do this using the `PIL` library but - We will first load the image using one of the services provided by InvokeAI to
we can use one of the services provided by InvokeAI to load the image. load the image.
- We will resize the image using `PIL` to our input data. - We will resize the image using `PIL` to our input data.
- We will output this image in the format we set above. - We will output this image in the format we set above.
So let's do that. So let's do that.
```python ```python
from typing import Literal, Union from .baseinvocation import BaseInvocation, InputField, invocation
from pydantic import Field from .primitives import ImageField
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField, ResourceOrigin, ImageCategory
from .image import ImageOutput from .image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation): class ResizeInvocation(BaseInvocation):
'''Resizes an image''' """Resizes an image"""
type: Literal['resize'] = 'resize'
# Inputs image: ImageField = InputField(description="The input image")
image: Union[ImageField, None] = Field(description="The input image", default=None) width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the image using InvokeAI's predefined Image Service. # Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
# Resizing the image # Resizing the image
# Because we used the above service, we already have a PIL image. So we can simply resize.
resized_image = image.resize((self.width, self.height)) resized_image = image.resize((self.width, self.height))
# Preparing the image for output using InvokeAI's predefined Image Service. # Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
output_image = context.services.images.create( output_image = context.services.images.create(
image=resized_image, image=resized_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
@ -241,7 +228,6 @@ class ResizeInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=output_image.image_name, image_name=output_image.image_name,
image_origin=output_image.image_origin,
), ),
width=output_image.width, width=output_image.width,
height=output_image.height, height=output_image.height,
@ -253,6 +239,20 @@ certain way that the images need to be dispatched in order to be stored and read
correctly. In 99% of the cases when dealing with an image output, you can simply correctly. In 99% of the cases when dealing with an image output, you can simply
copy-paste the template above. copy-paste the template above.
### Customization
We can use the `@invocation` decorator to provide some additional info to the
UI, like a custom title, tags and category.
```python
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
class ResizeInvocation(BaseInvocation):
"""Resizes an image"""
image: ImageField = InputField(description="The input image")
...
```
That's it. You made your own **Resize Invocation**. That's it. You made your own **Resize Invocation**.
## Result ## Result
@ -271,10 +271,57 @@ new Invocation ready to be used.
![resize node editor](../assets/contributing/resize_node_editor.png) ![resize node editor](../assets/contributing/resize_node_editor.png)
## Contributing Nodes ## Contributing Nodes
Once you've created a Node, the next step is to share it with the community! The best way to do this is to submit a Pull Request to add the Node to the [Community Nodes](nodes/communityNodes) list. If you're not sure how to do that, take a look a at our [contributing nodes overview](contributingNodes).
Once you've created a Node, the next step is to share it with the community! The
best way to do this is to submit a Pull Request to add the Node to the
[Community Nodes](nodes/communityNodes) list. If you're not sure how to do that,
take a look a at our [contributing nodes overview](contributingNodes).
## Advanced ## Advanced
-->
### Custom Output Types
Like with custom inputs, sometimes you might find yourself needing custom
outputs that InvokeAI does not provide. We can easily set one up.
Now that you are familiar with Invocations and Inputs, let us use that knowledge
to create an output that has an `image` field, a `color` field and a `string`
field.
- An invocation output is a class that derives from the parent class of
`BaseInvocationOutput`.
- All invocation outputs must use the `@invocation_output` decorator to provide
their unique output type.
- Output fields must use the provided `OutputField` function. This is very
similar to the `InputField` function described earlier - it's a wrapper around
`pydantic`'s `Field()`.
- It is not mandatory but we recommend using names ending with `Output` for
output types.
- It is not mandatory but we highly recommend adding a `docstring` to describe
what your output type is for.
Now that we know the basic rules for creating a new output type, let us go ahead
and make it.
```python
from .baseinvocation import BaseInvocationOutput, OutputField, invocation_output
from .primitives import ImageField, ColorField
@invocation_output('image_color_string_output')
class ImageColorStringOutput(BaseInvocationOutput):
'''Base class for nodes that output a single image'''
image: ImageField = OutputField(description="The image")
color: ColorField = OutputField(description="The color")
text: str = OutputField(description="The string")
```
That's all there is to it.
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
### Custom Input Fields ### Custom Input Fields
Now that you know how to create your own Invocations, let us dive into slightly Now that you know how to create your own Invocations, let us dive into slightly
@ -329,172 +376,6 @@ like this.
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image') color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
``` ```
**Extra Config**
All input fields also take an additional `Config` class that you can use to do
various advanced things like setting required parameters and etc.
Let us do that for our _ColorField_ and enforce all the values because we did
not define any defaults for our fields.
```python
class ColorField(BaseModel):
'''A field that holds the rgba values of a color'''
r: int = Field(ge=0, le=255, description="The red channel")
g: int = Field(ge=0, le=255, description="The green channel")
b: int = Field(ge=0, le=255, description="The blue channel")
a: int = Field(ge=0, le=255, description="The alpha channel")
class Config:
schema_extra = {"required": ["r", "g", "b", "a"]}
```
Now it becomes mandatory for the user to supply all the values required by our
input field.
We will discuss the `Config` class in extra detail later in this guide and how
you can use it to make your Invocations more robust.
### Custom Output Types
Like with custom inputs, sometimes you might find yourself needing custom
outputs that InvokeAI does not provide. We can easily set one up.
Now that you are familiar with Invocations and Inputs, let us use that knowledge
to put together a custom output type for an Invocation that returns _width_,
_height_ and _background_color_ that we need to create a blank image.
- A custom output type is a class that derives from the parent class of
`BaseInvocationOutput`.
- It is not mandatory but we recommend using names ending with `Output` for
output types. So we'll call our class `BlankImageOutput`
- It is not mandatory but we highly recommend adding a `docstring` to describe
what your output type is for.
- Like Invocations, each output type should have a `type` variable that is
**unique**
Now that we know the basic rules for creating a new output type, let us go ahead
and make it.
```python
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocationOutput
class BlankImageOutput(BaseInvocationOutput):
'''Base output type for creating a blank image'''
type: Literal['blank_image_output'] = 'blank_image_output'
# Inputs
width: int = Field(description='Width of blank image')
height: int = Field(description='Height of blank image')
bg_color: ColorField = Field(description='Background color of blank image')
class Config:
schema_extra = {"required": ["type", "width", "height", "bg_color"]}
```
All set. We now have an output type that requires what we need to create a
blank_image. And if you noticed it, we even used the `Config` class to ensure
the fields are required.
### Custom Configuration
As you might have noticed when making inputs and outputs, we used a class called
`Config` from _pydantic_ to further customize them. Because our inputs and
outputs essentially inherit from _pydantic_'s `BaseModel` class, all
[configuration options](https://docs.pydantic.dev/latest/usage/schema/#schema-customization)
that are valid for _pydantic_ classes are also valid for our inputs and outputs.
You can do the same for your Invocations too but InvokeAI makes our life a
little bit easier on that end.
InvokeAI provides a custom configuration class called `InvocationConfig`
particularly for configuring Invocations. This is exactly the same as the raw
`Config` class from _pydantic_ with some extra stuff on top to help faciliate
parsing of the scheme in the frontend UI.
At the current moment, tihs `InvocationConfig` class is further improved with
the following features related the `ui`.
| Config Option | Field Type | Example |
| ------------- | ------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- |
| type_hints | `Dict[str, Literal["integer", "float", "boolean", "string", "enum", "image", "latents", "model", "control"]]` | `type_hint: "model"` provides type hints related to the model like displaying a list of available models |
| tags | `List[str]` | `tags: ['resize', 'image']` will classify your invocation under the tags of resize and image. |
| title | `str` | `title: 'Resize Image` will rename your to this custom title rather than infer from the name of the Invocation class. |
So let us update your `ResizeInvocation` with some extra configuration and see
how that works.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from ..models.image import ImageField, ResourceOrigin, ImageCategory
from .image import ImageOutput
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
class Config(InvocationConfig):
schema_extra: {
ui: {
tags: ['resize', 'image'],
title: ['My Custom Resize']
}
}
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the image using InvokeAI's predefined Image Service.
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
# Resizing the image
# Because we used the above service, we already have a PIL image. So we can simply resize.
resized_image = image.resize((self.width, self.height))
# Preparing the image for output using InvokeAI's predefined Image Service.
output_image = context.services.images.create(
image=resized_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
# Returning the Image
return ImageOutput(
image=ImageField(
image_name=output_image.image_name,
image_origin=output_image.image_origin,
),
width=output_image.width,
height=output_image.height,
)
```
We now customized our code to let the frontend know that our Invocation falls
under `resize` and `image` categories. So when the user searches for these
particular words, our Invocation will show up too.
We also set a custom title for our Invocation. So instead of being called
`Resize`, it will be called `My Custom Resize`.
As simple as that.
As time goes by, InvokeAI will further improve and add more customizability for
Invocation configuration. We will have more documentation regarding this at a
later time.
# **[TODO]**
### Custom Components For Frontend ### Custom Components For Frontend
Every backend input type should have a corresponding frontend component so the Every backend input type should have a corresponding frontend component so the
@ -513,282 +394,4 @@ Let us create a new component for our custom color field we created above. When
we use a color field, let us say we want the UI to display a color picker for we use a color field, let us say we want the UI to display a color picker for
the user to pick from rather than entering values. That is what we will build the user to pick from rather than entering values. That is what we will build
now. now.
-->
---
<!-- # OLD -- TO BE DELETED OR MOVED LATER
---
## Creating a new invocation
To create a new invocation, either find the appropriate module file in
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
that folder. All invocations in that folder will be discovered and made
available to the CLI and API automatically. Invocations make use of
[typing](https://docs.python.org/3/library/typing.html) and
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
into the CLI and API.
An invocation looks like this:
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring
that declares what they do in a single, short line. They should also have a
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
is what the user will type on the CLI or use in the API to create this
invocation. The `command_name` must be unique. The `type` must be assigned to
the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
description, and validation information. For example:
| Part | Value | Description |
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
field to be parsed with `None` as a value, which enables linking to previous
invocations. All fields should either provide a default value or allow `None` as
a value, so that they can be overwritten with a linked output from another
invocation.
The special type `ImageField` is also used here. All images are passed as
`ImageField`, which protects them from pydantic validation errors (since images
only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
### Config
```py
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
```
This is an optional configuration for the invocation. It inherits from
pydantic's model `Config` class, and it used primarily to customize the
autogenerated OpenAPI schema.
The UI relies on the OpenAPI schema in two ways:
- An API client & Typescript types are generated from it. This happens at build
time.
- The node editor parses the schema into a template used by the UI to create the
node editor UI. This parsing happens at runtime.
In this example, a `ui` key has been added to the `schema_extra` dict to provide
some tags for the UI, to facilitate filtering nodes.
See the Schema Generation section below for more information.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
The `invoke` function is the last portion of an invocation. It is provided an
`InvocationContext` which contains services to perform work as well as a
`session_id` for use as needed. It should return a class with output values that
derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from
defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other
invocations, may be running on another machine, or in other interesting
scenarios. If you need functionality, please provide it as a service in the
`InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
```
Output classes look like an invocation class without the invoke method. Prefer
to use an existing output class if available, and prefer to name inputs the same
as outputs when possible, to promote automatic invocation linking.
## Schema Generation
Invocation, output and related classes are used to generate an OpenAPI schema.
### Required Properties
The schema generation treat all properties with default values as optional. This
makes sense internally, but when when using these classes via the generated
schema, we end up with e.g. the `ImageOutput` class having its `image` property
marked as optional.
We know that this property will always be present, so the additional logic
needed to always check if the property exists adds a lot of extraneous cruft.
To fix this, we can leverage `pydantic`'s
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
to mark properties that we know will always be present as required.
Here's that `ImageOutput` class, without the needed schema customisation:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
```
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
`image`, `width` and `height` properties marked as optional, even though we know
they will always have a value.
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
# Add schema customization
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
```
With the customization in place, the schema will now show these properties as
required, obviating the need for extensive null checks in client code.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577> -->

View File

@ -2,15 +2,18 @@
from __future__ import annotations from __future__ import annotations
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from inspect import signature from inspect import signature
import re
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,
Any, Any,
Callable, Callable,
ClassVar, ClassVar,
Literal,
Mapping, Mapping,
Optional, Optional,
Type, Type,
@ -20,8 +23,8 @@ from typing import (
get_type_hints, get_type_hints,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING: if TYPE_CHECKING:
@ -141,9 +144,11 @@ class UIType(str, Enum):
# endregion # endregion
# region Misc # region Misc
FilePath = "FilePath"
Enum = "enum" Enum = "enum"
Scheduler = "Scheduler" Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
MetadataField = "MetadataField"
# endregion # endregion
@ -365,12 +370,12 @@ def OutputField(
class UIConfigBase(BaseModel): class UIConfigBase(BaseModel):
""" """
Provides additional node configuration to the UI. Provides additional node configuration to the UI.
This is used internally by the @tags and @title decorator logic. You probably want to use those This is used internally by the @invocation decorator logic. Do not use this directly.
decorators, though you may add this class to a node definition to specify the title and tags.
""" """
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI") tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The display name of the node") title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
class InvocationContext: class InvocationContext:
@ -383,10 +388,11 @@ class InvocationContext:
class BaseInvocationOutput(BaseModel): class BaseInvocationOutput(BaseModel):
"""Base class for all invocation outputs""" """
Base class for all invocation outputs.
# All outputs must include a type name like this: All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
# type: Literal['your_output_name'] # noqa f821 """
@classmethod @classmethod
def get_all_subclasses_tuple(cls): def get_all_subclasses_tuple(cls):
@ -422,12 +428,12 @@ class MissingInputException(Exception):
class BaseInvocation(ABC, BaseModel): class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
""" """
A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
# All invocations must include a type name like this: All invocations must use the `@invocation` decorator to provide their unique type.
# type: Literal['your_output_name'] # noqa f821 """
@classmethod @classmethod
def get_all_subclasses(cls): def get_all_subclasses(cls):
@ -466,6 +472,8 @@ class BaseInvocation(ABC, BaseModel):
schema["title"] = uiconfig.title schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"): if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list() schema["required"] = list()
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@ -505,37 +513,110 @@ class BaseInvocation(ABC, BaseModel):
raise MissingInputException(self.__fields__["type"].default, field_name) raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context) return self.invoke(context)
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(
is_intermediate: bool = InputField( description="The id of this instance of an invocation. Must be unique among all instances of invocations."
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
) )
is_intermediate: bool = InputField(
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
)
workflow: Optional[str] = InputField(
default=None,
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
if v is None:
return None
try:
json.loads(v)
except json.decoder.JSONDecodeError:
raise ValueError("Workflow must be valid JSON")
return v
UIConfig: ClassVar[Type[UIConfigBase]] UIConfig: ClassVar[Type[UIConfigBase]]
T = TypeVar("T", bound=BaseInvocation) GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
def title(title: str) -> Callable[[Type[T]], Type[T]]: def invocation(
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name.""" invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
def wrapper(cls: Type[T]) -> Type[T]: :param str invocation_type: The type of the invocation. Must be unique among all invocations.
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
"""
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
# Validate invocation types on creation of invocation classes
# TODO: ensure unique?
if re.compile(r"^\S+$").match(invocation_type) is None:
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig" uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
if title is not None:
cls.UIConfig.title = title cls.UIConfig.title = title
if tags is not None:
cls.UIConfig.tags = tags
if category is not None:
cls.UIConfig.category = category
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = ModelField.infer(
name="type",
value=invocation_type,
annotation=invocation_type_annotation,
class_validators=None,
config=cls.__config__,
)
cls.__fields__.update({"type": invocation_type_field})
cls.__annotations__.update({"type": invocation_type_annotation})
return cls return cls
return wrapper return wrapper
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]: GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
def invocation_output(
output_type: str,
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
"""
Adds metadata to an invocation output.
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
"""
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
# Validate output types on creation of invocation output classes
# TODO: ensure unique?
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
# Add the output type to the pydantic model of the invocation output
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = ModelField.infer(
name="type",
value=output_type,
annotation=output_type_annotation,
class_validators=None,
config=cls.__config__,
)
cls.__fields__.update({"type": output_type_field})
cls.__annotations__.update({"type": output_type_annotation})
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.tags = list(tags)
return cls return cls
return wrapper return wrapper

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal
import numpy as np import numpy as np
from pydantic import validator from pydantic import validator
@ -8,17 +7,13 @@ from pydantic import validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("Integer Range") @invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
@tags("collection", "integer", "range")
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
# Inputs
start: int = InputField(default=0, description="The start of the range") start: int = InputField(default=0, description="The start of the range")
stop: int = InputField(default=10, description="The stop of the range") stop: int = InputField(default=10, description="The stop of the range")
step: int = InputField(default=1, description="The step of the range") step: int = InputField(default=1, description="The step of the range")
@ -33,14 +28,15 @@ class RangeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size") @invocation(
@tags("range", "integer", "size", "collection") "range_of_size",
title="Integer Range of Size",
tags=["collection", "integer", "size", "range"],
category="collections",
)
class RangeOfSizeInvocation(BaseInvocation): class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step""" """Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = InputField(default=0, description="The start of the range") start: int = InputField(default=0, description="The start of the range")
size: int = InputField(default=1, description="The number of values") size: int = InputField(default=1, description="The number of values")
step: int = InputField(default=1, description="The step of the range") step: int = InputField(default=1, description="The step of the range")
@ -49,14 +45,15 @@ class RangeOfSizeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range") @invocation(
@tags("range", "integer", "random", "collection") "random_range",
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
)
class RandomRangeInvocation(BaseInvocation): class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers""" """Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = InputField(default=0, description="The inclusive low value") low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate") size: int = InputField(default=1, description="The number of values to generate")

View File

@ -1,6 +1,6 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Literal, Union from typing import List, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
@ -26,8 +26,8 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIComponent, UIComponent,
tags, invocation,
title, invocation_output,
) )
from .model import ClipField from .model import ClipField
@ -44,13 +44,10 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg" # PerpNeg = "perp_neg"
@title("Compel Prompt") @invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
@tags("prompt", "compel")
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = InputField( prompt: str = InputField(
default="", default="",
description=FieldDescriptions.compel_prompt, description=FieldDescriptions.compel_prompt,
@ -265,13 +262,15 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec return c, c_pooled, ec
@title("SDXL Compel Prompt") @invocation(
@tags("sdxl", "compel", "prompt") "sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
original_width: int = InputField(default=1024, description="") original_width: int = InputField(default=1024, description="")
@ -324,13 +323,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
) )
@title("SDXL Refiner Compel Prompt") @invocation(
@tags("sdxl", "compel", "prompt") "sdxl_refiner_compel_prompt",
title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = InputField( style: str = InputField(
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
) # TODO: ? ) # TODO: ?
@ -372,20 +373,17 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
) )
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput): class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output""" """Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@title("CLIP Skip") @invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
@tags("clipskip", "clip", "skip")
class ClipSkipInvocation(BaseInvocation): class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)

View File

@ -40,8 +40,8 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
tags, invocation,
title, invocation_output,
) )
@ -87,23 +87,18 @@ class ControlField(BaseModel):
return v return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput): class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
type: Literal["control_output"] = "control_output"
# Outputs # Outputs
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
@tags("controlnet")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField( control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
@ -134,12 +129,10 @@ class ControlNetInvocation(BaseInvocation):
) )
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
class ImageProcessorInvocation(BaseInvocation): class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
def run_processor(self, image): def run_processor(self, image):
@ -151,11 +144,6 @@ class ImageProcessorInvocation(BaseInvocation):
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node, # currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create( image_dto = context.services.images.create(
@ -165,6 +153,7 @@ class ImageProcessorInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
@ -179,14 +168,15 @@ class ImageProcessorInvocation(BaseInvocation):
) )
@title("Canny Processor") @invocation(
@tags("controlnet", "canny") "canny_image_processor",
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation): class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet""" """Canny edge detection for ControlNet"""
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = InputField( low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
) )
@ -200,14 +190,15 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("HED (softedge) Processor") @invocation(
@tags("controlnet", "hed", "softedge") "hed_image_processor",
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
)
class HedImageProcessorInvocation(ImageProcessorInvocation): class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# safe not supported in controlnet_aux v0.0.3 # safe not supported in controlnet_aux v0.0.3
@ -227,14 +218,15 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Lineart Processor") @invocation(
@tags("controlnet", "lineart") "lineart_image_processor",
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation): class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image""" """Applies line art processing to image"""
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode") coarse: bool = InputField(default=False, description="Whether to use coarse mode")
@ -247,14 +239,15 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Lineart Anime Processor") @invocation(
@tags("controlnet", "lineart", "anime") "lineart_anime_image_processor",
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image""" """Applies line art anime processing to image"""
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -268,14 +261,15 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Openpose Processor") @invocation(
@tags("controlnet", "openpose", "pose") "openpose_image_processor",
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation): class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image""" """Applies Openpose processing to image"""
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode") hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -291,14 +285,15 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Midas (Depth) Processor") @invocation(
@tags("controlnet", "midas", "depth") "midas_depth_image_processor",
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image""" """Applies Midas depth processing to image"""
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
@ -316,14 +311,15 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Normal BAE Processor") @invocation(
@tags("controlnet", "normal", "bae") "normalbae_image_processor",
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image""" """Applies NormalBae processing to image"""
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -335,14 +331,10 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("MLSD Processor") @invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
@tags("controlnet", "mlsd")
class MlsdImageProcessorInvocation(ImageProcessorInvocation): class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image""" """Applies MLSD processing to image"""
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
@ -360,14 +352,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("PIDI Processor") @invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
@tags("controlnet", "pidi")
class PidiImageProcessorInvocation(ImageProcessorInvocation): class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image""" """Applies PIDI processing to image"""
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
@ -385,14 +373,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Content Shuffle Processor") @invocation(
@tags("controlnet", "contentshuffle") "content_shuffle_image_processor",
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image""" """Applies content shuffle processing to image"""
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter") h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
@ -413,27 +402,30 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
@title("Zoe (Depth) Processor") @invocation(
@tags("controlnet", "zoe", "depth") "zoe_depth_image_processor",
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
def run_processor(self, image): def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image)
return processed_image return processed_image
@title("Mediapipe Face Processor") @invocation(
@tags("controlnet", "mediapipe", "face") "mediapipe_face_processor",
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image""" """Applies mediapipe face processing to image"""
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
@ -447,14 +439,15 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Leres (Depth) Processor") @invocation(
@tags("controlnet", "leres", "depth") "leres_image_processor",
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation): class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image""" """Applies leres processing to image"""
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`") thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`") thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
boost: bool = InputField(default=False, description="Whether to use boost mode") boost: bool = InputField(default=False, description="Whether to use boost mode")
@ -474,14 +467,15 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Tile Resample Processor") @invocation(
@tags("controlnet", "tile") "tile_image_processor",
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation): class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor""" """Tile resampler processor"""
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile") # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
@ -512,13 +506,15 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
return processed_image return processed_image
@title("Segment Anything Processor") @invocation(
@tags("controlnet", "segmentanything") "segment_anything_processor",
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image""" """Applies segment anything processing to image"""
type: Literal["segment_anything_processor"] = "segment_anything_processor"
def run_processor(self, image): def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv import cv2 as cv
import numpy import numpy
@ -8,17 +7,18 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("OpenCV Inpaint") @invocation(
@tags("opencv", "inpaint") "cv_inpaint",
title="OpenCV Inpaint",
tags=["opencv", "inpaint"],
category="inpaint",
)
class CvInpaintInvocation(BaseInvocation): class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = InputField(description="The image to inpaint") image: ImageField = InputField(description="The image to inpaint")
mask: ImageField = InputField(description="The mask to use when inpainting") mask: ImageField = InputField(description="The mask to use when inpainting")
@ -45,6 +45,7 @@ class CvInpaintInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -13,18 +13,13 @@ from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@title("Show Image") @invocation("show_image", title="Show Image", tags=["image"], category="image")
@tags("image")
class ShowImageInvocation(BaseInvocation): class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline.""" """Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
# Metadata
type: Literal["show_image"] = "show_image"
# Inputs
image: ImageField = InputField(description="The image to show") image: ImageField = InputField(description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -41,15 +36,10 @@ class ShowImageInvocation(BaseInvocation):
) )
@title("Blank Image") @invocation("blank_image", title="Blank Image", tags=["image"], category="image")
@tags("image")
class BlankImageInvocation(BaseInvocation): class BlankImageInvocation(BaseInvocation):
"""Creates a blank image and forwards it to the pipeline""" """Creates a blank image and forwards it to the pipeline"""
# Metadata
type: Literal["blank_image"] = "blank_image"
# Inputs
width: int = InputField(default=512, description="The width of the image") width: int = InputField(default=512, description="The width of the image")
height: int = InputField(default=512, description="The height of the image") height: int = InputField(default=512, description="The height of the image")
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
@ -65,6 +55,7 @@ class BlankImageInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -74,15 +65,10 @@ class BlankImageInvocation(BaseInvocation):
) )
@title("Crop Image") @invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
@tags("image", "crop")
class ImageCropInvocation(BaseInvocation): class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
# Metadata
type: Literal["img_crop"] = "img_crop"
# Inputs
image: ImageField = InputField(description="The image to crop") image: ImageField = InputField(description="The image to crop")
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle") x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle") y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
@ -102,6 +88,7 @@ class ImageCropInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -111,15 +98,10 @@ class ImageCropInvocation(BaseInvocation):
) )
@title("Paste Image") @invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
@tags("image", "paste")
class ImagePasteInvocation(BaseInvocation): class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image.""" """Pastes an image into another image."""
# Metadata
type: Literal["img_paste"] = "img_paste"
# Inputs
base_image: ImageField = InputField(description="The base image") base_image: ImageField = InputField(description="The base image")
image: ImageField = InputField(description="The image to paste") image: ImageField = InputField(description="The image to paste")
mask: Optional[ImageField] = InputField( mask: Optional[ImageField] = InputField(
@ -154,6 +136,7 @@ class ImagePasteInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -163,15 +146,10 @@ class ImagePasteInvocation(BaseInvocation):
) )
@title("Mask from Alpha") @invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
@tags("image", "mask")
class MaskFromAlphaInvocation(BaseInvocation): class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
# Metadata
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = InputField(description="The image to create the mask from") image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask") invert: bool = InputField(default=False, description="Whether or not to invert the mask")
@ -189,6 +167,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -198,15 +177,10 @@ class MaskFromAlphaInvocation(BaseInvocation):
) )
@title("Multiply Images") @invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
@tags("image", "multiply")
class ImageMultiplyInvocation(BaseInvocation): class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`.""" """Multiplies two images together using `PIL.ImageChops.multiply()`."""
# Metadata
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: ImageField = InputField(description="The first image to multiply") image1: ImageField = InputField(description="The first image to multiply")
image2: ImageField = InputField(description="The second image to multiply") image2: ImageField = InputField(description="The second image to multiply")
@ -223,6 +197,7 @@ class ImageMultiplyInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -235,15 +210,10 @@ class ImageMultiplyInvocation(BaseInvocation):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"] IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@title("Extract Image Channel") @invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
@tags("image", "channel")
class ImageChannelInvocation(BaseInvocation): class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image.""" """Gets a channel from an image."""
# Metadata
type: Literal["img_chan"] = "img_chan"
# Inputs
image: ImageField = InputField(description="The image to get the channel from") image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
@ -259,6 +229,7 @@ class ImageChannelInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -271,15 +242,10 @@ class ImageChannelInvocation(BaseInvocation):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@title("Convert Image Mode") @invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
@tags("image", "convert")
class ImageConvertInvocation(BaseInvocation): class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode.""" """Converts an image to a different mode."""
# Metadata
type: Literal["img_conv"] = "img_conv"
# Inputs
image: ImageField = InputField(description="The image to convert") image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
@ -295,6 +261,7 @@ class ImageConvertInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -304,15 +271,10 @@ class ImageConvertInvocation(BaseInvocation):
) )
@title("Blur Image") @invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
@tags("image", "blur")
class ImageBlurInvocation(BaseInvocation): class ImageBlurInvocation(BaseInvocation):
"""Blurs an image""" """Blurs an image"""
# Metadata
type: Literal["img_blur"] = "img_blur"
# Inputs
image: ImageField = InputField(description="The image to blur") image: ImageField = InputField(description="The image to blur")
radius: float = InputField(default=8.0, ge=0, description="The blur radius") radius: float = InputField(default=8.0, ge=0, description="The blur radius")
# Metadata # Metadata
@ -333,6 +295,7 @@ class ImageBlurInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -362,15 +325,10 @@ PIL_RESAMPLING_MAP = {
} }
@title("Resize Image") @invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
@tags("image", "resize")
class ImageResizeInvocation(BaseInvocation): class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions""" """Resizes an image to specific dimensions"""
# Metadata
type: Literal["img_resize"] = "img_resize"
# Inputs
image: ImageField = InputField(description="The image to resize") image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
@ -397,6 +355,7 @@ class ImageResizeInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None, metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -406,15 +365,10 @@ class ImageResizeInvocation(BaseInvocation):
) )
@title("Scale Image") @invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
@tags("image", "scale")
class ImageScaleInvocation(BaseInvocation): class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor""" """Scales an image by a factor"""
# Metadata
type: Literal["img_scale"] = "img_scale"
# Inputs
image: ImageField = InputField(description="The image to scale") image: ImageField = InputField(description="The image to scale")
scale_factor: float = InputField( scale_factor: float = InputField(
default=2.0, default=2.0,
@ -442,6 +396,7 @@ class ImageScaleInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -451,15 +406,10 @@ class ImageScaleInvocation(BaseInvocation):
) )
@title("Lerp Image") @invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
@tags("image", "lerp")
class ImageLerpInvocation(BaseInvocation): class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
# Metadata
type: Literal["img_lerp"] = "img_lerp"
# Inputs
image: ImageField = InputField(description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
@ -479,6 +429,7 @@ class ImageLerpInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -488,15 +439,10 @@ class ImageLerpInvocation(BaseInvocation):
) )
@title("Inverse Lerp Image") @invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
@tags("image", "ilerp")
class ImageInverseLerpInvocation(BaseInvocation): class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
# Metadata
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
image: ImageField = InputField(description="The image to lerp") image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
@ -516,6 +462,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -525,15 +472,10 @@ class ImageInverseLerpInvocation(BaseInvocation):
) )
@title("Blur NSFW Image") @invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
@tags("image", "nsfw")
class ImageNSFWBlurInvocation(BaseInvocation): class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
# Metadata
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: ImageField = InputField(description="The image to check") image: ImageField = InputField(description="The image to check")
metadata: Optional[CoreMetadata] = InputField( metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
@ -559,6 +501,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None, metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -574,15 +517,10 @@ class ImageNSFWBlurInvocation(BaseInvocation):
return caution.resize((caution.width // 2, caution.height // 2)) return caution.resize((caution.width // 2, caution.height // 2))
@title("Add Invisible Watermark") @invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
@tags("image", "watermark")
class ImageWatermarkInvocation(BaseInvocation): class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image""" """Add an invisible watermark to an image"""
# Metadata
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: ImageField = InputField(description="The image to check") image: ImageField = InputField(description="The image to check")
text: str = InputField(default="InvokeAI", description="Watermark text") text: str = InputField(default="InvokeAI", description="Watermark text")
metadata: Optional[CoreMetadata] = InputField( metadata: Optional[CoreMetadata] = InputField(
@ -600,6 +538,7 @@ class ImageWatermarkInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None, metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -609,14 +548,10 @@ class ImageWatermarkInvocation(BaseInvocation):
) )
@title("Mask Edge") @invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
@tags("image", "mask", "inpaint")
class MaskEdgeInvocation(BaseInvocation): class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image""" """Applies an edge mask to an image"""
type: Literal["mask_edge"] = "mask_edge"
# Inputs
image: ImageField = InputField(description="The image to apply the mask to") image: ImageField = InputField(description="The image to apply the mask to")
edge_size: int = InputField(description="The size of the edge") edge_size: int = InputField(description="The size of the edge")
edge_blur: int = InputField(description="The amount of blur on the edge") edge_blur: int = InputField(description="The amount of blur on the edge")
@ -648,6 +583,7 @@ class MaskEdgeInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -657,14 +593,10 @@ class MaskEdgeInvocation(BaseInvocation):
) )
@title("Combine Mask") @invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
@tags("image", "mask", "multiply")
class MaskCombineInvocation(BaseInvocation): class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
type: Literal["mask_combine"] = "mask_combine"
# Inputs
mask1: ImageField = InputField(description="The first mask to combine") mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = InputField(description="The second image to combine") mask2: ImageField = InputField(description="The second image to combine")
@ -681,6 +613,7 @@ class MaskCombineInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -690,17 +623,13 @@ class MaskCombineInvocation(BaseInvocation):
) )
@title("Color Correct") @invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
@tags("image", "color")
class ColorCorrectInvocation(BaseInvocation): class ColorCorrectInvocation(BaseInvocation):
""" """
Shifts the colors of a target image to match the reference image, optionally Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image. using a mask to only color-correct certain regions of the target image.
""" """
type: Literal["color_correct"] = "color_correct"
# Inputs
image: ImageField = InputField(description="The image to color-correct") image: ImageField = InputField(description="The image to color-correct")
reference: ImageField = InputField(description="Reference image for color-correction") reference: ImageField = InputField(description="Reference image for color-correction")
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
@ -789,6 +718,7 @@ class ColorCorrectInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -798,14 +728,10 @@ class ColorCorrectInvocation(BaseInvocation):
) )
@title("Image Hue Adjustment") @invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
@tags("image", "hue", "hsl")
class ImageHueAdjustmentInvocation(BaseInvocation): class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image.""" """Adjusts the Hue of an image."""
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
@ -831,6 +757,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -842,14 +769,15 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
) )
@title("Image Luminosity Adjustment") @invocation(
@tags("image", "luminosity", "hsl") "img_luminosity_adjust",
title="Adjust Image Luminosity",
tags=["image", "luminosity", "hsl"],
category="image",
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation): class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image.""" """Adjusts the Luminosity (Value) of an image."""
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
luminosity: float = InputField( luminosity: float = InputField(
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)" default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
@ -881,6 +809,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -892,14 +821,15 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
) )
@title("Image Saturation Adjustment") @invocation(
@tags("image", "saturation", "hsl") "img_saturation_adjust",
title="Adjust Image Saturation",
tags=["image", "saturation", "hsl"],
category="image",
)
class ImageSaturationAdjustmentInvocation(BaseInvocation): class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image.""" """Adjusts the Saturation of an image."""
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust") image: ImageField = InputField(description="The image to adjust")
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
@ -929,6 +859,7 @@ class ImageSaturationAdjustmentInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -12,7 +12,7 @@ from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
def infill_methods() -> list[str]: def infill_methods() -> list[str]:
@ -116,14 +116,10 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si return si
@title("Solid Color Infill") @invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
@tags("image", "inpaint")
class InfillColorInvocation(BaseInvocation): class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
# Inputs
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField( color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255), default=ColorField(r=127, g=127, b=127, a=255),
@ -145,6 +141,7 @@ class InfillColorInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -154,14 +151,10 @@ class InfillColorInvocation(BaseInvocation):
) )
@title("Tile Infill") @invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
@tags("image", "inpaint")
class InfillTileInvocation(BaseInvocation): class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image""" """Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
# Input
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField( seed: int = InputField(
@ -184,6 +177,7 @@ class InfillTileInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -193,14 +187,10 @@ class InfillTileInvocation(BaseInvocation):
) )
@title("PatchMatch Infill") @invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
@tags("image", "inpaint")
class InfillPatchMatchInvocation(BaseInvocation): class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm""" """Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
# Inputs
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -218,6 +208,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -227,14 +218,10 @@ class InfillPatchMatchInvocation(BaseInvocation):
) )
@title("LaMa Infill") @invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
@tags("image", "inpaint")
class LaMaInfillInvocation(BaseInvocation): class LaMaInfillInvocation(BaseInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
type: Literal["infill_lama"] = "infill_lama"
# Inputs
image: ImageField = InputField(description="The image to infill") image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -47,7 +47,18 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from .compel import ConditioningField from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
@ -58,15 +69,27 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create Denoise Mask") @invocation_output("scheduler_output")
@tags("mask", "denoise") class SchedulerOutput(BaseInvocationOutput):
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
class CreateDenoiseMaskInvocation(BaseInvocation): class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run.""" """Creates mask for denoising model run."""
# Metadata
type: Literal["create_denoise_mask"] = "create_denoise_mask"
# Inputs
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
@ -158,14 +181,15 @@ def get_scheduler(
return scheduler return scheduler
@title("Denoise Latents") @invocation(
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l") "denoise_latents",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
)
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
type: Literal["denoise_latents"] = "denoise_latents"
# Inputs
positive_conditioning: ConditioningField = InputField( positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
) )
@ -512,14 +536,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed) return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
@title("Latents to Image") @invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
@tags("latents", "image", "vae", "l2i")
class LatentsToImageInvocation(BaseInvocation): class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: LatentsField = InputField( latents: LatentsField = InputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
@ -600,6 +620,7 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None, metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -612,14 +633,10 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@title("Resize Latents") @invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
@tags("latents", "resize")
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: LatentsField = InputField( latents: LatentsField = InputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
@ -660,14 +677,10 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Scale Latents") @invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
@tags("latents", "resize")
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor.""" """Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: LatentsField = InputField( latents: LatentsField = InputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
@ -700,14 +713,10 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Image to Latents") @invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
@tags("latents", "image", "vae", "i2l")
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents.""" """Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: ImageField = InputField( image: ImageField = InputField(
description="The image to encode", description="The image to encode",
) )
@ -784,14 +793,10 @@ class ImageToLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=latents, seed=None) return build_latents_output(latents_name=name, latents=latents, seed=None)
@title("Blend Latents") @invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
@tags("latents", "blend")
class BlendLatentsInvocation(BaseInvocation): class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size.""" """Blend two latents using a given alpha. Latents must have same size."""
type: Literal["lblend"] = "lblend"
# Inputs
latents_a: LatentsField = InputField( latents_a: LatentsField = InputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,

View File

@ -1,22 +1,16 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import numpy as np import numpy as np
from invokeai.app.invocations.primitives import IntegerOutput from invokeai.app.invocations.primitives import IntegerOutput
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@title("Add Integers") @invocation("add", title="Add Integers", tags=["math", "add"], category="math")
@tags("math")
class AddInvocation(BaseInvocation): class AddInvocation(BaseInvocation):
"""Adds two numbers""" """Adds two numbers"""
type: Literal["add"] = "add"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1) a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2) b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -24,14 +18,10 @@ class AddInvocation(BaseInvocation):
return IntegerOutput(value=self.a + self.b) return IntegerOutput(value=self.a + self.b)
@title("Subtract Integers") @invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
@tags("math")
class SubtractInvocation(BaseInvocation): class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers""" """Subtracts two numbers"""
type: Literal["sub"] = "sub"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1) a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2) b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -39,14 +29,10 @@ class SubtractInvocation(BaseInvocation):
return IntegerOutput(value=self.a - self.b) return IntegerOutput(value=self.a - self.b)
@title("Multiply Integers") @invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
@tags("math")
class MultiplyInvocation(BaseInvocation): class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers""" """Multiplies two numbers"""
type: Literal["mul"] = "mul"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1) a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2) b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -54,14 +40,10 @@ class MultiplyInvocation(BaseInvocation):
return IntegerOutput(value=self.a * self.b) return IntegerOutput(value=self.a * self.b)
@title("Divide Integers") @invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
@tags("math")
class DivideInvocation(BaseInvocation): class DivideInvocation(BaseInvocation):
"""Divides two numbers""" """Divides two numbers"""
type: Literal["div"] = "div"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1) a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2) b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -69,14 +51,10 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b)) return IntegerOutput(value=int(self.a / self.b))
@title("Random Integer") @invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
@tags("math")
class RandomIntInvocation(BaseInvocation): class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer.""" """Outputs a single random integer."""
type: Literal["rand_int"] = "rand_int"
# Inputs
low: int = InputField(default=0, description="The inclusive low value") low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional from typing import Optional
from pydantic import Field from pydantic import Field
@ -8,8 +8,8 @@ from invokeai.app.invocations.baseinvocation import (
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
tags, invocation,
title, invocation_output,
) )
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
@ -91,21 +91,17 @@ class ImageMetadata(BaseModelExcludeNull):
graph: Optional[dict] = Field(default=None, description="The graph that created the image") graph: Optional[dict] = Field(default=None, description="The graph that created the image")
@invocation_output("metadata_accumulator_output")
class MetadataAccumulatorOutput(BaseInvocationOutput): class MetadataAccumulatorOutput(BaseInvocationOutput):
"""The output of the MetadataAccumulator node""" """The output of the MetadataAccumulator node"""
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = OutputField(description="The core metadata for the image") metadata: CoreMetadata = OutputField(description="The core metadata for the image")
@title("Metadata Accumulator") @invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
@tags("metadata")
class MetadataAccumulatorInvocation(BaseInvocation): class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object""" """Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = InputField( generation_mode: str = InputField(
description="The generation mode that output this image", description="The generation mode that output this image",
) )

View File

@ -1,5 +1,5 @@
import copy import copy
from typing import List, Literal, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -13,8 +13,8 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
tags, invocation,
title, invocation_output,
) )
@ -49,11 +49,10 @@ class VaeField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("model_loader_output")
class ModelLoaderOutput(BaseInvocationOutput): class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@ -74,14 +73,10 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
@title("Main Model") @invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
@tags("model")
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
# Inputs
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision? # TODO: precision?
@ -170,25 +165,18 @@ class MainModelLoaderInvocation(BaseInvocation):
) )
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
# fmt: on
@title("LoRA") @invocation("lora_loader", title="LoRA", tags=["model"], category="model")
@tags("lora", "model")
class LoraLoaderInvocation(BaseInvocation): class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
# Inputs
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -247,25 +235,19 @@ class LoraLoaderInvocation(BaseInvocation):
return output return output
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput): class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output""" """SDXL LoRA Loader Output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1") clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2") clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on
@title("SDXL LoRA") @invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
@tags("sdxl", "lora", "model")
class SDXLLoraLoaderInvocation(BaseInvocation): class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight) weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field( unet: Optional[UNetField] = Field(
@ -349,23 +331,17 @@ class VAEModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model") base_model: BaseModelType = Field(description="Base model")
@invocation_output("vae_loader_output")
class VaeLoaderOutput(BaseInvocationOutput): class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """VAE output"""
type: Literal["vae_loader_output"] = "vae_loader_output"
# Outputs
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("VAE") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
@tags("vae", "model")
class VaeLoaderInvocation(BaseInvocation): class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
# Inputs
vae_model: VAEModelField = InputField( vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE" description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
) )
@ -392,24 +368,18 @@ class VaeLoaderInvocation(BaseInvocation):
) )
@invocation_output("seamless_output")
class SeamlessModeOutput(BaseInvocationOutput): class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output""" """Modified Seamless Model output"""
type: Literal["seamless_output"] = "seamless_output"
# Outputs
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet") unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE") vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("Seamless") @invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
@tags("seamless", "model")
class SeamlessModeInvocation(BaseInvocation): class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE.""" """Applies the seamless transformation to the Model UNet and VAE."""
type: Literal["seamless"] = "seamless"
# Inputs
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet" default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
) )

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from typing import Literal
import torch import torch
from pydantic import validator from pydantic import validator
@ -16,8 +15,8 @@ from .baseinvocation import (
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
tags, invocation,
title, invocation_output,
) )
""" """
@ -62,12 +61,10 @@ Nodes
""" """
@invocation_output("noise_output")
class NoiseOutput(BaseInvocationOutput): class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output""" """Invocation noise output"""
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise) noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
width: int = OutputField(description=FieldDescriptions.width) width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height) height: int = OutputField(description=FieldDescriptions.height)
@ -81,14 +78,10 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
) )
@title("Noise") @invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
@tags("latents", "noise")
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = InputField( seed: int = InputField(
ge=0, ge=0,
le=SEED_MAX, le=SEED_MAX,

View File

@ -31,8 +31,8 @@ from .baseinvocation import (
OutputField, OutputField,
UIComponent, UIComponent,
UIType, UIType,
tags, invocation,
title, invocation_output,
) )
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
@ -56,11 +56,8 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@title("ONNX Prompt (Raw)") @invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
@tags("onnx", "prompt")
class ONNXPromptInvocation(BaseInvocation): class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@ -141,14 +138,15 @@ class ONNXPromptInvocation(BaseInvocation):
# Text to image # Text to image
@title("ONNX Text to Latents") @invocation(
@tags("latents", "inference", "txt2img", "onnx") "t2l_onnx",
title="ONNX Text to Latents",
tags=["latents", "inference", "txt2img", "onnx"],
category="latents",
)
class ONNXTextToLatentsInvocation(BaseInvocation): class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
positive_conditioning: ConditioningField = InputField( positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, description=FieldDescriptions.positive_cond,
input=Input.Connection, input=Input.Connection,
@ -316,14 +314,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# Latent to image # Latent to image
@title("ONNX Latents to Image") @invocation(
@tags("latents", "image", "vae", "onnx") "l2i_onnx",
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
)
class ONNXLatentsToImageInvocation(BaseInvocation): class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents.""" """Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: LatentsField = InputField( latents: LatentsField = InputField(
description=FieldDescriptions.denoised_latents, description=FieldDescriptions.denoised_latents,
input=Input.Connection, input=Input.Connection,
@ -376,6 +375,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None, metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(
@ -385,17 +385,14 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
) )
@invocation_output("model_loader_output_onnx")
class ONNXModelLoaderOutput(BaseInvocationOutput): class ONNXModelLoaderOutput(BaseInvocationOutput):
"""Model loader output""" """Model loader output"""
# fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
# fmt: on
class OnnxModelField(BaseModel): class OnnxModelField(BaseModel):
@ -406,14 +403,10 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type") model_type: ModelType = Field(description="Model Type")
@title("ONNX Main Model") @invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
@tags("onnx", "model")
class OnnxModelLoaderInvocation(BaseInvocation): class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
# Inputs
model: OnnxModelField = InputField( model: OnnxModelField = InputField(
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
) )

View File

@ -42,17 +42,13 @@ from matplotlib.ticker import MaxNLocator
from invokeai.app.invocations.primitives import FloatCollectionOutput from invokeai.app.invocations.primitives import FloatCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("Float Range") @invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
@tags("math", "range")
class FloatLinearRangeInvocation(BaseInvocation): class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range""" """Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = InputField(default=5, description="The first value of the range") start: float = InputField(default=5, description="The first value of the range")
stop: float = InputField(default=10, description="The last value of the range") stop: float = InputField(default=10, description="The last value of the range")
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)") steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
@ -100,14 +96,10 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any] # actually I think for now could just use CollectionOutput (which is list[Any]
@title("Step Param Easing") @invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
@tags("step", "easing")
class StepParamEasingInvocation(BaseInvocation): class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps""" """Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use") easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
num_steps: int = InputField(default=20, description="number of denoising steps") num_steps: int = InputField(default=20, description="number of denoising steps")
start_value: float = InputField(default=0.0, description="easing starting value") start_value: float = InputField(default=0.0, description="easing starting value")

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple from typing import Optional, Tuple
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -15,8 +15,8 @@ from .baseinvocation import (
OutputField, OutputField,
UIComponent, UIComponent,
UIType, UIType,
tags, invocation,
title, invocation_output,
) )
""" """
@ -29,44 +29,39 @@ Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
# region Boolean # region Boolean
@invocation_output("boolean_output")
class BooleanOutput(BaseInvocationOutput): class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean""" """Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
value: bool = OutputField(description="The output boolean") value: bool = OutputField(description="The output boolean")
@invocation_output("boolean_collection_output")
class BooleanCollectionOutput(BaseInvocationOutput): class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans""" """Base class for nodes that output a collection of booleans"""
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection) collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
@title("Boolean Primitive") @invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
@tags("primitives", "boolean")
class BooleanInvocation(BaseInvocation): class BooleanInvocation(BaseInvocation):
"""A boolean primitive value""" """A boolean primitive value"""
type: Literal["boolean"] = "boolean"
# Inputs
value: bool = InputField(default=False, description="The boolean value") value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput: def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(value=self.value) return BooleanOutput(value=self.value)
@title("Boolean Primitive Collection") @invocation(
@tags("primitives", "boolean", "collection") "boolean_collection",
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
)
class BooleanCollectionInvocation(BaseInvocation): class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values""" """A collection of boolean primitive values"""
type: Literal["boolean_collection"] = "boolean_collection"
# Inputs
collection: list[bool] = InputField( collection: list[bool] = InputField(
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
) )
@ -80,44 +75,39 @@ class BooleanCollectionInvocation(BaseInvocation):
# region Integer # region Integer
@invocation_output("integer_output")
class IntegerOutput(BaseInvocationOutput): class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer""" """Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
value: int = OutputField(description="The output integer") value: int = OutputField(description="The output integer")
@invocation_output("integer_collection_output")
class IntegerCollectionOutput(BaseInvocationOutput): class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers""" """Base class for nodes that output a collection of integers"""
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection) collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
@title("Integer Primitive") @invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
@tags("primitives", "integer")
class IntegerInvocation(BaseInvocation): class IntegerInvocation(BaseInvocation):
"""An integer primitive value""" """An integer primitive value"""
type: Literal["integer"] = "integer"
# Inputs
value: int = InputField(default=0, description="The integer value") value: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput: def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.value) return IntegerOutput(value=self.value)
@title("Integer Primitive Collection") @invocation(
@tags("primitives", "integer", "collection") "integer_collection",
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
)
class IntegerCollectionInvocation(BaseInvocation): class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values""" """A collection of integer primitive values"""
type: Literal["integer_collection"] = "integer_collection"
# Inputs
collection: list[int] = InputField( collection: list[int] = InputField(
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
) )
@ -131,44 +121,39 @@ class IntegerCollectionInvocation(BaseInvocation):
# region Float # region Float
@invocation_output("float_output")
class FloatOutput(BaseInvocationOutput): class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float""" """Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
value: float = OutputField(description="The output float") value: float = OutputField(description="The output float")
@invocation_output("float_collection_output")
class FloatCollectionOutput(BaseInvocationOutput): class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats""" """Base class for nodes that output a collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection) collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
@title("Float Primitive") @invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
@tags("primitives", "float")
class FloatInvocation(BaseInvocation): class FloatInvocation(BaseInvocation):
"""A float primitive value""" """A float primitive value"""
type: Literal["float"] = "float"
# Inputs
value: float = InputField(default=0.0, description="The float value") value: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput: def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=self.value) return FloatOutput(value=self.value)
@title("Float Primitive Collection") @invocation(
@tags("primitives", "float", "collection") "float_collection",
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
)
class FloatCollectionInvocation(BaseInvocation): class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values""" """A collection of float primitive values"""
type: Literal["float_collection"] = "float_collection"
# Inputs
collection: list[float] = InputField( collection: list[float] = InputField(
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
) )
@ -182,44 +167,39 @@ class FloatCollectionInvocation(BaseInvocation):
# region String # region String
@invocation_output("string_output")
class StringOutput(BaseInvocationOutput): class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string""" """Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
value: str = OutputField(description="The output string") value: str = OutputField(description="The output string")
@invocation_output("string_collection_output")
class StringCollectionOutput(BaseInvocationOutput): class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings""" """Base class for nodes that output a collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection) collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
@title("String Primitive") @invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
@tags("primitives", "string")
class StringInvocation(BaseInvocation): class StringInvocation(BaseInvocation):
"""A string primitive value""" """A string primitive value"""
type: Literal["string"] = "string"
# Inputs
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput: def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=self.value) return StringOutput(value=self.value)
@title("String Primitive Collection") @invocation(
@tags("primitives", "string", "collection") "string_collection",
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
)
class StringCollectionInvocation(BaseInvocation): class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values""" """A collection of string primitive values"""
type: Literal["string_collection"] = "string_collection"
# Inputs
collection: list[str] = InputField( collection: list[str] = InputField(
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
) )
@ -239,33 +219,26 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image") image_name: str = Field(description="The name of the image")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput): class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image""" """Base class for nodes that output a single image"""
type: Literal["image_output"] = "image_output"
image: ImageField = OutputField(description="The output image") image: ImageField = OutputField(description="The output image")
width: int = OutputField(description="The width of the image in pixels") width: int = OutputField(description="The width of the image in pixels")
height: int = OutputField(description="The height of the image in pixels") height: int = OutputField(description="The height of the image in pixels")
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput): class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images""" """Base class for nodes that output a collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection) collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
@title("Image Primitive") @invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
@tags("primitives", "image")
class ImageInvocation(BaseInvocation): class ImageInvocation(BaseInvocation):
"""An image primitive value""" """An image primitive value"""
# Metadata
type: Literal["image"] = "image"
# Inputs
image: ImageField = InputField(description="The image to load") image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -278,14 +251,15 @@ class ImageInvocation(BaseInvocation):
) )
@title("Image Primitive Collection") @invocation(
@tags("primitives", "image", "collection") "image_collection",
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
)
class ImageCollectionInvocation(BaseInvocation): class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values""" """A collection of image primitive values"""
type: Literal["image_collection"] = "image_collection"
# Inputs
collection: list[ImageField] = InputField( collection: list[ImageField] = InputField(
default=0, description="The collection of image values", ui_type=UIType.ImageCollection default=0, description="The collection of image values", ui_type=UIType.ImageCollection
) )
@ -306,10 +280,10 @@ class DenoiseMaskField(BaseModel):
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents") masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
@invocation_output("denoise_mask_output")
class DenoiseMaskOutput(BaseInvocationOutput): class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image""" """Base class for nodes that output a single image"""
type: Literal["denoise_mask_output"] = "denoise_mask_output"
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
@ -325,11 +299,10 @@ class LatentsField(BaseModel):
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
@invocation_output("latents_output")
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor""" """Base class for nodes that output a single latents tensor"""
type: Literal["latents_output"] = "latents_output"
latents: LatentsField = OutputField( latents: LatentsField = OutputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
) )
@ -337,25 +310,20 @@ class LatentsOutput(BaseInvocationOutput):
height: int = OutputField(description=FieldDescriptions.height) height: int = OutputField(description=FieldDescriptions.height)
@invocation_output("latents_collection_output")
class LatentsCollectionOutput(BaseInvocationOutput): class LatentsCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of latents tensors""" """Base class for nodes that output a collection of latents tensors"""
type: Literal["latents_collection_output"] = "latents_collection_output"
collection: list[LatentsField] = OutputField( collection: list[LatentsField] = OutputField(
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection, ui_type=UIType.LatentsCollection,
) )
@title("Latents Primitive") @invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
@tags("primitives", "latents")
class LatentsInvocation(BaseInvocation): class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value""" """A latents tensor primitive value"""
type: Literal["latents"] = "latents"
# Inputs
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -364,14 +332,15 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents) return build_latents_output(self.latents.latents_name, latents)
@title("Latents Primitive Collection") @invocation(
@tags("primitives", "latents", "collection") "latents_collection",
title="Latents Collection Primitive",
tags=["primitives", "latents", "collection"],
category="primitives",
)
class LatentsCollectionInvocation(BaseInvocation): class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values""" """A collection of latents tensor primitive values"""
type: Literal["latents_collection"] = "latents_collection"
# Inputs
collection: list[LatentsField] = InputField( collection: list[LatentsField] = InputField(
description="The collection of latents tensors", ui_type=UIType.LatentsCollection description="The collection of latents tensors", ui_type=UIType.LatentsCollection
) )
@ -405,30 +374,24 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a) return (self.r, self.g, self.b, self.a)
@invocation_output("color_output")
class ColorOutput(BaseInvocationOutput): class ColorOutput(BaseInvocationOutput):
"""Base class for nodes that output a single color""" """Base class for nodes that output a single color"""
type: Literal["color_output"] = "color_output"
color: ColorField = OutputField(description="The output color") color: ColorField = OutputField(description="The output color")
@invocation_output("color_collection_output")
class ColorCollectionOutput(BaseInvocationOutput): class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors""" """Base class for nodes that output a collection of colors"""
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection) collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
@title("Color Primitive") @invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
@tags("primitives", "color")
class ColorInvocation(BaseInvocation): class ColorInvocation(BaseInvocation):
"""A color primitive value""" """A color primitive value"""
type: Literal["color"] = "color"
# Inputs
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context: InvocationContext) -> ColorOutput: def invoke(self, context: InvocationContext) -> ColorOutput:
@ -446,47 +409,47 @@ class ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor") conditioning_name: str = Field(description="The name of conditioning tensor")
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput): class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor""" """Base class for nodes that output a single conditioning tensor"""
type: Literal["conditioning_output"] = "conditioning_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput): class ConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors""" """Base class for nodes that output a collection of conditioning tensors"""
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
# Outputs
collection: list[ConditioningField] = OutputField( collection: list[ConditioningField] = OutputField(
description="The output conditioning tensors", description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection, ui_type=UIType.ConditioningCollection,
) )
@title("Conditioning Primitive") @invocation(
@tags("primitives", "conditioning") "conditioning",
title="Conditioning Primitive",
tags=["primitives", "conditioning"],
category="primitives",
)
class ConditioningInvocation(BaseInvocation): class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value""" """A conditioning tensor primitive value"""
type: Literal["conditioning"] = "conditioning"
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning) return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Primitive Collection") @invocation(
@tags("primitives", "conditioning", "collection") "conditioning_collection",
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
)
class ConditioningCollectionInvocation(BaseInvocation): class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values""" """A collection of conditioning tensor primitive values"""
type: Literal["conditioning_collection"] = "conditioning_collection"
# Inputs
collection: list[ConditioningField] = InputField( collection: list[ConditioningField] = InputField(
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
) )

View File

@ -1,5 +1,5 @@
from os.path import exists from os.path import exists
from typing import Literal, Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
@ -7,17 +7,13 @@ from pydantic import validator
from invokeai.app.invocations.primitives import StringCollectionOutput from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@title("Dynamic Prompt") @invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
@tags("prompt", "collection")
class DynamicPromptInvocation(BaseInvocation): class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
# Inputs
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea) prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
max_prompts: int = InputField(default=1, description="The number of prompts to generate") max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
@ -33,15 +29,11 @@ class DynamicPromptInvocation(BaseInvocation):
return StringCollectionOutput(collection=prompts) return StringCollectionOutput(collection=prompts)
@title("Prompts from File") @invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
@tags("prompt", "file")
class PromptsFromFileInvocation(BaseInvocation): class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file""" """Loads prompts from a text file"""
type: Literal["prompt_from_file"] = "prompt_from_file" file_path: str = InputField(description="Path to prompt text file")
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
pre_prompt: Optional[str] = InputField( pre_prompt: Optional[str] = InputField(
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
) )

View File

@ -1,5 +1,3 @@
from typing import Literal
from ...backend.model_management import ModelType, SubModelType from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -10,41 +8,35 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
tags, invocation,
title, invocation_output,
) )
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output""" """SDXL base model loader output"""
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output""" """SDXL refiner model loader output"""
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("SDXL Main Model") @invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
@tags("model", "sdxl")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
# Inputs
model: MainModelField = InputField( model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
) )
@ -122,14 +114,15 @@ class SDXLModelLoaderInvocation(BaseInvocation):
) )
@title("SDXL Refiner Model") @invocation(
@tags("model", "sdxl", "refiner") "sdxl_refiner_model_loader",
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
# Inputs
model: MainModelField = InputField( model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model, description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct, input=Input.Direct,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
# TODO: Populate this from disk? # TODO: Populate this from disk?
# TODO: Use model manager to load? # TODO: Use model manager to load?
@ -23,14 +23,10 @@ ESRGAN_MODELS = Literal[
] ]
@title("Upscale (RealESRGAN)") @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
@tags("esrgan", "upscale")
class ESRGANInvocation(BaseInvocation): class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN.""" """Upscales an image using RealESRGAN."""
type: Literal["esrgan"] = "esrgan"
# Inputs
image: ImageField = InputField(description="The input image") image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
@ -110,6 +106,7 @@ class ESRGANInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
workflow=self.workflow,
) )
return ImageOutput( return ImageOutput(

View File

@ -3,7 +3,7 @@
import copy import copy
import itertools import itertools
import uuid import uuid
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import BaseModel, root_validator, validator from pydantic import BaseModel, root_validator, validator
@ -14,11 +14,13 @@ from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import ( from ..invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
invocation,
Input, Input,
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
UIType, UIType,
invocation_output,
) )
# in 3.10 this would be "from types import NoneType" # in 3.10 this would be "from types import NoneType"
@ -148,24 +150,16 @@ class NodeAlreadyExecutedError(Exception):
# TODO: Create and use an Empty output? # TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput): class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output" pass
class Config:
schema_extra = {
"required": [
"type",
"image",
]
}
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
@invocation("graph")
class GraphInvocation(BaseInvocation): class GraphInvocation(BaseInvocation):
"""Execute a graph""" """Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here # TODO: figure out how to create a default here
graph: "Graph" = Field(description="The graph to run", default=None) graph: "Graph" = Field(description="The graph to run", default=None)
@ -174,22 +168,20 @@ class GraphInvocation(BaseInvocation):
return GraphInvocationOutput() return GraphInvocationOutput()
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput): class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output.""" """Used to connect iteration outputs. Will be expanded to a specific output."""
type: Literal["iterate_output"] = "iterate_output"
item: Any = OutputField( item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
) )
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
@invocation("iterate")
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
"""Iterates over a list of items""" """Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = InputField( collection: list[Any] = InputField(
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
) )
@ -200,19 +192,17 @@ class IterateInvocation(BaseInvocation):
return IterateInvocationOutput(item=self.collection[self.index]) return IterateInvocationOutput(item=self.collection[self.index])
@invocation_output("collect_output")
class CollectInvocationOutput(BaseInvocationOutput): class CollectInvocationOutput(BaseInvocationOutput):
type: Literal["collect_output"] = "collect_output"
collection: list[Any] = OutputField( collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type=UIType.Collection description="The collection of input items", title="Collection", ui_type=UIType.Collection
) )
@invocation("collect")
class CollectInvocation(BaseInvocation): class CollectInvocation(BaseInvocation):
"""Collects values into a collection""" """Collects values into a collection"""
type: Literal["collect"] = "collect"
item: Any = InputField( item: Any = InputField(
description="The item to collect (all inputs must be of the same type)", description="The item to collect (all inputs must be of the same type)",
ui_type=UIType.CollectionItem, ui_type=UIType.CollectionItem,

View File

@ -60,7 +60,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None, workflow: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,7 +110,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None, workflow: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
@ -119,12 +119,23 @@ class DiskImageFileStorage(ImageFileStorageBase):
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
if metadata is not None or workflow is not None:
if metadata is not None: if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata)) pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None: if workflow is not None:
pnginfo.add_text("invokeai_graph", json.dumps(graph)) pnginfo.add_text("invokeai_workflow", workflow)
else:
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
# TODO: retain non-invokeai metadata on save...
original_metadata = image.info.get("invokeai_metadata", None)
if original_metadata is not None:
pnginfo.add_text("invokeai_metadata", original_metadata)
original_workflow = image.info.get("invokeai_workflow", None)
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo) image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)

View File

@ -54,6 +54,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -177,6 +178,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -186,16 +188,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name() image_name = self._services.names.create_image_name()
graph = None # TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
if session_id is not None: # if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id) # session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None: # if session_raw is not None:
try: # try:
graph = get_metadata_graph_from_raw_session(session_raw) # graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e: # except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}") # self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None # graph = None
(width, height) = image.size (width, height) = image.size
@ -217,7 +219,7 @@ class ImageService(ImageServiceABC):
) )
if board_id is not None: if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph) self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
return image_dto return image_dto

View File

@ -53,7 +53,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
- `starred`: change whether the image is starred - `starred`: change whether the image is starred
""" """
image_category: Optional[ImageCategory] = Field(description="The image's new category.") image_category: Optional[ImageCategory] = Field(default=None, description="The image's new category.")
"""The image's new category.""" """The image's new category."""
session_id: Optional[StrictStr] = Field( session_id: Optional[StrictStr] = Field(
default=None, default=None,

View File

@ -7,5 +7,4 @@ stats.html
index.html index.html
.yarn/ .yarn/
*.scss *.scss
src/services/api/ src/services/api/schema.d.ts
src/services/fixtures/*

View File

@ -7,8 +7,7 @@ index.html
.yarn/ .yarn/
.yalc/ .yalc/
*.scss *.scss
src/services/api/ src/services/api/schema.d.ts
src/services/fixtures/*
docs/ docs/
static/ static/
src/theme/css/overlayscrollbars.css src/theme/css/overlayscrollbars.css

View File

@ -74,6 +74,7 @@
"@nanostores/react": "^0.7.1", "@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "^1.9.5", "@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5", "@roarr/browser-log-writer": "^1.1.5",
"@stevebel/png": "^1.5.1",
"dateformat": "^5.0.3", "dateformat": "^5.0.3",
"formik": "^2.4.3", "formik": "^2.4.3",
"framer-motion": "^10.16.1", "framer-motion": "^10.16.1",
@ -110,6 +111,7 @@
"roarr": "^7.15.1", "roarr": "^7.15.1",
"serialize-error": "^11.0.1", "serialize-error": "^11.0.1",
"socket.io-client": "^4.7.2", "socket.io-client": "^4.7.2",
"type-fest": "^4.2.0",
"use-debounce": "^9.0.4", "use-debounce": "^9.0.4",
"use-image": "^1.1.1", "use-image": "^1.1.1",
"uuid": "^9.0.0", "uuid": "^9.0.0",

View File

@ -719,7 +719,7 @@
}, },
"nodes": { "nodes": {
"reloadNodeTemplates": "Reload Node Templates", "reloadNodeTemplates": "Reload Node Templates",
"saveWorkflow": "Save Workflow", "downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow", "loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow", "resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?", "resetWorkflowDesc": "Are you sure you want to reset this workflow?",

View File

@ -1,10 +1,12 @@
import { Box } from '@chakra-ui/react'; import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectIsBusy } from 'features/system/store/systemSelectors'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { AnimatePresence, motion } from 'framer-motion';
import { import {
KeyboardEvent, KeyboardEvent,
ReactNode, ReactNode,
@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images'; import { useUploadImageMutation } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
import ImageUploadOverlay from './ImageUploadOverlay'; import ImageUploadOverlay from './ImageUploadOverlay';
import { AnimatePresence, motion } from 'framer-motion';
import { stateSelector } from 'app/store/store';
const selector = createSelector( const selector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],

View File

@ -0,0 +1,56 @@
import { Box } from '@chakra-ui/react';
import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.light';
}
if (isSelected) {
return 'nodeSelected.light';
}
if (isHovered) {
return 'nodeHovered.light';
}
return undefined;
}, [isHovered, isSelected]);
const shadowDark = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.dark';
}
if (isSelected) {
return 'nodeSelected.dark';
}
if (isHovered) {
return 'nodeHovered.dark';
}
return undefined;
}, [isHovered, isSelected]);
return (
<Box
className="selection-box"
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
opacity: isSelected || isHovered ? 1 : 0.5,
transitionProperty: 'common',
transitionDuration: '0.1s',
pointerEvents: 'none',
shadow,
_dark: {
shadow: shadowDark,
},
}}
/>
);
};
export default memo(SelectionOverlay);

View File

@ -15,6 +15,7 @@ import { BoardDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu'; import { menuListMotionProps } from 'theme/components/menu';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems'; import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
import NoBoardContextMenuItems from './NoBoardContextMenuItems'; import NoBoardContextMenuItems from './NoBoardContextMenuItems';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type Props = { type Props = {
board?: BoardDTO; board?: BoardDTO;
@ -33,12 +34,16 @@ const BoardContextMenu = ({
const selector = useMemo( const selector = useMemo(
() => () =>
createSelector(stateSelector, ({ gallery, system }) => { createSelector(
stateSelector,
({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id; const isAutoAdd = gallery.autoAddBoardId === board_id;
const isProcessing = system.isProcessing; const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick; const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick }; return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
}), },
defaultSelectorOptions
),
[board_id] [board_id]
); );

View File

@ -9,20 +9,24 @@ import {
MenuButton, MenuButton,
MenuList, MenuList,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton'; import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings'; import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import {
setActiveTab,
setShouldShowImageDetails, setShouldShowImageDetails,
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice'; } from 'features/ui/store/uiSlice';
@ -37,12 +41,12 @@ import {
FaSeedling, FaSeedling,
FaShareAlt, FaShareAlt,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { MdDeviceHub } from 'react-icons/md';
import { import {
useGetImageDTOQuery, useGetImageDTOQuery,
useGetImageMetadataQuery, useGetImageMetadataFromFileQuery,
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { menuListMotionProps } from 'theme/components/menu'; import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions'; import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems'; import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
@ -101,22 +105,36 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: imageDTO } = useGetImageDTOQuery( const { currentData: imageDTO } = useGetImageDTOQuery(
lastSelectedImage?.image_name ?? skipToken lastSelectedImage?.image_name ?? skipToken
); );
const { currentData: metadataData } = useGetImageMetadataQuery( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
debounceState.isPending() lastSelectedImage?.image_name ?? skipToken,
? skipToken {
: debouncedMetadataQueryArg?.image_name ?? skipToken selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
); );
const metadata = metadataData?.metadata; const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata); recallAllParameters(metadata);
@ -153,6 +171,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [imageDTO]); useHotkeys('p', handleUsePrompt, [imageDTO]);
useHotkeys('w', handleLoadWorkflow, [workflow]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img()); dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO)); dispatch(initialImageSelected(imageDTO));
@ -259,22 +279,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
isLoading={isLoading}
icon={<MdDeviceHub />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!workflow}
onClick={handleLoadWorkflow}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`} aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!metadata?.positive_prompt} isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt} onClick={handleUsePrompt}
/> />
<IAIIconButton <IAIIconButton
isLoading={isLoading}
icon={<FaSeedling />} icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`} tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`} aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!metadata?.seed} isDisabled={!metadata?.seed}
onClick={handleUseSeed} onClick={handleUseSeed}
/> />
<IAIIconButton <IAIIconButton
isLoading={isLoading}
icon={<FaAsterisk />} icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`} tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`}

View File

@ -1,5 +1,4 @@
import { Flex, MenuItem, Text } from '@chakra-ui/react'; import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -8,9 +7,12 @@ import {
isModalOpenChanged, isModalOpenChanged,
} from 'features/changeBoardModal/store/slice'; } from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard'; import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -26,14 +28,13 @@ import {
FaShare, FaShare,
FaTrash, FaTrash,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md'; import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
import { import {
useGetImageMetadataQuery, useGetImageMetadataFromFileQuery,
useStarImagesMutation, useStarImagesMutation,
useUnstarImagesMutation, useUnstarImagesMutation,
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
type SingleSelectionMenuItemsProps = { type SingleSelectionMenuItemsProps = {
@ -50,15 +51,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const [debouncedMetadataQueryArg, debounceState] = useDebounce( const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name, imageDTO.image_name,
500 {
); selectFromResult: (res) => ({
isLoading: res.isFetching,
const { currentData } = useGetImageMetadataQuery( metadata: res?.currentData?.metadata,
debounceState.isPending() workflow: res?.currentData?.workflow,
? skipToken }),
: debouncedMetadataQueryArg ?? skipToken }
); );
const [starImages] = useStarImagesMutation(); const [starImages] = useStarImagesMutation();
@ -67,8 +68,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { isClipboardAPIAvailable, copyImageToClipboard } = const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard(); useCopyImageToClipboard();
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (!imageDTO) { if (!imageDTO) {
return; return;
@ -99,6 +98,22 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
recallSeed(metadata?.seed); recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]); }, [metadata?.seed, recallSeed]);
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img()); dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO)); dispatch(initialImageSelected(imageDTO));
@ -118,7 +133,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
}, [dispatch, imageDTO, t, toaster]); }, [dispatch, imageDTO, t, toaster]);
const handleUseAllParameters = useCallback(() => { const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata); recallAllParameters(metadata);
}, [metadata, recallAllParameters]); }, [metadata, recallAllParameters]);
@ -169,27 +183,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.downloadImage')} {t('parameters.downloadImage')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<FaQuoteRight />} icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
onClickCapture={handleLoadWorkflow}
isDisabled={isLoading || !workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>
<MenuItem
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
onClickCapture={handleRecallPrompt} onClickCapture={handleRecallPrompt}
isDisabled={ isDisabled={
metadata?.positive_prompt === undefined && isLoading ||
metadata?.negative_prompt === undefined (metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined)
} }
> >
{t('parameters.usePrompt')} {t('parameters.usePrompt')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<FaSeedling />} icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
onClickCapture={handleRecallSeed} onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined} isDisabled={isLoading || metadata?.seed === undefined}
> >
{t('parameters.useSeed')} {t('parameters.useSeed')}
</MenuItem> </MenuItem>
<MenuItem <MenuItem
icon={<FaAsterisk />} icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={!metadata} isDisabled={isLoading || !metadata}
> >
{t('parameters.useAll')} {t('parameters.useAll')}
</MenuItem> </MenuItem>
@ -228,20 +249,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
> >
{t('gallery.deleteImage')} {t('gallery.deleteImage')}
</MenuItem> </MenuItem>
{metadata?.created_by && (
<Flex
sx={{
padding: '5px 10px',
marginTop: '5px',
}}
>
<Text fontSize="xs" fontWeight="bold">
Created by {metadata?.created_by}
</Text>
</Flex>
)}
</> </>
); );
}; };
export default memo(SingleSelectionMenuItems); export default memo(SingleSelectionMenuItems);
const SpinnerIcon = () => (
<Flex w="14px" alignItems="center" justifyContent="center">
<Spinner size="xs" />
</Flex>
);

View File

@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { isString } from 'lodash-es'; import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa'; import { FaCopy, FaDownload } from 'react-icons/fa';
type Props = { type Props = {
label: string; label: string;
@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
navigator.clipboard.writeText(dataString); navigator.clipboard.writeText(dataString);
}, [dataString]); }, [dataString]);
const handleSave = useCallback(() => { const handleDownload = useCallback(() => {
const blob = new Blob([dataString]); const blob = new Blob([dataString]);
const a = document.createElement('a'); const a = document.createElement('a');
a.href = URL.createObjectURL(blob); a.href = URL.createObjectURL(blob);
@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
</Box> </Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}> <Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
{withDownload && ( {withDownload && (
<Tooltip label={`Save ${label} JSON`}> <Tooltip label={`Download ${label} JSON`}>
<IconButton <IconButton
aria-label={`Save ${label} JSON`} aria-label={`Download ${label} JSON`}
icon={<FaSave />} icon={<FaDownload />}
variant="ghost" variant="ghost"
opacity={0.7} opacity={0.7}
onClick={handleSave} onClick={handleDownload}
/> />
</Tooltip> </Tooltip>
)} )}

View File

@ -1,10 +1,10 @@
import { CoreMetadata } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/types';
import ImageMetadataItem from './ImageMetadataItem'; import ImageMetadataItem from './ImageMetadataItem';
type Props = { type Props = {
metadata?: UnsafeImageMetadata['metadata']; metadata?: CoreMetadata;
}; };
const ImageMetadataActions = (props: Props) => { const ImageMetadataActions = (props: Props) => {
@ -94,14 +94,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallNegativePrompt} onClick={handleRecallNegativePrompt}
/> />
)} )}
{metadata.seed !== undefined && ( {metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem <ImageMetadataItem
label="Seed" label="Seed"
value={metadata.seed} value={metadata.seed}
onClick={handleRecallSeed} onClick={handleRecallSeed}
/> />
)} )}
{metadata.model !== undefined && ( {metadata.model !== undefined && metadata.model !== null && (
<ImageMetadataItem <ImageMetadataItem
label="Model" label="Model"
value={metadata.model.model_name} value={metadata.model.model_name}
@ -150,7 +150,7 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSteps} onClick={handleRecallSteps}
/> />
)} )}
{metadata.cfg_scale !== undefined && ( {metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
<ImageMetadataItem <ImageMetadataItem
label="CFG scale" label="CFG scale"
value={metadata.cfg_scale} value={metadata.cfg_scale}

View File

@ -9,14 +9,12 @@ import {
Tabs, Tabs,
Text, Text,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react'; import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import ImageMetadataActions from './ImageMetadataActions';
import DataViewer from './DataViewer'; import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
type ImageMetadataViewerProps = { type ImageMetadataViewerProps = {
image: ImageDTO; image: ImageDTO;
@ -29,19 +27,16 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false)); // dispatch(setShouldShowImageDetails(false));
// }); // });
const [debouncedMetadataQueryArg, debounceState] = useDebounce( const { metadata, workflow } = useGetImageMetadataFromFileQuery(
image.image_name, image.image_name,
500 {
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
); );
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
return ( return (
<Flex <Flex
layerStyle="first" layerStyle="first"
@ -71,17 +66,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }} sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
> >
<TabList> <TabList>
<Tab>Core Metadata</Tab> <Tab>Metadata</Tab>
<Tab>Image Details</Tab> <Tab>Image Details</Tab>
<Tab>Graph</Tab> <Tab>Workflow</Tab>
</TabList> </TabList>
<TabPanels> <TabPanels>
<TabPanel> <TabPanel>
{metadata ? ( {metadata ? (
<DataViewer data={metadata} label="Core Metadata" /> <DataViewer data={metadata} label="Metadata" />
) : ( ) : (
<IAINoContentFallback label="No core metadata found" /> <IAINoContentFallback label="No metadata found" />
)} )}
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
@ -92,10 +87,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
)} )}
</TabPanel> </TabPanel>
<TabPanel> <TabPanel>
{graph ? ( {workflow ? (
<DataViewer data={graph} label="Graph" /> <DataViewer data={workflow} label="Workflow" />
) : ( ) : (
<IAINoContentFallback label="No graph found" /> <IAINoContentFallback label="No workflow found" />
)} )}
</TabPanel> </TabPanel>
</TabPanels> </TabPanels>

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const embedWorkflow = useEmbedWorkflow(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeEmbedWorkflowChanged({
nodeId,
embedWorkflow: e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={embedWorkflow}
/>
</FormControl>
);
};
export default memo(EmbedWorkflowCheckbox);

View File

@ -41,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
flexDirection: 'column', flexDirection: 'column',
w: 'full', w: 'full',
h: 'full', h: 'full',
py: 1, py: 2,
gap: 1, gap: 1,
borderBottomRadius: withFooter ? 0 : 'base', borderBottomRadius: withFooter ? 0 : 'base',
}} }}

View File

@ -1,16 +1,8 @@
import { import { Flex } from '@chakra-ui/react';
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { ChangeEvent, memo, useCallback } from 'react'; import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
type Props = { type Props = {
nodeId: string; nodeId: string;
@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
px: 2, px: 2,
py: 0, py: 0,
h: 6, h: 6,
justifyContent: 'space-between',
}} }}
> >
<Spacer /> <EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveImageCheckbox nodeId={nodeId} /> <SaveToGalleryCheckbox nodeId={nodeId} />
</Flex> </Flex>
); );
}; };
export default memo(InvocationNodeFooter); export default memo(InvocationNodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,7 +1,5 @@
import { import {
Flex, Flex,
FormControl,
FormLabel,
Icon, Icon,
Modal, Modal,
ModalBody, ModalBody,
@ -14,16 +12,14 @@ import {
Tooltip, Tooltip,
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData'; import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel'; import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle'; import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types'; import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react'; import { memo, useMemo } from 'react';
import { FaInfoCircle } from 'react-icons/fa'; import { FaInfoCircle } from 'react-icons/fa';
import NotesTextarea from './NotesTextarea';
interface Props { interface Props {
nodeId: string; nodeId: string;
@ -80,13 +76,29 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId); const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId); const nodeTemplate = useNodeTemplate(nodeId);
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
}
if (!data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return 'Unknown Node';
}, [data, nodeTemplate]);
if (!isInvocationNodeData(data)) { if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>; return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
} }
return ( return (
<Flex sx={{ flexDir: 'column' }}> <Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text> <Text sx={{ fontWeight: 600 }}>{title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}> <Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description} {nodeTemplate?.description}
</Text> </Text>
@ -96,29 +108,3 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
}); });
TooltipContent.displayName = 'TooltipContent'; TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@ -0,0 +1,33 @@
import { FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
};
export default memo(NotesTextarea);

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const isIntermediate = useIsIntermediate(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeIsIntermediateChanged({
nodeId,
isIntermediate: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={!isIntermediate}
/>
</FormControl>
);
};
export default memo(SaveToGalleryCheckbox);

View File

@ -0,0 +1,167 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
Tooltip,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
import FieldTooltipContent from './FieldTooltipContent';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
interface Props {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
isMissingInput?: boolean;
withTooltip?: boolean;
}
const EditableFieldTitle = forwardRef((props: Props, ref) => {
const {
nodeId,
fieldName,
kind,
isMissingInput = false,
withTooltip = false,
} = props;
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return (
<Tooltip
label={
withTooltip ? (
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
) : undefined
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
>
<Flex
ref={ref}
sx={{
position: 'relative',
overflow: 'hidden',
alignItems: 'center',
justifyContent: 'flex-start',
gap: 1,
h: 'full',
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
w: 'full',
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex>
</Tooltip>
);
});
export default memo(EditableFieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@ -1,16 +1,7 @@
import { import { Flex, Text, forwardRef } from '@chakra-ui/react';
Editable,
EditableInput,
EditablePreview,
Flex,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice'; import { memo } from 'react';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
interface Props { interface Props {
nodeId: string; nodeId: string;
@ -24,31 +15,6 @@ const FieldTitle = forwardRef((props: Props, ref) => {
const label = useFieldLabel(nodeId, fieldName); const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind); const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return ( return (
<Flex <Flex
ref={ref} ref={ref}
@ -62,82 +28,11 @@ const FieldTitle = forwardRef((props: Props, ref) => {
w: 'full', w: 'full',
}} }}
> >
<Editable <Text sx={{ fontWeight: isMissingInput ? 600 : 400 }}>
value={localTitle} {label || fieldTemplateTitle}
onChange={handleChange} </Text>
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
w: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex> </Flex>
); );
}); });
export default memo(FieldTitle); export default memo(FieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@ -34,6 +34,8 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
} }
return 'Unknown Field'; return 'Unknown Field';
} else {
return fieldTemplate?.title || 'Unknown Field';
} }
}, [field, fieldTemplate]); }, [field, fieldTemplate]);

View File

@ -1,16 +1,11 @@
import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue'; import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo, useMemo } from 'react'; import { PropsWithChildren, memo, useMemo } from 'react';
import EditableFieldTitle from './EditableFieldTitle';
import FieldContextMenu from './FieldContextMenu'; import FieldContextMenu from './FieldContextMenu';
import FieldHandle from './FieldHandle'; import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
interface Props { interface Props {
@ -21,7 +16,6 @@ interface Props {
const InputField = ({ nodeId, fieldName }: Props) => { const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const input = useFieldInputKind(nodeId, fieldName);
const { const {
isConnected, isConnected,
@ -51,11 +45,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
if (fieldTemplate?.fieldKind !== 'input') { if (fieldTemplate?.fieldKind !== 'input') {
return ( return (
<InputFieldWrapper <InputFieldWrapper shouldDim={shouldDim}>
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<FormControl <FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }} sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
> >
@ -66,19 +56,14 @@ const InputField = ({ nodeId, fieldName }: Props) => {
} }
return ( return (
<InputFieldWrapper <InputFieldWrapper shouldDim={shouldDim}>
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<FormControl <FormControl
as={Flex}
isInvalid={isMissingInput} isInvalid={isMissingInput}
isDisabled={isConnected} isDisabled={isConnected}
sx={{ sx={{
alignItems: 'stretch', alignItems: 'stretch',
justifyContent: 'space-between', justifyContent: 'space-between',
ps: 2, ps: fieldTemplate.input === 'direct' ? 0 : 2,
gap: 2, gap: 2,
h: 'full', h: 'full',
w: 'full', w: 'full',
@ -86,42 +71,27 @@ const InputField = ({ nodeId, fieldName }: Props) => {
> >
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input"> <FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
{(ref) => ( {(ref) => (
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
>
<FormLabel <FormLabel
sx={{ sx={{
display: 'flex',
alignItems: 'center',
mb: 0, mb: 0,
width: input === 'connection' ? 'auto' : '25%', px: 1,
flexShrink: 0, gap: 2,
flexGrow: 0,
}} }}
> >
<FieldTitle <EditableFieldTitle
ref={ref} ref={ref}
nodeId={nodeId} nodeId={nodeId}
fieldName={fieldName} fieldName={fieldName}
kind="input" kind="input"
isMissingInput={isMissingInput} isMissingInput={isMissingInput}
withTooltip
/> />
</FormLabel> </FormLabel>
</Tooltip>
)} )}
</FieldContextMenu> </FieldContextMenu>
<Box <Box>
sx={{
width: input === 'connection' ? 'auto' : '75%',
}}
>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} /> <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</Box> </Box>
</FormControl> </FormControl>
@ -143,19 +113,12 @@ export default memo(InputField);
type InputFieldWrapperProps = PropsWithChildren<{ type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean; shouldDim: boolean;
nodeId: string;
fieldName: string;
}>; }>;
const InputFieldWrapper = memo( const InputFieldWrapper = memo(
({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => { ({ shouldDim, children }: InputFieldWrapperProps) => {
const { isMouseOverField, handleMouseOver, handleMouseOut } =
useIsMouseOverField(nodeId, fieldName);
return ( return (
<Flex <Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
sx={{ sx={{
position: 'relative', position: 'relative',
minH: 8, minH: 8,
@ -169,7 +132,6 @@ const InputFieldWrapper = memo(
}} }}
> >
{children} {children}
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
</Flex> </Flex>
); );
} }

View File

@ -1,13 +1,20 @@
import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react'; import {
Flex,
FormControl,
FormLabel,
Icon,
Spacer,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import SelectionOverlay from 'common/components/SelectionOverlay'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice'; import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaInfoCircle, FaTrash } from 'react-icons/fa'; import { FaInfoCircle, FaTrash } from 'react-icons/fa';
import FieldTitle from './FieldTitle'; import EditableFieldTitle from './EditableFieldTitle';
import FieldTooltipContent from './FieldTooltipContent'; import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer'; import InputFieldRenderer from './InputFieldRenderer';
@ -18,8 +25,8 @@ type Props = {
const LinearViewField = ({ nodeId, fieldName }: Props) => { const LinearViewField = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isMouseOverField, handleMouseOut, handleMouseOver } = const { isMouseOverNode, handleMouseOut, handleMouseOver } =
useIsMouseOverField(nodeId, fieldName); useMouseOverNode(nodeId);
const handleRemoveField = useCallback(() => { const handleRemoveField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName })); dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
@ -27,8 +34,8 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
return ( return (
<Flex <Flex
onMouseOver={handleMouseOver} onMouseEnter={handleMouseOver}
onMouseOut={handleMouseOut} onMouseLeave={handleMouseOut}
layerStyle="second" layerStyle="second"
sx={{ sx={{
position: 'relative', position: 'relative',
@ -42,11 +49,15 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
sx={{ sx={{
display: 'flex', display: 'flex',
alignItems: 'center', alignItems: 'center',
justifyContent: 'space-between',
mb: 0, mb: 0,
}} }}
> >
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" /> <EditableFieldTitle
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
<Spacer />
<Tooltip <Tooltip
label={ label={
<FieldTooltipContent <FieldTooltipContent
@ -74,7 +85,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
</FormLabel> </FormLabel>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} /> <InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</FormControl> </FormControl>
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} /> <NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
</Flex> </Flex>
); );
}; };

View File

@ -92,6 +92,7 @@ const ControlNetModelInputFieldComponent = (
error={!selectedModel} error={!selectedModel}
data={data} data={data}
onChange={handleValueChanged} onChange={handleValueChanged}
sx={{ width: '100%' }}
/> />
); );
}; };

View File

@ -101,8 +101,10 @@ const LoRAModelInputFieldComponent = (
item.label?.toLowerCase().includes(value.toLowerCase().trim()) || item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) item.value.toLowerCase().includes(value.toLowerCase().trim())
} }
error={!selectedLoRAModel}
onChange={handleChange} onChange={handleChange}
sx={{ sx={{
width: '100%',
'.mantine-Select-dropdown': { '.mantine-Select-dropdown': {
width: '16rem !important', width: '16rem !important',
}, },

View File

@ -134,6 +134,7 @@ const MainModelInputFieldComponent = (
disabled={data.length === 0} disabled={data.length === 0}
onChange={handleChangeModel} onChange={handleChangeModel}
sx={{ sx={{
width: '100%',
'.mantine-Select-dropdown': { '.mantine-Select-dropdown': {
width: '16rem !important', width: '16rem !important',
}, },

View File

@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
FieldComponentProps,
SDXLRefinerModelInputFieldTemplate, SDXLRefinerModelInputFieldTemplate,
SDXLRefinerModelInputFieldValue, SDXLRefinerModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
@ -101,20 +101,17 @@ const RefinerModelInputFieldComponent = (
value={selectedModel?.id} value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'} placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data} data={data}
error={data.length === 0} error={!selectedModel}
disabled={data.length === 0} disabled={data.length === 0}
onChange={handleChangeModel} onChange={handleChangeModel}
sx={{ sx={{
width: '100%',
'.mantine-Select-dropdown': { '.mantine-Select-dropdown': {
width: '16rem !important', width: '16rem !important',
}, },
}} }}
/> />
{isSyncModelEnabled && ( {isSyncModelEnabled && <SyncModelsButton className="nodrag" iconMode />}
<Box mt={7}>
<SyncModelsButton className="nodrag" iconMode />
</Box>
)}
</Flex> </Flex>
); );
}; };

View File

@ -128,10 +128,11 @@ const ModelInputFieldComponent = (
value={selectedModel?.id} value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'} placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data} data={data}
error={data.length === 0} error={!selectedModel}
disabled={data.length === 0} disabled={data.length === 0}
onChange={handleChangeModel} onChange={handleChangeModel}
sx={{ sx={{
width: '100%',
'.mantine-Select-dropdown': { '.mantine-Select-dropdown': {
width: '16rem !important', width: '16rem !important',
}, },

View File

@ -4,9 +4,9 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
FieldComponentProps,
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
VaeModelInputFieldValue, VaeModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
@ -88,17 +88,15 @@ const VaeModelInputFieldComponent = (
className="nowheel nodrag" className="nowheel nodrag"
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={
selectedVaeModel?.base_model &&
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
}
value={selectedVaeModel?.id ?? 'default'} value={selectedVaeModel?.id ?? 'default'}
placeholder="Default" placeholder="Default"
data={data} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
disabled={data.length === 0} disabled={data.length === 0}
error={!selectedVaeModel}
clearable clearable
sx={{ sx={{
width: '100%',
'.mantine-Select-dropdown': { '.mantine-Select-dropdown': {
width: '16rem !important', width: '16rem !important',
}, },

View File

@ -27,9 +27,11 @@ const NodeTitle = ({ nodeId, title }: Props) => {
const handleSubmit = useCallback( const handleSubmit = useCallback(
async (newTitle: string) => { async (newTitle: string) => {
dispatch(nodeLabelChanged({ nodeId, label: newTitle })); dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
setLocalTitle(newTitle || title || 'Problem Setting Title'); setLocalTitle(
newTitle || title || templateTitle || 'Problem Setting Title'
);
}, },
[nodeId, dispatch, title] [dispatch, nodeId, title, templateTitle]
); );
const handleChange = useCallback((newTitle: string) => { const handleChange = useCallback((newTitle: string) => {

View File

@ -7,6 +7,8 @@ import {
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { import {
DRAG_HANDLE_CLASSNAME, DRAG_HANDLE_CLASSNAME,
NODE_WIDTH, NODE_WIDTH,
@ -23,6 +25,8 @@ type NodeWrapperProps = PropsWithChildren & {
const NodeWrapper = (props: NodeWrapperProps) => { const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props; const { nodeId, width, children, selected } = props;
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
useMouseOverNode(nodeId);
const selectIsInProgress = useMemo( const selectIsInProgress = useMemo(
() => () =>
@ -36,16 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const isInProgress = useAppSelector(selectIsInProgress); const isInProgress = useAppSelector(selectIsInProgress);
const [ const [nodeInProgressLight, nodeInProgressDark, shadowsXl, shadowsBase] =
nodeSelectedLight, useToken('shadows', [
nodeSelectedDark,
nodeInProgressLight,
nodeInProgressDark,
shadowsXl,
shadowsBase,
] = useToken('shadows', [
'nodeSelected.light',
'nodeSelected.dark',
'nodeInProgress.light', 'nodeInProgress.light',
'nodeInProgress.dark', 'nodeInProgress.dark',
'shadows.xl', 'shadows.xl',
@ -54,7 +50,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedShadow = useColorModeValue(nodeSelectedLight, nodeSelectedDark);
const inProgressShadow = useColorModeValue( const inProgressShadow = useColorModeValue(
nodeInProgressLight, nodeInProgressLight,
nodeInProgressDark nodeInProgressDark
@ -69,6 +64,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
return ( return (
<Box <Box
onClick={handleClick} onClick={handleClick}
onMouseEnter={handleMouseOver}
onMouseLeave={handleMouseOut}
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}
sx={{ sx={{
h: 'full', h: 'full',
@ -77,11 +74,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
w: width ?? NODE_WIDTH, w: width ?? NODE_WIDTH,
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: '0.1s', transitionDuration: '0.1s',
shadow: selected
? isInProgress
? undefined
: selectedShadow
: undefined,
cursor: 'grab', cursor: 'grab',
opacity, opacity,
}} }}
@ -116,6 +108,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
}} }}
/> />
{children} {children}
<NodeSelectionOverlay isSelected={selected} isHovered={isMouseOverNode} />
</Box> </Box>
); );
}; };

View File

@ -2,12 +2,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useWorkflow } from 'features/nodes/hooks/useWorkflow'; import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa'; import { FaDownload } from 'react-icons/fa';
const SaveWorkflowButton = () => { const DownloadWorkflowButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const workflow = useWorkflow(); const workflow = useWorkflow();
const handleSave = useCallback(() => { const handleDownload = useCallback(() => {
const blob = new Blob([JSON.stringify(workflow, null, 2)]); const blob = new Blob([JSON.stringify(workflow, null, 2)]);
const a = document.createElement('a'); const a = document.createElement('a');
a.href = URL.createObjectURL(blob); a.href = URL.createObjectURL(blob);
@ -18,12 +18,12 @@ const SaveWorkflowButton = () => {
}, [workflow]); }, [workflow]);
return ( return (
<IAIIconButton <IAIIconButton
icon={<FaSave />} icon={<FaDownload />}
tooltip={t('nodes.saveWorkflow')} tooltip={t('nodes.downloadWorkflow')}
aria-label={t('nodes.saveWorkflow')} aria-label={t('nodes.downloadWorkflow')}
onClick={handleSave} onClick={handleDownload}
/> />
); );
}; };
export default memo(SaveWorkflowButton); export default memo(DownloadWorkflowButton);

View File

@ -2,7 +2,7 @@ import { Flex } from '@chakra-ui/layout';
import { memo } from 'react'; import { memo } from 'react';
import LoadWorkflowButton from './LoadWorkflowButton'; import LoadWorkflowButton from './LoadWorkflowButton';
import ResetWorkflowButton from './ResetWorkflowButton'; import ResetWorkflowButton from './ResetWorkflowButton';
import SaveWorkflowButton from './SaveWorkflowButton'; import DownloadWorkflowButton from './DownloadWorkflowButton';
const TopCenterPanel = () => { const TopCenterPanel = () => {
return ( return (
@ -15,7 +15,7 @@ const TopCenterPanel = () => {
transform: 'translate(-50%)', transform: 'translate(-50%)',
}} }}
> >
<SaveWorkflowButton /> <DownloadWorkflowButton />
<LoadWorkflowButton /> <LoadWorkflowButton />
<ResetWorkflowButton /> <ResetWorkflowButton />
</Flex> </Flex>

View File

@ -0,0 +1,74 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvocationTemplate, NodeData } from 'features/nodes/types/types';
import { memo } from 'react';
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
import NodeTitle from '../../flow/nodes/common/NodeTitle';
import ScrollableContent from '../ScrollableContent';
const selector = createSelector(
stateSelector,
({ nodes }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId
);
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
return {
data: lastSelectedNode?.data,
template: lastSelectedNodeTemplate,
};
},
defaultSelectorOptions
);
const InspectorDetailsTab = () => {
const { data, template } = useAppSelector(selector);
if (!template || !data) {
return <IAINoContentFallback label="No node selected" icon={null} />;
}
return <Content data={data} template={template} />;
};
export default memo(InspectorDetailsTab);
const Content = (props: { data: NodeData; template: InvocationTemplate }) => {
const { data } = props;
return (
<Box
sx={{
position: 'relative',
w: 'full',
h: 'full',
}}
>
<ScrollableContent>
<Flex
sx={{
flexDir: 'column',
position: 'relative',
p: 1,
gap: 2,
w: 'full',
}}
>
<NodeTitle nodeId={data.id} />
<NotesTextarea nodeId={data.id} />
</Flex>
</ScrollableContent>
</Box>
);
};

View File

@ -4,12 +4,13 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { isInvocationNode } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo } from 'react';
import ImageOutputPreview from './outputs/ImageOutputPreview'; import { ImageOutput } from 'services/api/types';
import ScrollableContent from '../ScrollableContent';
import { AnyResult } from 'services/events/types'; import { AnyResult } from 'services/events/types';
import StringOutputPreview from './outputs/StringOutputPreview'; import ScrollableContent from '../ScrollableContent';
import NumberOutputPreview from './outputs/NumberOutputPreview'; import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -21,11 +22,16 @@ const selector = createSelector(
(node) => node.id === lastSelectedNodeId (node) => node.id === lastSelectedNodeId
); );
const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type]
: undefined;
const nes = const nes =
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
return { return {
node: lastSelectedNode, node: lastSelectedNode,
template: lastSelectedNodeTemplate,
nes, nes,
}; };
}, },
@ -33,9 +39,9 @@ const selector = createSelector(
); );
const InspectorOutputsTab = () => { const InspectorOutputsTab = () => {
const { node, nes } = useAppSelector(selector); const { node, template, nes } = useAppSelector(selector);
if (!node || !nes) { if (!node || !nes || !isInvocationNode(node)) {
return <IAINoContentFallback label="No node selected" icon={null} />; return <IAINoContentFallback label="No node selected" icon={null} />;
} }
@ -63,33 +69,16 @@ const InspectorOutputsTab = () => {
w: 'full', w: 'full',
}} }}
> >
{nes.outputs.map((result, i) => { {template?.outputType === 'image_output' ? (
if (result.type === 'string_output') { nes.outputs.map((result, i) => (
return ( <ImageOutputPreview
<StringOutputPreview key={getKey(result, i)} output={result} /> key={getKey(result, i)}
); output={result as ImageOutput}
} />
if (result.type === 'float_output') { ))
return ( ) : (
<NumberOutputPreview key={getKey(result, i)} output={result} /> <DataViewer data={nes.outputs} label="Node Outputs" />
); )}
}
if (result.type === 'integer_output') {
return (
<NumberOutputPreview key={getKey(result, i)} output={result} />
);
}
if (result.type === 'image_output') {
return (
<ImageOutputPreview key={getKey(result, i)} output={result} />
);
}
return (
<pre key={getKey(result, i)}>
{JSON.stringify(result, null, 2)}
</pre>
);
})}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
</Box> </Box>

View File

@ -10,6 +10,7 @@ import { memo } from 'react';
import InspectorDataTab from './InspectorDataTab'; import InspectorDataTab from './InspectorDataTab';
import InspectorOutputsTab from './InspectorOutputsTab'; import InspectorOutputsTab from './InspectorOutputsTab';
import InspectorTemplateTab from './InspectorTemplateTab'; import InspectorTemplateTab from './InspectorTemplateTab';
// import InspectorDetailsTab from './InspectorDetailsTab';
const InspectorPanel = () => { const InspectorPanel = () => {
return ( return (
@ -29,12 +30,16 @@ const InspectorPanel = () => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }} sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
> >
<TabList> <TabList>
{/* <Tab>Details</Tab> */}
<Tab>Outputs</Tab> <Tab>Outputs</Tab>
<Tab>Data</Tab> <Tab>Data</Tab>
<Tab>Template</Tab> <Tab>Template</Tab>
</TabList> </TabList>
<TabPanels> <TabPanels>
{/* <TabPanel>
<InspectorDetailsTab />
</TabPanel> */}
<TabPanel> <TabPanel>
<InspectorOutputsTab /> <InspectorOutputsTab />
</TabPanel> </TabPanel>

View File

@ -1,13 +0,0 @@
import { Text } from '@chakra-ui/react';
import { memo } from 'react';
import { FloatOutput, IntegerOutput } from 'services/api/types';
type Props = {
output: IntegerOutput | FloatOutput;
};
const NumberOutputPreview = ({ output }: Props) => {
return <Text>{output.value}</Text>;
};
export default memo(NumberOutputPreview);

View File

@ -1,13 +0,0 @@
import { Text } from '@chakra-ui/react';
import { memo } from 'react';
import { StringOutput } from 'services/api/types';
type Props = {
output: StringOutput;
};
const StringOutputPreview = ({ output }: Props) => {
return <Text>{output.value}</Text>;
};
export default memo(StringOutputPreview);

View File

@ -22,6 +22,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
} }
return map(nodeTemplate.inputs) return map(nodeTemplate.inputs)
.filter((field) => ['any', 'direct'].includes(field.input)) .filter((field) => ['any', 'direct'].includes(field.input))
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name) .map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate'); .filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -143,6 +143,8 @@ export const useBuildNodeData = () => {
isOpen: true, isOpen: true,
label: '', label: '',
notes: '', notes: '',
embedWorkflow: false,
isIntermediate: true,
}, },
}; };

View File

@ -22,6 +22,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
} }
return map(nodeTemplate.inputs) return map(nodeTemplate.inputs)
.filter((field) => field.input === 'connection') .filter((field) => field.input === 'connection')
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name) .map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate'); .filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -0,0 +1,27 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.embedWorkflow;
},
defaultSelectorOptions
),
[nodeId]
);
const embedWorkflow = useAppSelector(selector);
return embedWorkflow;
};

View File

@ -15,7 +15,7 @@ export const useIsIntermediate = (nodeId: string) => {
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return false; return false;
} }
return Boolean(node.data.inputs.is_intermediate?.value); return node.data.isIntermediate;
}, },
defaultSelectorOptions defaultSelectorOptions
), ),

View File

@ -3,7 +3,7 @@ import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { workflowLoaded } from 'features/nodes/store/nodesSlice'; import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { zWorkflow } from 'features/nodes/types/types'; import { zValidatedWorkflow } from 'features/nodes/types/types';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -24,31 +24,30 @@ export const useLoadWorkflowFromFile = () => {
try { try {
const parsedJSON = JSON.parse(String(rawJSON)); const parsedJSON = JSON.parse(String(rawJSON));
const result = zWorkflow.safeParse(parsedJSON); const result = zValidatedWorkflow.safeParse(parsedJSON);
if (!result.success) { if (!result.success) {
const message = fromZodError(result.error, { const { message } = fromZodError(result.error, {
prefix: 'Workflow Validation Error', prefix: 'Workflow Validation Error',
}).toString(); });
logger.error({ error: parseify(result.error) }, message); logger.error({ error: parseify(result.error) }, message);
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: 'Unable to Validate Workflow', title: 'Unable to Validate Workflow',
description: (
<WorkflowValidationErrorContent error={result.error} />
),
status: 'error', status: 'error',
duration: 5000, duration: 5000,
}) })
) )
); );
reader.abort();
return; return;
} }
dispatch(workflowLoaded(result.data.workflow));
dispatch(workflowLoaded(result.data)); if (!result.data.warnings.length) {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
@ -58,9 +57,24 @@ export const useLoadWorkflowFromFile = () => {
) )
); );
reader.abort(); reader.abort();
} catch (error) { return;
}
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded with Warnings',
status: 'warning',
})
)
);
result.data.warnings.forEach(({ message, ...rest }) => {
logger.warn(rest, message);
});
reader.abort();
} catch {
// file reader error // file reader error
if (error) {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
@ -70,7 +84,6 @@ export const useLoadWorkflowFromFile = () => {
) )
); );
} }
}
}; };
reader.readAsText(file); reader.readAsText(file);

View File

@ -0,0 +1,31 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback, useMemo } from 'react';
import { mouseOverNodeChanged } from '../store/nodesSlice';
export const useMouseOverNode = (nodeId: string) => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => nodes.mouseOverNode === nodeId,
defaultSelectorOptions
),
[nodeId]
);
const isMouseOverNode = useAppSelector(selector);
const handleMouseOver = useCallback(() => {
!isMouseOverNode && dispatch(mouseOverNodeChanged(nodeId));
}, [dispatch, nodeId, isMouseOverNode]);
const handleMouseOut = useCallback(() => {
isMouseOverNode && dispatch(mouseOverNodeChanged(null));
}, [dispatch, isMouseOverNode]);
return { isMouseOverNode, handleMouseOver, handleMouseOut };
};

View File

@ -21,6 +21,7 @@ export const useOutputFieldNames = (nodeId: string) => {
return []; return [];
} }
return map(nodeTemplate.outputs) return map(nodeTemplate.outputs)
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0)) .sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name) .map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate'); .filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -1,5 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es'; import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
import { import {
addEdge, addEdge,
applyEdgeChanges, applyEdgeChanges,
@ -18,7 +18,7 @@ import {
Viewport, Viewport,
} from 'reactflow'; } from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { sessionInvoked } from 'services/api/thunks/session'; import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
import { ImageField } from 'services/api/types'; import { ImageField } from 'services/api/types';
import { import {
appSocketGeneratorProgress, appSocketGeneratorProgress,
@ -102,6 +102,7 @@ export const initialNodesState: NodesState = {
nodeExecutionStates: {}, nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 }, viewport: { x: 0, y: 0, zoom: 1 },
mouseOverField: null, mouseOverField: null,
mouseOverNode: null,
nodesToCopy: [], nodesToCopy: [],
edgesToCopy: [], edgesToCopy: [],
selectionMode: SelectionMode.Partial, selectionMode: SelectionMode.Partial,
@ -245,6 +246,34 @@ const nodesSlice = createSlice({
} }
field.label = label; field.label = label;
}, },
nodeEmbedWorkflowChanged: (
state,
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
) => {
const { nodeId, embedWorkflow } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.embedWorkflow = embedWorkflow;
},
nodeIsIntermediateChanged: (
state,
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
) => {
const { nodeId, isIntermediate } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.isIntermediate = isIntermediate;
},
nodeIsOpenChanged: ( nodeIsOpenChanged: (
state, state,
action: PayloadAction<{ nodeId: string; isOpen: boolean }> action: PayloadAction<{ nodeId: string; isOpen: boolean }>
@ -561,7 +590,7 @@ const nodesSlice = createSlice({
nodeEditorReset: (state) => { nodeEditorReset: (state) => {
state.nodes = []; state.nodes = [];
state.edges = []; state.edges = [];
state.workflow.exposedFields = []; state.workflow = cloneDeep(initialWorkflow);
}, },
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => { shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload; state.shouldValidateGraph = action.payload;
@ -637,6 +666,9 @@ const nodesSlice = createSlice({
) => { ) => {
state.mouseOverField = action.payload; state.mouseOverField = action.payload;
}, },
mouseOverNodeChanged: (state, action: PayloadAction<string | null>) => {
state.mouseOverNode = action.payload;
},
selectedAll: (state) => { selectedAll: (state) => {
state.nodes = applyNodeChanges( state.nodes = applyNodeChanges(
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })), state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
@ -790,6 +822,13 @@ const nodesSlice = createSlice({
nes.outputs = []; nes.outputs = [];
}); });
}); });
builder.addCase(sessionCanceled.fulfilled, (state) => {
map(state.nodeExecutionStates, (nes) => {
if (nes.status === NodeStatus.IN_PROGRESS) {
nes.status = NodeStatus.PENDING;
}
});
});
}, },
}); });
@ -850,6 +889,9 @@ export const {
addNodePopoverClosed, addNodePopoverClosed,
addNodePopoverToggled, addNodePopoverToggled,
selectionModeChanged, selectionModeChanged,
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
mouseOverNodeChanged,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;

View File

@ -35,6 +35,7 @@ export type NodesState = {
viewport: Viewport; viewport: Viewport;
isReady: boolean; isReady: boolean;
mouseOverField: FieldIdentifier | null; mouseOverField: FieldIdentifier | null;
mouseOverNode: string | null;
nodesToCopy: Node<NodeData>[]; nodesToCopy: Node<NodeData>[];
edgesToCopy: Edge<InvocationEdgeExtra>[]; edgesToCopy: Edge<InvocationEdgeExtra>[];
isAddNodePopoverOpen: boolean; isAddNodePopoverOpen: boolean;

View File

@ -62,7 +62,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
DenoiseMaskField: { DenoiseMaskField: {
title: 'Denoise Mask', title: 'Denoise Mask',
description: 'Denoise Mask may be passed between nodes', description: 'Denoise Mask may be passed between nodes',
color: 'red.700', color: 'base.500',
}, },
LatentsField: { LatentsField: {
title: 'Latents', title: 'Latents',
@ -174,11 +174,6 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Color Collection', title: 'Color Collection',
description: 'A collection of colors.', description: 'A collection of colors.',
}, },
FilePath: {
color: 'base.500',
title: 'File Path',
description: 'A path to a file.',
},
ONNXModelField: { ONNXModelField: {
color: 'base.500', color: 'base.500',
title: 'ONNX Model', title: 'ONNX Model',

View File

@ -1,12 +1,16 @@
import { store } from 'app/store/store';
import { import {
SchedulerParam, SchedulerParam,
zBaseModel, zBaseModel,
zMainOrOnnxModel, zMainOrOnnxModel,
zSDXLRefinerModel,
zScheduler, zScheduler,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow'; import { Node } from 'reactflow';
import { JsonObject } from 'type-fest';
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types'; import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
import { import {
AnyInvocationType, AnyInvocationType,
@ -98,7 +102,6 @@ export const zFieldType = z.enum([
// endregion // endregion
// region Misc // region Misc
'FilePath',
'enum', 'enum',
'Scheduler', 'Scheduler',
// endregion // endregion
@ -106,8 +109,17 @@ export const zFieldType = z.enum([
export type FieldType = z.infer<typeof zFieldType>; export type FieldType = z.infer<typeof zFieldType>;
export const zReservedFieldType = z.enum([
'WorkflowField',
'IsIntermediate',
'MetadataField',
]);
export type ReservedFieldType = z.infer<typeof zReservedFieldType>;
export const isFieldType = (value: unknown): value is FieldType => export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success; zFieldType.safeParse(value).success ||
zReservedFieldType.safeParse(value).success;
/** /**
* An input field template is generated on each page load from the OpenAPI schema. * An input field template is generated on each page load from the OpenAPI schema.
@ -215,7 +227,7 @@ export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('integer'), type: z.literal('integer'),
value: z.number().optional(), value: z.number().int().optional(),
}); });
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>; export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
@ -641,6 +653,11 @@ export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
type: 'Scheduler'; type: 'Scheduler';
}; };
export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'WorkflowField';
};
export const isInputFieldValue = ( export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); ): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
@ -661,6 +678,7 @@ export type TypeHints = {
export type InvocationSchemaExtra = { export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation output: OpenAPIV3.ReferenceObject; // the output of the invocation
title: string; title: string;
category?: string;
tags?: string[]; tags?: string[];
properties: Omit< properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']> & NonNullable<OpenAPIV3.SchemaObject['properties']> &
@ -737,6 +755,48 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
export const zCoreMetadata = z
.object({
app_version: z.string().nullish(),
generation_mode: z.string().nullish(),
created_by: z.string().nullish(),
positive_prompt: z.string().nullish(),
negative_prompt: z.string().nullish(),
width: z.number().int().nullish(),
height: z.number().int().nullish(),
seed: z.number().int().nullish(),
rand_device: z.string().nullish(),
cfg_scale: z.number().nullish(),
steps: z.number().int().nullish(),
scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(),
model: zMainOrOnnxModel.nullish(),
controlnets: z.array(zControlField).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField,
weight: z.number(),
})
)
.nullish(),
vae: zVaeModelField.nullish(),
strength: z.number().nullish(),
init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.nullish(),
refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_store: z.number().nullish(),
refiner_negative_aesthetic_store: z.number().nullish(),
refiner_start: z.number().nullish(),
})
.catchall(z.record(z.any()));
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
export const zInvocationNodeData = z.object({ export const zInvocationNodeData = z.object({
id: z.string().trim().min(1), id: z.string().trim().min(1),
// no easy way to build this dynamically, and we don't want to anyways, because this will be used // no easy way to build this dynamically, and we don't want to anyways, because this will be used
@ -747,6 +807,8 @@ export const zInvocationNodeData = z.object({
label: z.string(), label: z.string(),
isOpen: z.boolean(), isOpen: z.boolean(),
notes: z.string(), notes: z.string(),
embedWorkflow: z.boolean(),
isIntermediate: z.boolean(),
}); });
// Massage this to get better type safety while developing // Massage this to get better type safety while developing
@ -767,28 +829,38 @@ export const zNotesNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>; export type NotesNodeData = z.infer<typeof zNotesNodeData>;
const zPosition = z
.object({
x: z.number(),
y: z.number(),
})
.default({ x: 0, y: 0 });
const zDimension = z.number().gt(0).nullish();
export const zWorkflowInvocationNode = z.object({ export const zWorkflowInvocationNode = z.object({
id: z.string().trim().min(1), id: z.string().trim().min(1),
type: z.literal('invocation'), type: z.literal('invocation'),
data: zInvocationNodeData, data: zInvocationNodeData,
width: z.number().gt(0), width: zDimension,
height: z.number().gt(0), height: zDimension,
position: z.object({ position: zPosition,
x: z.number(),
y: z.number(),
}),
}); });
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
export const isWorkflowInvocationNode = (
val: unknown
): val is WorkflowInvocationNode =>
zWorkflowInvocationNode.safeParse(val).success;
export const zWorkflowNotesNode = z.object({ export const zWorkflowNotesNode = z.object({
id: z.string().trim().min(1), id: z.string().trim().min(1),
type: z.literal('notes'), type: z.literal('notes'),
data: zNotesNodeData, data: zNotesNodeData,
width: z.number().gt(0), width: zDimension,
height: z.number().gt(0), height: zDimension,
position: z.object({ position: zPosition,
x: z.number(),
y: z.number(),
}),
}); });
export const zWorkflowNode = z.discriminatedUnion('type', [ export const zWorkflowNode = z.discriminatedUnion('type', [
@ -798,14 +870,25 @@ export const zWorkflowNode = z.discriminatedUnion('type', [
export type WorkflowNode = z.infer<typeof zWorkflowNode>; export type WorkflowNode = z.infer<typeof zWorkflowNode>;
export const zWorkflowEdge = z.object({ export const zDefaultWorkflowEdge = z.object({
source: z.string().trim().min(1), source: z.string().trim().min(1),
sourceHandle: z.string().trim().min(1), sourceHandle: z.string().trim().min(1),
target: z.string().trim().min(1), target: z.string().trim().min(1),
targetHandle: z.string().trim().min(1), targetHandle: z.string().trim().min(1),
id: z.string().trim().min(1), id: z.string().trim().min(1),
type: z.enum(['default', 'collapsed']), type: z.literal('default'),
}); });
export const zCollapsedWorkflowEdge = z.object({
source: z.string().trim().min(1),
target: z.string().trim().min(1),
id: z.string().trim().min(1),
type: z.literal('collapsed'),
});
export const zWorkflowEdge = z.union([
zDefaultWorkflowEdge,
zCollapsedWorkflowEdge,
]);
export const zFieldIdentifier = z.object({ export const zFieldIdentifier = z.object({
nodeId: z.string().trim().min(1), nodeId: z.string().trim().min(1),
@ -828,21 +911,92 @@ export const zSemVer = z.string().refine((val) => {
export type SemVer = z.infer<typeof zSemVer>; export type SemVer = z.infer<typeof zSemVer>;
export type WorkflowWarning = {
message: string;
issues: string[];
data: JsonObject;
};
export const zWorkflow = z.object({ export const zWorkflow = z.object({
name: z.string(), name: z.string().default(''),
author: z.string(), author: z.string().default(''),
description: z.string(), description: z.string().default(''),
version: z.string(), version: z.string().default(''),
contact: z.string(), contact: z.string().default(''),
tags: z.string(), tags: z.string().default(''),
notes: z.string(), notes: z.string().default(''),
nodes: z.array(zWorkflowNode), nodes: z.array(zWorkflowNode).default([]),
edges: z.array(zWorkflowEdge), edges: z.array(zWorkflowEdge).default([]),
exposedFields: z.array(zFieldIdentifier), exposedFields: z.array(zFieldIdentifier).default([]),
meta: z
.object({
version: zSemVer,
})
.default({ version: '1.0.0' }),
});
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
const nodeTemplates = store.getState().nodes.nodeTemplates;
const { nodes, edges } = workflow;
const warnings: WorkflowWarning[] = [];
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
const keyedNodes = keyBy(invocationNodes, 'id');
invocationNodes.forEach((node, i) => {
const nodeTemplate = nodeTemplates[node.data.type];
if (!nodeTemplate) {
warnings.push({
message: `Node "${node.data.label || node.data.id}" skipped`,
issues: [`Unable to find template for type "${node.data.type}"`],
data: node,
});
delete nodes[i];
}
});
edges.forEach((edge, i) => {
const sourceNode = keyedNodes[edge.source];
const targetNode = keyedNodes[edge.target];
const issues: string[] = [];
if (!sourceNode) {
issues.push(`Output node ${edge.source} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.sourceHandle in sourceNode.data.outputs)
) {
issues.push(
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
);
}
if (!targetNode) {
issues.push(`Input node ${edge.target} does not exist`);
} else if (
edge.type === 'default' &&
!(edge.targetHandle in targetNode.data.inputs)
) {
issues.push(
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
);
}
if (issues.length) {
delete edges[i];
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
warnings.push({
message: `Edge "${src} -> ${tgt}" skipped`,
issues,
data: edge,
});
}
});
return { workflow, warnings };
}); });
export type Workflow = z.infer<typeof zWorkflow>; export type Workflow = z.infer<typeof zWorkflow>;
export type ImageMetadataAndWorkflow = {
metadata?: CoreMetadata;
workflow?: Workflow;
};
export type CurrentImageNodeData = { export type CurrentImageNodeData = {
id: string; id: string;
type: 'current_image'; type: 'current_image';

View File

@ -1,7 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger';
import { stateSelector } from 'app/store/store';
import { NodesState } from '../store/types'; import { NodesState } from '../store/types';
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types'; import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
import { fromZodError } from 'zod-validation-error';
import { parseify } from 'common/util/serialize';
export const buildWorkflow = (nodesState: NodesState): Workflow => { export const buildWorkflow = (nodesState: NodesState): Workflow => {
const { workflow: workflowMeta, nodes, edges } = nodesState; const { workflow: workflowMeta, nodes, edges } = nodesState;
@ -14,6 +15,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
nodes.forEach((node) => { nodes.forEach((node) => {
const result = zWorkflowNode.safeParse(node); const result = zWorkflowNode.safeParse(node);
if (!result.success) { if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse node',
});
logger('nodes').warn({ node: parseify(node) }, message);
return; return;
} }
workflow.nodes.push(result.data); workflow.nodes.push(result.data);
@ -22,6 +27,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
edges.forEach((edge) => { edges.forEach((edge) => {
const result = zWorkflowEdge.safeParse(edge); const result = zWorkflowEdge.safeParse(edge);
if (!result.success) { if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse edge',
});
logger('nodes').warn({ edge: parseify(edge) }, message);
return; return;
} }
workflow.edges.push(result.data); workflow.edges.push(result.data);
@ -29,7 +38,3 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
return workflow; return workflow;
}; };
export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
buildWorkflow(nodes)
);

View File

@ -28,7 +28,6 @@ import {
UNetInputFieldTemplate, UNetInputFieldTemplate,
VaeInputFieldTemplate, VaeInputFieldTemplate,
VaeModelInputFieldTemplate, VaeModelInputFieldTemplate,
isFieldType,
} from '../types/types'; } from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description'; export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -422,9 +421,7 @@ const buildSchedulerInputFieldTemplate = ({
return template; return template;
}; };
export const getFieldType = ( export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
schemaObject: InvocationFieldSchema
): FieldType => {
let fieldType = ''; let fieldType = '';
const { ui_type } = schemaObject; const { ui_type } = schemaObject;
@ -460,10 +457,6 @@ export const getFieldType = (
} }
} }
if (!isFieldType(fieldType)) {
throw `Field type "${fieldType}" is unknown!`;
}
return fieldType; return fieldType;
}; };
@ -475,12 +468,9 @@ export const getFieldType = (
export const buildInputFieldTemplate = ( export const buildInputFieldTemplate = (
nodeSchema: InvocationSchemaObject, nodeSchema: InvocationSchemaObject,
fieldSchema: InvocationFieldSchema, fieldSchema: InvocationFieldSchema,
name: string name: string,
fieldType: FieldType
) => { ) => {
// console.log('input', schemaObject);
const fieldType = getFieldType(fieldSchema);
// console.log('input fieldType', fieldType);
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema; const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
const extra = { const extra = {

View File

@ -0,0 +1,33 @@
import * as png from '@stevebel/png';
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
zWorkflow,
} from 'features/nodes/types/types';
import { get } from 'lodash-es';
export const getMetadataAndWorkflowFromImageBlob = async (
image: Blob
): Promise<ImageMetadataAndWorkflow> => {
const data: ImageMetadataAndWorkflow = {};
const buffer = await image.arrayBuffer();
const text = png.decode(buffer).text;
const rawMetadata = get(text, 'invokeai_metadata');
if (rawMetadata) {
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
if (metadataResult.success) {
data.metadata = metadataResult.data;
}
}
const rawWorkflow = get(text, 'invokeai_workflow');
if (rawWorkflow) {
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
if (workflowResult.success) {
data.workflow = workflowResult.data;
}
}
return data;
};

View File

@ -11,10 +11,10 @@ import {
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
@ -41,7 +41,9 @@ export const addSDXLLoRAsToGraph = (
// Handle Seamless Plugs // Handle Seamless Plugs
const unetLoaderId = modelLoaderNodeId; const unetLoaderId = modelLoaderNodeId;
let clipLoaderId = modelLoaderNodeId; let clipLoaderId = modelLoaderNodeId;
if ([SEAMLESS, REFINER_SEAMLESS].includes(modelLoaderNodeId)) { if (
[SEAMLESS, SDXL_REFINER_INPAINT_CREATE_MASK].includes(modelLoaderNodeId)
) {
clipLoaderId = SDXL_MODEL_LOADER; clipLoaderId = SDXL_MODEL_LOADER;
} }

View File

@ -1,24 +1,28 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
CreateDenoiseMaskInvocation,
ImageDTO,
MetadataAccumulatorInvocation, MetadataAccumulatorInvocation,
SeamlessModeInvocation, SeamlessModeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from '../../types/types'; import { NonNullableGraph } from '../../types/types';
import { import {
CANVAS_OUTPUT, CANVAS_OUTPUT,
INPAINT_IMAGE_RESIZE_UP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MASK_BLUR, MASK_BLUR,
METADATA_ACCUMULATOR, METADATA_ACCUMULATOR,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_DENOISE_LATENTS, SDXL_REFINER_DENOISE_LATENTS,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_REFINER_MODEL_LOADER, SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING, SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING, SDXL_REFINER_POSITIVE_CONDITIONING,
SDXL_REFINER_SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -26,7 +30,8 @@ export const addSDXLRefinerToGraph = (
state: RootState, state: RootState,
graph: NonNullableGraph, graph: NonNullableGraph,
baseNodeId: string, baseNodeId: string,
modelLoaderNodeId?: string modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO
): void => { ): void => {
const { const {
refinerModel, refinerModel,
@ -38,7 +43,12 @@ export const addSDXLRefinerToGraph = (
refinerStart, refinerStart,
} = state.sdxl; } = state.sdxl;
const { seamlessXAxis, seamlessYAxis } = state.generation; const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
if (!refinerModel) { if (!refinerModel) {
return; return;
@ -108,8 +118,8 @@ export const addSDXLRefinerToGraph = (
// Add Seamless To Refiner // Add Seamless To Refiner
if (seamlessXAxis || seamlessYAxis) { if (seamlessXAxis || seamlessYAxis) {
graph.nodes[REFINER_SEAMLESS] = { graph.nodes[SDXL_REFINER_SEAMLESS] = {
id: REFINER_SEAMLESS, id: SDXL_REFINER_SEAMLESS,
type: 'seamless', type: 'seamless',
seamless_x: seamlessXAxis, seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis, seamless_y: seamlessYAxis,
@ -122,13 +132,23 @@ export const addSDXLRefinerToGraph = (
field: 'unet', field: 'unet',
}, },
destination: { destination: {
node_id: REFINER_SEAMLESS, node_id: SDXL_REFINER_SEAMLESS,
field: 'unet', field: 'unet',
}, },
}, },
{ {
source: { source: {
node_id: REFINER_SEAMLESS, node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: SDXL_REFINER_SEAMLESS,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_SEAMLESS,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -203,6 +223,61 @@ export const addSDXLRefinerToGraph = (
} }
); );
if (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: SDXL_REFINER_INPAINT_CREATE_MASK,
is_intermediate: true,
fp32: vaePrecision === 'fp32' ? true : false,
};
if (isUsingScaledDimensions) {
graph.edges.push({
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'image',
},
});
} else {
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
...(graph.nodes[
SDXL_REFINER_INPAINT_CREATE_MASK
] as CreateDenoiseMaskInvocation),
image: canvasInitImage,
};
}
graph.edges.push(
{
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'mask',
},
},
{
source: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'denoise_mask',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'denoise_mask',
},
}
);
}
if ( if (
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH || graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH
@ -213,7 +288,7 @@ export const addSDXLRefinerToGraph = (
field: 'latents', field: 'latents',
}, },
destination: { destination: {
node_id: CANVAS_OUTPUT, node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
field: 'latents', field: 'latents',
}, },
}); });
@ -229,20 +304,4 @@ export const addSDXLRefinerToGraph = (
}, },
}); });
} }
if (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: MASK_BLUR,
field: 'image',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'mask',
},
});
}
}; };

View File

@ -20,6 +20,7 @@ import {
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_REFINER_INPAINT_CREATE_MASK,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
VAE_LOADER, VAE_LOADER,
@ -32,6 +33,7 @@ export const addVAEToGraph = (
): void => { ): void => {
const { vae } = state.generation; const { vae } = state.generation;
const { boundingBoxScaleMethod } = state.canvas; const { boundingBoxScaleMethod } = state.canvas;
const { shouldUseSDXLRefiner } = state.sdxl;
const isUsingScaledDimensions = ['auto', 'manual'].includes( const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod boundingBoxScaleMethod
@ -146,6 +148,24 @@ export const addVAEToGraph = (
); );
} }
if (shouldUseSDXLRefiner) {
if (
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
) {
graph.edges.push({
source: {
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
},
destination: {
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
field: 'vae',
},
});
}
}
if (vae && metadataAccumulator) { if (vae && metadataAccumulator) {
metadataAccumulator.vae = vae; metadataAccumulator.vae = vae;
} }

View File

@ -20,10 +20,10 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH, SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -367,8 +367,15 @@ export const buildCanvasSDXLImageToImageGraph = (
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(
modelLoaderNodeId = REFINER_SEAMLESS; state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -36,10 +36,10 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_INPAINT_GRAPH, SDXL_CANVAS_INPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -628,9 +628,12 @@ export const buildCanvasSDXLInpaintGraph = (
state, state,
graph, graph,
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId modelLoaderNodeId,
canvasInitImage
); );
modelLoaderNodeId = REFINER_SEAMLESS; if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -41,10 +41,10 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
REFINER_SEAMLESS,
SDXL_CANVAS_OUTPAINT_GRAPH, SDXL_CANVAS_OUTPAINT_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -766,9 +766,12 @@ export const buildCanvasSDXLOutpaintGraph = (
state, state,
graph, graph,
CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_COHERENCE_DENOISE_LATENTS,
modelLoaderNodeId modelLoaderNodeId,
canvasInitImage
); );
modelLoaderNodeId = REFINER_SEAMLESS; if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -22,10 +22,10 @@ import {
NOISE, NOISE,
ONNX_MODEL_LOADER, ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH, SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -347,8 +347,15 @@ export const buildCanvasSDXLTextToImageGraph = (
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(
modelLoaderNodeId = REFINER_SEAMLESS; state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// add LoRA support // add LoRA support

View File

@ -21,11 +21,11 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
RESIZE, RESIZE,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH, SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt'; import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -368,7 +368,9 @@ export const buildLinearSDXLImageToImageGraph = (
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS; if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -16,9 +16,9 @@ import {
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_DENOISE_LATENTS, SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER, SDXL_MODEL_LOADER,
SDXL_REFINER_SEAMLESS,
SDXL_TEXT_TO_IMAGE_GRAPH, SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS, SEAMLESS,
} from './constants'; } from './constants';
@ -261,7 +261,9 @@ export const buildLinearSDXLTextToImageGraph = (
// Add Refiner if enabled // Add Refiner if enabled
if (shouldUseSDXLRefiner) { if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS; if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
} }
// optionally add custom VAE // optionally add custom VAE

View File

@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { buildWorkflow } from '../buildWorkflow';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -34,12 +35,13 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState; const { nodes, edges } = nodesState;
const filteredNodes = nodes.filter(isInvocationNode); const filteredNodes = nodes.filter(isInvocationNode);
const workflowJSON = JSON.stringify(buildWorkflow(nodesState));
// Reduce the node editor nodes into invocation graph nodes // Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>( const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
(nodesAccumulator, node) => { (nodesAccumulator, node) => {
const { id, data } = node; const { id, data } = node;
const { type, inputs } = data; const { type, inputs, isIntermediate, embedWorkflow } = data;
// Transform each node's inputs to simple key-value pairs // Transform each node's inputs to simple key-value pairs
const transformedInputs = reduce( const transformedInputs = reduce(
@ -58,8 +60,14 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
type, type,
id, id,
...transformedInputs, ...transformedInputs,
is_intermediate: isIntermediate,
}; };
if (embedWorkflow) {
// add the workflow to the node
Object.assign(graphNode, { workflow: workflowJSON });
}
// Add it to the nodes object // Add it to the nodes object
Object.assign(nodesAccumulator, { Object.assign(nodesAccumulator, {
[id]: graphNode, [id]: graphNode,

View File

@ -56,8 +56,9 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
export const SDXL_REFINER_NEGATIVE_CONDITIONING = export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning'; 'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents'; export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless'; export const SEAMLESS = 'seamless';
export const REFINER_SEAMLESS = 'refiner_seamless'; export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -4,10 +4,12 @@ import { reduce } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
import { AnyInvocationType } from 'services/events/types'; import { AnyInvocationType } from 'services/events/types';
import { import {
FieldType,
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
OutputFieldTemplate, OutputFieldTemplate,
isFieldType,
isInvocationFieldSchema, isInvocationFieldSchema,
isInvocationOutputSchemaObject, isInvocationOutputSchemaObject,
isInvocationSchemaObject, isInvocationSchemaObject,
@ -16,23 +18,35 @@ import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata']; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type']; const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
const RESERVED_FIELD_TYPES = [
'WorkflowField',
'MetadataField',
'IsIntermediate',
];
const invocationDenylist: AnyInvocationType[] = [ const invocationDenylist: AnyInvocationType[] = [
'graph', 'graph',
'metadata_accumulator', 'metadata_accumulator',
]; ];
const isAllowedInputField = (nodeType: string, fieldName: string) => { const isReservedInputField = (nodeType: string, fieldName: string) => {
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
return false; return true;
} }
if (nodeType === 'collect' && fieldName === 'collection') { if (nodeType === 'collect' && fieldName === 'collection') {
return false; return true;
} }
if (nodeType === 'iterate' && fieldName === 'index') { if (nodeType === 'iterate' && fieldName === 'index') {
return false;
}
return true; return true;
}
return false;
};
const isReservedFieldType = (fieldType: FieldType) => {
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
return true;
}
return false;
}; };
const isAllowedOutputField = (nodeType: string, fieldName: string) => { const isAllowedOutputField = (nodeType: string, fieldName: string) => {
@ -54,7 +68,7 @@ export const parseSchema = (
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>
>((acc, schema) => { >((invocationsAccumulator, schema) => {
const type = schema.properties.type.default; const type = schema.properties.type.default;
const title = schema.title.replace('Invocation', ''); const title = schema.title.replace('Invocation', '');
const tags = schema.tags ?? []; const tags = schema.tags ?? [];
@ -62,10 +76,14 @@ export const parseSchema = (
const inputs = reduce( const inputs = reduce(
schema.properties, schema.properties,
(inputsAccumulator, property, propertyName) => { (
if (!isAllowedInputField(type, propertyName)) { inputsAccumulator: Record<string, InputFieldTemplate>,
property,
propertyName
) => {
if (isReservedInputField(type, propertyName)) {
logger('nodes').trace( logger('nodes').trace(
{ type, propertyName, property: parseify(property) }, { node: type, fieldName: propertyName, field: parseify(property) },
'Skipped reserved input field' 'Skipped reserved input field'
); );
return inputsAccumulator; return inputsAccumulator;
@ -73,37 +91,80 @@ export const parseSchema = (
if (!isInvocationFieldSchema(property)) { if (!isInvocationFieldSchema(property)) {
logger('nodes').warn( logger('nodes').warn(
{ type, propertyName, property: parseify(property) }, { node: type, propertyName, property: parseify(property) },
'Unhandled input property' 'Unhandled input property'
); );
return inputsAccumulator; return inputsAccumulator;
} }
const field = buildInputFieldTemplate(schema, property, propertyName); const fieldType = getFieldType(property);
if (field) { if (!isFieldType(fieldType)) {
inputsAccumulator[propertyName] = field; logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping unknown input field type'
);
return inputsAccumulator;
} }
if (isReservedFieldType(fieldType)) {
logger('nodes').trace(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping reserved field type'
);
return inputsAccumulator;
}
const field = buildInputFieldTemplate(
schema,
property,
propertyName,
fieldType
);
if (!field) {
logger('nodes').debug(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping input field with no template'
);
return inputsAccumulator;
}
inputsAccumulator[propertyName] = field;
return inputsAccumulator; return inputsAccumulator;
}, },
{} as Record<string, InputFieldTemplate> {}
); );
const outputSchemaName = schema.output.$ref.split('/').pop(); const outputSchemaName = schema.output.$ref.split('/').pop();
if (!outputSchemaName) { if (!outputSchemaName) {
logger('nodes').error( logger('nodes').warn(
{ outputRefObject: parseify(schema.output) }, { outputRefObject: parseify(schema.output) },
'No output schema name found in ref object' 'No output schema name found in ref object'
); );
throw 'No output schema name found in ref object'; return invocationsAccumulator;
} }
const outputSchema = openAPI.components?.schemas?.[outputSchemaName]; const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
if (!outputSchema) { if (!outputSchema) {
logger('nodes').error({ outputSchemaName }, 'Output schema not found'); logger('nodes').warn({ outputSchemaName }, 'Output schema not found');
throw 'Output schema not found'; return invocationsAccumulator;
} }
if (!isInvocationOutputSchemaObject(outputSchema)) { if (!isInvocationOutputSchemaObject(outputSchema)) {
@ -111,7 +172,7 @@ export const parseSchema = (
{ outputSchema: parseify(outputSchema) }, { outputSchema: parseify(outputSchema) },
'Invalid output schema' 'Invalid output schema'
); );
throw 'Invalid output schema'; return invocationsAccumulator;
} }
const outputType = outputSchema.properties.type.default; const outputType = outputSchema.properties.type.default;
@ -136,6 +197,15 @@ export const parseSchema = (
} }
const fieldType = getFieldType(property); const fieldType = getFieldType(property);
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{ fieldName: propertyName, fieldType, field: parseify(property) },
'Skipping unknown output field type'
);
return outputsAccumulator;
}
outputsAccumulator[propertyName] = { outputsAccumulator[propertyName] = {
fieldKind: 'output', fieldKind: 'output',
name: propertyName, name: propertyName,
@ -162,9 +232,9 @@ export const parseSchema = (
outputType, outputType,
}; };
Object.assign(acc, { [type]: invocation }); Object.assign(invocationsAccumulator, { [type]: invocation });
return acc; return invocationsAccumulator;
}, {}); }, {});
return invocations; return invocations;

View File

@ -68,7 +68,7 @@ const ParamControlNetCollapse = () => {
} }
return ( return (
<IAICollapse label="ControlNet" activeLabel={activeLabel}> <IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}> <Flex sx={{ flexDir: 'column', gap: 2 }}>
<Flex <Flex
sx={{ sx={{

View File

@ -1,4 +1,5 @@
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { CoreMetadata } from 'features/nodes/types/types';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useAppToaster } from '../../../app/components/Toaster'; import { useAppToaster } from '../../../app/components/Toaster';
@ -64,7 +65,7 @@ export const usePreselectedImage = () => {
if (selectedImage.action === 'useAllParameters') { if (selectedImage.action === 'useAllParameters') {
setImageNameForMetadata(selectedImage?.imageName); setImageNameForMetadata(selectedImage?.imageName);
if (selectedImageMetadata) { if (selectedImageMetadata) {
recallAllParameters(selectedImageMetadata.metadata); recallAllParameters(selectedImageMetadata.metadata as CoreMetadata);
} }
} }
}, },

View File

@ -1,5 +1,6 @@
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { CoreMetadata } from 'features/nodes/types/types';
import { import {
refinerModelChanged, refinerModelChanged,
setNegativeStylePromptSDXL, setNegativeStylePromptSDXL,
@ -13,7 +14,7 @@ import {
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageDTO, UnsafeImageMetadata } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions'; import { initialImageSelected, modelSelected } from '../store/actions';
import { import {
setCfgScale, setCfgScale,
@ -317,7 +318,7 @@ export const useRecallParameters = () => {
); );
const recallAllParameters = useCallback( const recallAllParameters = useCallback(
(metadata: UnsafeImageMetadata['metadata'] | undefined) => { (metadata: CoreMetadata | undefined) => {
if (!metadata) { if (!metadata) {
allParameterNotSetToast(); allParameterNotSetToast();
return; return;

View File

@ -29,11 +29,13 @@ export const $projectId = atom<string | undefined>();
* @example * @example
* const { get, post, del } = $client.get(); * const { get, post, del } = $client.get();
*/ */
export const $client = computed([$authToken, $baseUrl, $projectId], (authToken, baseUrl, projectId) => export const $client = computed(
[$authToken, $baseUrl, $projectId],
(authToken, baseUrl, projectId) =>
createClient<paths>({ createClient<paths>({
headers: { headers: {
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}), ...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
...(projectId ? { "project-id": projectId } : {}) ...(projectId ? { 'project-id': projectId } : {}),
}, },
// do not include `api/v1` in the base url for this client // do not include `api/v1` in the base url for this client
baseUrl: `${baseUrl ?? ''}`, baseUrl: `${baseUrl ?? ''}`,

View File

@ -19,7 +19,7 @@ export const boardsApi = api.injectEndpoints({
*/ */
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({ listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
query: (arg) => ({ url: 'boards/', params: arg }), query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => { providesTags: (result) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
@ -42,7 +42,7 @@ export const boardsApi = api.injectEndpoints({
url: 'boards/', url: 'boards/',
params: { all: true }, params: { all: true },
}), }),
providesTags: (result, error, arg) => { providesTags: (result) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];

View File

@ -6,7 +6,8 @@ import {
IMAGE_CATEGORIES, IMAGE_CATEGORIES,
IMAGE_LIMIT, IMAGE_LIMIT,
} from 'features/gallery/store/types'; } from 'features/gallery/store/types';
import { keyBy } from 'lodash'; import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
import { keyBy } from 'lodash-es';
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { components, paths } from '../schema'; import { components, paths } from '../schema';
import { import {
@ -26,6 +27,7 @@ import {
imagesSelectors, imagesSelectors,
} from '../util'; } from '../util';
import { boardsApi } from './boards'; import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
export const imagesApi = api.injectEndpoints({ export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -113,6 +115,20 @@ export const imagesApi = api.injectEndpoints({
], ],
keepUnusedDataFor: 86400, // 24 hours keepUnusedDataFor: 86400, // 24 hours
}), }),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({
url: `images/i/${image_name}/full`,
responseHandler: async (res) => {
return await res.blob();
},
}),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name },
],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
keepUnusedDataFor: 86400, // 24 hours
}),
clearIntermediates: build.mutation<number, void>({ clearIntermediates: build.mutation<number, void>({
query: () => ({ url: `images/clear-intermediates`, method: 'POST' }), query: () => ({ url: `images/clear-intermediates`, method: 'POST' }),
invalidatesTags: ['IntermediatesCount'], invalidatesTags: ['IntermediatesCount'],
@ -357,7 +373,7 @@ export const imagesApi = api.injectEndpoints({
], ],
async onQueryStarted( async onQueryStarted(
{ imageDTO, session_id }, { imageDTO, session_id },
{ dispatch, queryFulfilled, getState } { dispatch, queryFulfilled }
) { ) {
/** /**
* Cache changes for `changeImageSessionId`: * Cache changes for `changeImageSessionId`:
@ -432,7 +448,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name) data.updated_image_names.includes(i.image_name)
); );
if (!updatedImages[0]) return; if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category // assume all images are on the same board/category
const categories = getCategories(updatedImages[0]); const categories = getCategories(updatedImages[0]);
@ -544,7 +562,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name) data.updated_image_names.includes(i.image_name)
); );
if (!updatedImages[0]) return; if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category // assume all images are on the same board/category
const categories = getCategories(updatedImages[0]); const categories = getCategories(updatedImages[0]);
const boardId = updatedImages[0].board_id; const boardId = updatedImages[0].board_id;
@ -645,17 +665,7 @@ export const imagesApi = api.injectEndpoints({
}, },
}; };
}, },
async onQueryStarted( async onQueryStarted(_, { dispatch, queryFulfilled }) {
{
file,
image_category,
is_intermediate,
postUploadAction,
session_id,
board_id,
},
{ dispatch, queryFulfilled }
) {
try { try {
/** /**
* NOTE: PESSIMISTIC UPDATE * NOTE: PESSIMISTIC UPDATE
@ -712,7 +722,7 @@ export const imagesApi = api.injectEndpoints({
deleteBoard: build.mutation<DeleteBoardResult, string>({ deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
invalidatesTags: (result, error, board_id) => [ invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG }, { type: 'Board', id: LIST_TAG },
// invalidate the 'No Board' cache // invalidate the 'No Board' cache
{ {
@ -732,7 +742,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' }, { type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' }, { type: 'BoardAssetsTotal', id: 'none' },
], ],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/** /**
* Cache changes for deleteBoard: * Cache changes for deleteBoard:
* - Update every image in the 'getImageDTO' cache that has the board_id * - Update every image in the 'getImageDTO' cache that has the board_id
@ -802,7 +812,7 @@ export const imagesApi = api.injectEndpoints({
method: 'DELETE', method: 'DELETE',
params: { include_images: true }, params: { include_images: true },
}), }),
invalidatesTags: (result, error, board_id) => [ invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG }, { type: 'Board', id: LIST_TAG },
{ {
type: 'ImageList', type: 'ImageList',
@ -821,7 +831,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' }, { type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' }, { type: 'BoardAssetsTotal', id: 'none' },
], ],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) { async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/** /**
* Cache changes for deleteBoardAndImages: * Cache changes for deleteBoardAndImages:
* - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~ * - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~
@ -1253,9 +1263,8 @@ export const imagesApi = api.injectEndpoints({
]; ];
result?.removed_image_names.forEach((image_name) => { result?.removed_image_names.forEach((image_name) => {
const board_id = imageDTOs.find( const board_id = imageDTOs.find((i) => i.image_name === image_name)
(i) => i.image_name === image_name ?.board_id;
)?.board_id;
if (!board_id || touchedBoardIds.includes(board_id)) { if (!board_id || touchedBoardIds.includes(board_id)) {
return; return;
@ -1385,4 +1394,5 @@ export const {
useDeleteBoardMutation, useDeleteBoardMutation,
useStarImagesMutation, useStarImagesMutation,
useUnstarImagesMutation, useUnstarImagesMutation,
useGetImageMetadataFromFileQuery,
} = imagesApi; } = imagesApi;

View File

@ -178,7 +178,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' }); const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`; return `models/?${query}`;
}, },
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG }, { type: 'OnnxModel', id: LIST_TAG },
]; ];
@ -194,11 +194,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: { models: OnnxModelConfig[] }) => {
response: { models: OnnxModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<OnnxModelConfigEntity>( const entities = createModelEntities<OnnxModelConfigEntity>(
response.models response.models
); );
@ -221,7 +217,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' }); const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`; return `models/?${query}`;
}, },
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
]; ];
@ -237,11 +233,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: { models: MainModelConfig[] }) => {
response: { models: MainModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<MainModelConfigEntity>( const entities = createModelEntities<MainModelConfigEntity>(
response.models response.models
); );
@ -361,7 +353,7 @@ export const modelsApi = api.injectEndpoints({
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'LoRAModel', id: LIST_TAG }, { type: 'LoRAModel', id: LIST_TAG },
]; ];
@ -377,11 +369,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: { models: LoRAModelConfig[] }) => {
response: { models: LoRAModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<LoRAModelConfigEntity>( const entities = createModelEntities<LoRAModelConfigEntity>(
response.models response.models
); );
@ -421,7 +409,7 @@ export const modelsApi = api.injectEndpoints({
void void
>({ >({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'ControlNetModel', id: LIST_TAG }, { type: 'ControlNetModel', id: LIST_TAG },
]; ];
@ -437,11 +425,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: { models: ControlNetModelConfig[] }) => {
response: { models: ControlNetModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<ControlNetModelConfigEntity>( const entities = createModelEntities<ControlNetModelConfigEntity>(
response.models response.models
); );
@ -453,7 +437,7 @@ export const modelsApi = api.injectEndpoints({
}), }),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({ getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }), query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'VaeModel', id: LIST_TAG }, { type: 'VaeModel', id: LIST_TAG },
]; ];
@ -469,11 +453,7 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: { models: VaeModelConfig[] }) => {
response: { models: VaeModelConfig[] },
meta,
arg
) => {
const entities = createModelEntities<VaeModelConfigEntity>( const entities = createModelEntities<VaeModelConfigEntity>(
response.models response.models
); );
@ -488,7 +468,7 @@ export const modelsApi = api.injectEndpoints({
void void
>({ >({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'TextualInversionModel', id: LIST_TAG }, { type: 'TextualInversionModel', id: LIST_TAG },
]; ];
@ -504,11 +484,9 @@ export const modelsApi = api.injectEndpoints({
return tags; return tags;
}, },
transformResponse: ( transformResponse: (response: {
response: { models: TextualInversionModelConfig[] }, models: TextualInversionModelConfig[];
meta, }) => {
arg
) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>( const entities = createModelEntities<TextualInversionModelConfigEntity>(
response.models response.models
); );
@ -525,7 +503,7 @@ export const modelsApi = api.injectEndpoints({
url: `/models/search?${folderQueryStr}`, url: `/models/search?${folderQueryStr}`,
}; };
}, },
providesTags: (result, error, arg) => { providesTags: (result) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG }, { type: 'ScannedModels', id: LIST_TAG },
]; ];

View File

@ -16,6 +16,7 @@ export const tagTypes = [
'ImageNameList', 'ImageNameList',
'ImageList', 'ImageList',
'ImageMetadata', 'ImageMetadata',
'ImageMetadataFromFile',
'Model', 'Model',
]; ];
export type ApiFullTagDescription = FullTagDescription< export type ApiFullTagDescription = FullTagDescription<
@ -39,7 +40,7 @@ const dynamicBaseQuery: BaseQueryFn<
headers.set('Authorization', `Bearer ${authToken}`); headers.set('Authorization', `Bearer ${authToken}`);
} }
if (projectId) { if (projectId) {
headers.set("project-id", projectId) headers.set('project-id', projectId);
} }
return headers; return headers;

File diff suppressed because one or more lines are too long

View File

@ -1,14 +1,16 @@
import { createAsyncThunk } from '@reduxjs/toolkit'; import { createAsyncThunk } from '@reduxjs/toolkit';
function getCircularReplacer() { function getCircularReplacer() {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ancestors: Record<string, any>[] = []; const ancestors: Record<string, any>[] = [];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return function (key: string, value: any) { return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) { if (typeof value !== 'object' || value === null) {
return value; return value;
} }
// `this` is the object that value is contained in, // `this` is the object that value is contained in, i.e., its direct parent.
// i.e., its direct parent. // eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore // @ts-ignore don't think it's possible to not have TS complain about this...
while (ancestors.length > 0 && ancestors.at(-1) !== this) { while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop(); ancestors.pop();
} }

View File

@ -73,7 +73,7 @@ export const sessionInvoked = createAsyncThunk<
>('api/sessionInvoked', async (arg, { rejectWithValue }) => { >('api/sessionInvoked', async (arg, { rejectWithValue }) => {
const { session_id } = arg; const { session_id } = arg;
const { PUT } = $client.get(); const { PUT } = $client.get();
const { data, error, response } = await PUT( const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke', '/api/v1/sessions/{session_id}/invoke',
{ {
params: { query: { all: true }, path: { session_id } }, params: { query: { all: true }, path: { session_id } },
@ -85,6 +85,7 @@ export const sessionInvoked = createAsyncThunk<
return rejectWithValue({ return rejectWithValue({
arg, arg,
status: response.status, status: response.status,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (error as any).body.detail, error: (error as any).body.detail,
}); });
} }
@ -124,14 +125,11 @@ export const sessionCanceled = createAsyncThunk<
>('api/sessionCanceled', async (arg, { rejectWithValue }) => { >('api/sessionCanceled', async (arg, { rejectWithValue }) => {
const { session_id } = arg; const { session_id } = arg;
const { DELETE } = $client.get(); const { DELETE } = $client.get();
const { data, error, response } = await DELETE( const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
'/api/v1/sessions/{session_id}/invoke',
{
params: { params: {
path: { session_id }, path: { session_id },
}, },
} });
);
if (error) { if (error) {
return rejectWithValue({ arg, error }); return rejectWithValue({ arg, error });
@ -164,7 +162,7 @@ export const listedSessions = createAsyncThunk<
>('api/listSessions', async (arg, { rejectWithValue }) => { >('api/listSessions', async (arg, { rejectWithValue }) => {
const { params } = arg; const { params } = arg;
const { GET } = $client.get(); const { GET } = $client.get();
const { data, error, response } = await GET('/api/v1/sessions/', { const { data, error } = await GET('/api/v1/sessions/', {
params, params,
}); });

View File

@ -26,15 +26,21 @@ export const getIsImageInDateRange = (
for (let index = 0; index < totalCachedImageDtos.length; index++) { for (let index = 0; index < totalCachedImageDtos.length; index++) {
const image = totalCachedImageDtos[index]; const image = totalCachedImageDtos[index];
if (image?.starred) cachedStarredImages.push(image); if (image?.starred) {
if (!image?.starred) cachedUnstarredImages.push(image); cachedStarredImages.push(image);
}
if (!image?.starred) {
cachedUnstarredImages.push(image);
}
} }
if (imageDTO.starred) { if (imageDTO.starred) {
const lastStarredImage = const lastStarredImage =
cachedStarredImages[cachedStarredImages.length - 1]; cachedStarredImages[cachedStarredImages.length - 1];
// if starring or already starred, want to look in list of starred images // if starring or already starred, want to look in list of starred images
if (!lastStarredImage) return true; // no starred images showing, so always show this one if (!lastStarredImage) {
return true;
} // no starred images showing, so always show this one
const createdDate = new Date(imageDTO.created_at); const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastStarredImage.created_at); const oldestDate = new Date(lastStarredImage.created_at);
return createdDate >= oldestDate; return createdDate >= oldestDate;
@ -42,7 +48,9 @@ export const getIsImageInDateRange = (
const lastUnstarredImage = const lastUnstarredImage =
cachedUnstarredImages[cachedUnstarredImages.length - 1]; cachedUnstarredImages[cachedUnstarredImages.length - 1];
// if unstarring or already unstarred, want to look in list of unstarred images // if unstarring or already unstarred, want to look in list of unstarred images
if (!lastUnstarredImage) return false; // no unstarred images showing, so don't show this one if (!lastUnstarredImage) {
return false;
} // no unstarred images showing, so don't show this one
const createdDate = new Date(imageDTO.created_at); const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastUnstarredImage.created_at); const oldestDate = new Date(lastUnstarredImage.created_at);
return createdDate >= oldestDate; return createdDate >= oldestDate;

Some files were not shown because too many files have changed in this diff Show More