mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
nodes phase 5: workflow saving and loading (#4353)
## What type of PR is this? (check all applicable) - [ ] Refactor - [x] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Description - Workflows are saved to image files directly - Image-outputting nodes have an `Embed Workflow` checkbox which, if enabled, saves the workflow - `BaseInvocation` now has an `workflow: Optional[str]` field, so all nodes automatically have the field (but again only image-outputting nodes display this in UI) - If this field is enabled, when the graph is created, the workflow is stringified and set in this field - Nodes should add `workflow=self.workflow` when they save their output image to have the workflow written to the image - Uploads now have their metadata retained so that you can upload somebody else's image and have access to that workflow - Graphs are no longer saved to images, workflows replace them ### TODO - Images created in the linear UI do not have a workflow saved yet. Need to write a function to build a workflow around the linear UI graph when using linear tabs. Unfortunately it will not have the nice positioning and size data the node editor gives you when you save a workflow... we'll have to figure out how to handle this. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Related Issue # - Closes # ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. -->
This commit is contained in:
commit
2bd3cf28ea
@ -29,12 +29,13 @@ The first set of things we need to do when creating a new Invocation are -
|
|||||||
|
|
||||||
- Create a new class that derives from a predefined parent class called
|
- Create a new class that derives from a predefined parent class called
|
||||||
`BaseInvocation`.
|
`BaseInvocation`.
|
||||||
- The name of every Invocation must end with the word `Invocation` in order for
|
|
||||||
it to be recognized as an Invocation.
|
|
||||||
- Every Invocation must have a `docstring` that describes what this Invocation
|
- Every Invocation must have a `docstring` that describes what this Invocation
|
||||||
does.
|
does.
|
||||||
- Every Invocation must have a unique `type` field defined which becomes its
|
- While not strictly required, we suggest every invocation class name ends in
|
||||||
indentifier.
|
"Invocation", eg "CropImageInvocation".
|
||||||
|
- Every Invocation must use the `@invocation` decorator to provide its unique
|
||||||
|
invocation type. You may also provide its title, tags and category using the
|
||||||
|
decorator.
|
||||||
- Invocations are strictly typed. We make use of the native
|
- Invocations are strictly typed. We make use of the native
|
||||||
[typing](https://docs.python.org/3/library/typing.html) library and the
|
[typing](https://docs.python.org/3/library/typing.html) library and the
|
||||||
installed [pydantic](https://pydantic-docs.helpmanual.io/) library for
|
installed [pydantic](https://pydantic-docs.helpmanual.io/) library for
|
||||||
@ -43,12 +44,11 @@ The first set of things we need to do when creating a new Invocation are -
|
|||||||
So let us do that.
|
So let us do that.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .baseinvocation import BaseInvocation
|
|
||||||
|
|
||||||
|
@invocation('resize')
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
That's great.
|
That's great.
|
||||||
@ -62,8 +62,10 @@ our Invocation takes.
|
|||||||
|
|
||||||
### **Inputs**
|
### **Inputs**
|
||||||
|
|
||||||
Every Invocation input is a pydantic `Field` and like everything else should be
|
Every Invocation input must be defined using the `InputField` function. This is
|
||||||
strictly typed and defined.
|
a wrapper around the pydantic `Field` function, which handles a few extra things
|
||||||
|
and provides type hints. Like everything else, this should be strictly typed and
|
||||||
|
defined.
|
||||||
|
|
||||||
So let us create these inputs for our Invocation. First up, the `image` input we
|
So let us create these inputs for our Invocation. First up, the `image` input we
|
||||||
need. Generally, we can use standard variable types in Python but InvokeAI
|
need. Generally, we can use standard variable types in Python but InvokeAI
|
||||||
@ -76,55 +78,51 @@ create your own custom field types later in this guide. For now, let's go ahead
|
|||||||
and use it.
|
and use it.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal, Union
|
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||||
from pydantic import Field
|
from .primitives import ImageField
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation
|
|
||||||
from ..models.image import ImageField
|
|
||||||
|
|
||||||
|
@invocation('resize')
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: ImageField = InputField(description="The input image")
|
||||||
```
|
```
|
||||||
|
|
||||||
Let us break down our input code.
|
Let us break down our input code.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: ImageField = InputField(description="The input image")
|
||||||
```
|
```
|
||||||
|
|
||||||
| Part | Value | Description |
|
| Part | Value | Description |
|
||||||
| --------- | ---------------------------------------------------- | -------------------------------------------------------------------------------------------------- |
|
| --------- | ------------------------------------------- | ------------------------------------------------------------------------------- |
|
||||||
| Name | `image` | The variable that will hold our image |
|
| Name | `image` | The variable that will hold our image |
|
||||||
| Type Hint | `Union[ImageField, None]` | The types for our field. Indicates that the image can either be an `ImageField` type or `None` |
|
| Type Hint | `ImageField` | The types for our field. Indicates that the image must be an `ImageField` type. |
|
||||||
| Field | `Field(description="The input image", default=None)` | The image variable is a field which needs a description and a default value that we set to `None`. |
|
| Field | `InputField(description="The input image")` | The image variable is an `InputField` which needs a description. |
|
||||||
|
|
||||||
Great. Now let us create our other inputs for `width` and `height`
|
Great. Now let us create our other inputs for `width` and `height`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal, Union
|
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||||
from pydantic import Field
|
from .primitives import ImageField
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation
|
|
||||||
from ..models.image import ImageField
|
|
||||||
|
|
||||||
|
@invocation('resize')
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
```
|
```
|
||||||
|
|
||||||
As you might have noticed, we added two new parameters to the field type for
|
As you might have noticed, we added two new arguments to the `InputField`
|
||||||
`width` and `height` called `gt` and `le`. These basically stand for _greater
|
definition for `width` and `height`, called `gt` and `le`. They stand for
|
||||||
than or equal to_ and _less than or equal to_. There are various other param
|
_greater than or equal to_ and _less than or equal to_.
|
||||||
types for field that you can find on the **pydantic** documentation.
|
|
||||||
|
These impose contraints on those fields, and will raise an exception if the
|
||||||
|
values do not meet the constraints. Field constraints are provided by
|
||||||
|
**pydantic**, so anything you see in the **pydantic docs** will work.
|
||||||
|
|
||||||
**Note:** _Any time it is possible to define constraints for our field, we
|
**Note:** _Any time it is possible to define constraints for our field, we
|
||||||
should do it so the frontend has more information on how to parse this field._
|
should do it so the frontend has more information on how to parse this field._
|
||||||
@ -141,20 +139,17 @@ that are provided by it by InvokeAI.
|
|||||||
Let us create this function first.
|
Let us create this function first.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal, Union
|
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||||
from pydantic import Field
|
from .primitives import ImageField
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
from ..models.image import ImageField
|
|
||||||
|
|
||||||
|
@invocation('resize')
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext):
|
def invoke(self, context: InvocationContext):
|
||||||
pass
|
pass
|
||||||
@ -173,21 +168,18 @@ all the necessary info related to image outputs. So let us use that.
|
|||||||
We will cover how to create your own output types later in this guide.
|
We will cover how to create your own output types later in this guide.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal, Union
|
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||||
from pydantic import Field
|
from .primitives import ImageField
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
from ..models.image import ImageField
|
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
@invocation('resize')
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
'''Resizes an image'''
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: ImageField = InputField(description="The input image")
|
||||||
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pass
|
pass
|
||||||
@ -195,39 +187,34 @@ class ResizeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
||||||
|
|
||||||
- We will first load the image. Generally we do this using the `PIL` library but
|
- We will first load the image using one of the services provided by InvokeAI to
|
||||||
we can use one of the services provided by InvokeAI to load the image.
|
load the image.
|
||||||
- We will resize the image using `PIL` to our input data.
|
- We will resize the image using `PIL` to our input data.
|
||||||
- We will output this image in the format we set above.
|
- We will output this image in the format we set above.
|
||||||
|
|
||||||
So let's do that.
|
So let's do that.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from typing import Literal, Union
|
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||||
from pydantic import Field
|
from .primitives import ImageField
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
from ..models.image import ImageField, ResourceOrigin, ImageCategory
|
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
@invocation("resize")
|
||||||
class ResizeInvocation(BaseInvocation):
|
class ResizeInvocation(BaseInvocation):
|
||||||
'''Resizes an image'''
|
"""Resizes an image"""
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
image: ImageField = InputField(description="The input image")
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||||
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
|
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||||
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Load the image using InvokeAI's predefined Image Service.
|
# Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
|
||||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
# Resizing the image
|
# Resizing the image
|
||||||
# Because we used the above service, we already have a PIL image. So we can simply resize.
|
|
||||||
resized_image = image.resize((self.width, self.height))
|
resized_image = image.resize((self.width, self.height))
|
||||||
|
|
||||||
# Preparing the image for output using InvokeAI's predefined Image Service.
|
# Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
|
||||||
output_image = context.services.images.create(
|
output_image = context.services.images.create(
|
||||||
image=resized_image,
|
image=resized_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
@ -241,7 +228,6 @@ class ResizeInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=output_image.image_name,
|
image_name=output_image.image_name,
|
||||||
image_origin=output_image.image_origin,
|
|
||||||
),
|
),
|
||||||
width=output_image.width,
|
width=output_image.width,
|
||||||
height=output_image.height,
|
height=output_image.height,
|
||||||
@ -253,6 +239,20 @@ certain way that the images need to be dispatched in order to be stored and read
|
|||||||
correctly. In 99% of the cases when dealing with an image output, you can simply
|
correctly. In 99% of the cases when dealing with an image output, you can simply
|
||||||
copy-paste the template above.
|
copy-paste the template above.
|
||||||
|
|
||||||
|
### Customization
|
||||||
|
|
||||||
|
We can use the `@invocation` decorator to provide some additional info to the
|
||||||
|
UI, like a custom title, tags and category.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
|
||||||
|
class ResizeInvocation(BaseInvocation):
|
||||||
|
"""Resizes an image"""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The input image")
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
That's it. You made your own **Resize Invocation**.
|
That's it. You made your own **Resize Invocation**.
|
||||||
|
|
||||||
## Result
|
## Result
|
||||||
@ -271,10 +271,57 @@ new Invocation ready to be used.
|
|||||||

|

|
||||||
|
|
||||||
## Contributing Nodes
|
## Contributing Nodes
|
||||||
Once you've created a Node, the next step is to share it with the community! The best way to do this is to submit a Pull Request to add the Node to the [Community Nodes](nodes/communityNodes) list. If you're not sure how to do that, take a look a at our [contributing nodes overview](contributingNodes).
|
|
||||||
|
Once you've created a Node, the next step is to share it with the community! The
|
||||||
|
best way to do this is to submit a Pull Request to add the Node to the
|
||||||
|
[Community Nodes](nodes/communityNodes) list. If you're not sure how to do that,
|
||||||
|
take a look a at our [contributing nodes overview](contributingNodes).
|
||||||
|
|
||||||
## Advanced
|
## Advanced
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
### Custom Output Types
|
||||||
|
|
||||||
|
Like with custom inputs, sometimes you might find yourself needing custom
|
||||||
|
outputs that InvokeAI does not provide. We can easily set one up.
|
||||||
|
|
||||||
|
Now that you are familiar with Invocations and Inputs, let us use that knowledge
|
||||||
|
to create an output that has an `image` field, a `color` field and a `string`
|
||||||
|
field.
|
||||||
|
|
||||||
|
- An invocation output is a class that derives from the parent class of
|
||||||
|
`BaseInvocationOutput`.
|
||||||
|
- All invocation outputs must use the `@invocation_output` decorator to provide
|
||||||
|
their unique output type.
|
||||||
|
- Output fields must use the provided `OutputField` function. This is very
|
||||||
|
similar to the `InputField` function described earlier - it's a wrapper around
|
||||||
|
`pydantic`'s `Field()`.
|
||||||
|
- It is not mandatory but we recommend using names ending with `Output` for
|
||||||
|
output types.
|
||||||
|
- It is not mandatory but we highly recommend adding a `docstring` to describe
|
||||||
|
what your output type is for.
|
||||||
|
|
||||||
|
Now that we know the basic rules for creating a new output type, let us go ahead
|
||||||
|
and make it.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .baseinvocation import BaseInvocationOutput, OutputField, invocation_output
|
||||||
|
from .primitives import ImageField, ColorField
|
||||||
|
|
||||||
|
@invocation_output('image_color_string_output')
|
||||||
|
class ImageColorStringOutput(BaseInvocationOutput):
|
||||||
|
'''Base class for nodes that output a single image'''
|
||||||
|
|
||||||
|
image: ImageField = OutputField(description="The image")
|
||||||
|
color: ColorField = OutputField(description="The color")
|
||||||
|
text: str = OutputField(description="The string")
|
||||||
|
```
|
||||||
|
|
||||||
|
That's all there is to it.
|
||||||
|
|
||||||
|
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
|
||||||
|
|
||||||
### Custom Input Fields
|
### Custom Input Fields
|
||||||
|
|
||||||
Now that you know how to create your own Invocations, let us dive into slightly
|
Now that you know how to create your own Invocations, let us dive into slightly
|
||||||
@ -329,172 +376,6 @@ like this.
|
|||||||
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||||
```
|
```
|
||||||
|
|
||||||
**Extra Config**
|
|
||||||
|
|
||||||
All input fields also take an additional `Config` class that you can use to do
|
|
||||||
various advanced things like setting required parameters and etc.
|
|
||||||
|
|
||||||
Let us do that for our _ColorField_ and enforce all the values because we did
|
|
||||||
not define any defaults for our fields.
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ColorField(BaseModel):
|
|
||||||
'''A field that holds the rgba values of a color'''
|
|
||||||
r: int = Field(ge=0, le=255, description="The red channel")
|
|
||||||
g: int = Field(ge=0, le=255, description="The green channel")
|
|
||||||
b: int = Field(ge=0, le=255, description="The blue channel")
|
|
||||||
a: int = Field(ge=0, le=255, description="The alpha channel")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["r", "g", "b", "a"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Now it becomes mandatory for the user to supply all the values required by our
|
|
||||||
input field.
|
|
||||||
|
|
||||||
We will discuss the `Config` class in extra detail later in this guide and how
|
|
||||||
you can use it to make your Invocations more robust.
|
|
||||||
|
|
||||||
### Custom Output Types
|
|
||||||
|
|
||||||
Like with custom inputs, sometimes you might find yourself needing custom
|
|
||||||
outputs that InvokeAI does not provide. We can easily set one up.
|
|
||||||
|
|
||||||
Now that you are familiar with Invocations and Inputs, let us use that knowledge
|
|
||||||
to put together a custom output type for an Invocation that returns _width_,
|
|
||||||
_height_ and _background_color_ that we need to create a blank image.
|
|
||||||
|
|
||||||
- A custom output type is a class that derives from the parent class of
|
|
||||||
`BaseInvocationOutput`.
|
|
||||||
- It is not mandatory but we recommend using names ending with `Output` for
|
|
||||||
output types. So we'll call our class `BlankImageOutput`
|
|
||||||
- It is not mandatory but we highly recommend adding a `docstring` to describe
|
|
||||||
what your output type is for.
|
|
||||||
- Like Invocations, each output type should have a `type` variable that is
|
|
||||||
**unique**
|
|
||||||
|
|
||||||
Now that we know the basic rules for creating a new output type, let us go ahead
|
|
||||||
and make it.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from typing import Literal
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocationOutput
|
|
||||||
|
|
||||||
class BlankImageOutput(BaseInvocationOutput):
|
|
||||||
'''Base output type for creating a blank image'''
|
|
||||||
type: Literal['blank_image_output'] = 'blank_image_output'
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
width: int = Field(description='Width of blank image')
|
|
||||||
height: int = Field(description='Height of blank image')
|
|
||||||
bg_color: ColorField = Field(description='Background color of blank image')
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "width", "height", "bg_color"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
All set. We now have an output type that requires what we need to create a
|
|
||||||
blank_image. And if you noticed it, we even used the `Config` class to ensure
|
|
||||||
the fields are required.
|
|
||||||
|
|
||||||
### Custom Configuration
|
|
||||||
|
|
||||||
As you might have noticed when making inputs and outputs, we used a class called
|
|
||||||
`Config` from _pydantic_ to further customize them. Because our inputs and
|
|
||||||
outputs essentially inherit from _pydantic_'s `BaseModel` class, all
|
|
||||||
[configuration options](https://docs.pydantic.dev/latest/usage/schema/#schema-customization)
|
|
||||||
that are valid for _pydantic_ classes are also valid for our inputs and outputs.
|
|
||||||
You can do the same for your Invocations too but InvokeAI makes our life a
|
|
||||||
little bit easier on that end.
|
|
||||||
|
|
||||||
InvokeAI provides a custom configuration class called `InvocationConfig`
|
|
||||||
particularly for configuring Invocations. This is exactly the same as the raw
|
|
||||||
`Config` class from _pydantic_ with some extra stuff on top to help faciliate
|
|
||||||
parsing of the scheme in the frontend UI.
|
|
||||||
|
|
||||||
At the current moment, tihs `InvocationConfig` class is further improved with
|
|
||||||
the following features related the `ui`.
|
|
||||||
|
|
||||||
| Config Option | Field Type | Example |
|
|
||||||
| ------------- | ------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| type_hints | `Dict[str, Literal["integer", "float", "boolean", "string", "enum", "image", "latents", "model", "control"]]` | `type_hint: "model"` provides type hints related to the model like displaying a list of available models |
|
|
||||||
| tags | `List[str]` | `tags: ['resize', 'image']` will classify your invocation under the tags of resize and image. |
|
|
||||||
| title | `str` | `title: 'Resize Image` will rename your to this custom title rather than infer from the name of the Invocation class. |
|
|
||||||
|
|
||||||
So let us update your `ResizeInvocation` with some extra configuration and see
|
|
||||||
how that works.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from typing import Literal, Union
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
from ..models.image import ImageField, ResourceOrigin, ImageCategory
|
|
||||||
from .image import ImageOutput
|
|
||||||
|
|
||||||
class ResizeInvocation(BaseInvocation):
|
|
||||||
'''Resizes an image'''
|
|
||||||
type: Literal['resize'] = 'resize'
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
|
||||||
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
|
|
||||||
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra: {
|
|
||||||
ui: {
|
|
||||||
tags: ['resize', 'image'],
|
|
||||||
title: ['My Custom Resize']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
# Load the image using InvokeAI's predefined Image Service.
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
|
||||||
|
|
||||||
# Resizing the image
|
|
||||||
# Because we used the above service, we already have a PIL image. So we can simply resize.
|
|
||||||
resized_image = image.resize((self.width, self.height))
|
|
||||||
|
|
||||||
# Preparing the image for output using InvokeAI's predefined Image Service.
|
|
||||||
output_image = context.services.images.create(
|
|
||||||
image=resized_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Returning the Image
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(
|
|
||||||
image_name=output_image.image_name,
|
|
||||||
image_origin=output_image.image_origin,
|
|
||||||
),
|
|
||||||
width=output_image.width,
|
|
||||||
height=output_image.height,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
We now customized our code to let the frontend know that our Invocation falls
|
|
||||||
under `resize` and `image` categories. So when the user searches for these
|
|
||||||
particular words, our Invocation will show up too.
|
|
||||||
|
|
||||||
We also set a custom title for our Invocation. So instead of being called
|
|
||||||
`Resize`, it will be called `My Custom Resize`.
|
|
||||||
|
|
||||||
As simple as that.
|
|
||||||
|
|
||||||
As time goes by, InvokeAI will further improve and add more customizability for
|
|
||||||
Invocation configuration. We will have more documentation regarding this at a
|
|
||||||
later time.
|
|
||||||
|
|
||||||
# **[TODO]**
|
|
||||||
|
|
||||||
### Custom Components For Frontend
|
### Custom Components For Frontend
|
||||||
|
|
||||||
Every backend input type should have a corresponding frontend component so the
|
Every backend input type should have a corresponding frontend component so the
|
||||||
@ -513,282 +394,4 @@ Let us create a new component for our custom color field we created above. When
|
|||||||
we use a color field, let us say we want the UI to display a color picker for
|
we use a color field, let us say we want the UI to display a color picker for
|
||||||
the user to pick from rather than entering values. That is what we will build
|
the user to pick from rather than entering values. That is what we will build
|
||||||
now.
|
now.
|
||||||
|
-->
|
||||||
---
|
|
||||||
|
|
||||||
<!-- # OLD -- TO BE DELETED OR MOVED LATER
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Creating a new invocation
|
|
||||||
|
|
||||||
To create a new invocation, either find the appropriate module file in
|
|
||||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
|
||||||
that folder. All invocations in that folder will be discovered and made
|
|
||||||
available to the CLI and API automatically. Invocations make use of
|
|
||||||
[typing](https://docs.python.org/3/library/typing.html) and
|
|
||||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
|
||||||
into the CLI and API.
|
|
||||||
|
|
||||||
An invocation looks like this:
|
|
||||||
|
|
||||||
```py
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
|
||||||
"""Upscales an image."""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["upscale"] = "upscale"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
|
||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(
|
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=(self.level, self.strength),
|
|
||||||
strength=0.0, # GFPGAN strength
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
|
||||||
# TODO: can this return multiple results?
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=results[0][0],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(
|
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Each portion is important to implement correctly.
|
|
||||||
|
|
||||||
### Class definition and type
|
|
||||||
|
|
||||||
```py
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
|
||||||
"""Upscales an image."""
|
|
||||||
type: Literal['upscale'] = 'upscale'
|
|
||||||
```
|
|
||||||
|
|
||||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
|
||||||
that declares what they do in a single, short line. They should also have a
|
|
||||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
|
||||||
is what the user will type on the CLI or use in the API to create this
|
|
||||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
|
||||||
the value of the literal in the type hint.
|
|
||||||
|
|
||||||
### Inputs
|
|
||||||
|
|
||||||
```py
|
|
||||||
# Inputs
|
|
||||||
image: Union[ImageField,None] = Field(description="The input image")
|
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
|
||||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
|
||||||
```
|
|
||||||
|
|
||||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
|
||||||
description, and validation information. For example:
|
|
||||||
|
|
||||||
| Part | Value | Description |
|
|
||||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| Name | `strength` | This field is referred to as `strength` |
|
|
||||||
| Type Hint | `float` | This field must be of type `float` |
|
|
||||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
|
||||||
|
|
||||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
|
||||||
field to be parsed with `None` as a value, which enables linking to previous
|
|
||||||
invocations. All fields should either provide a default value or allow `None` as
|
|
||||||
a value, so that they can be overwritten with a linked output from another
|
|
||||||
invocation.
|
|
||||||
|
|
||||||
The special type `ImageField` is also used here. All images are passed as
|
|
||||||
`ImageField`, which protects them from pydantic validation errors (since images
|
|
||||||
only ever come from links).
|
|
||||||
|
|
||||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
|
||||||
If the `name` also matches, then the field can be **automatically linked** to a
|
|
||||||
previous invocation by name and matching.
|
|
||||||
|
|
||||||
### Config
|
|
||||||
|
|
||||||
```py
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
This is an optional configuration for the invocation. It inherits from
|
|
||||||
pydantic's model `Config` class, and it used primarily to customize the
|
|
||||||
autogenerated OpenAPI schema.
|
|
||||||
|
|
||||||
The UI relies on the OpenAPI schema in two ways:
|
|
||||||
|
|
||||||
- An API client & Typescript types are generated from it. This happens at build
|
|
||||||
time.
|
|
||||||
- The node editor parses the schema into a template used by the UI to create the
|
|
||||||
node editor UI. This parsing happens at runtime.
|
|
||||||
|
|
||||||
In this example, a `ui` key has been added to the `schema_extra` dict to provide
|
|
||||||
some tags for the UI, to facilitate filtering nodes.
|
|
||||||
|
|
||||||
See the Schema Generation section below for more information.
|
|
||||||
|
|
||||||
### Invoke Function
|
|
||||||
|
|
||||||
```py
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get_pil_image(
|
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
|
||||||
image_list=[[image, 0]],
|
|
||||||
upscale=(self.level, self.strength),
|
|
||||||
strength=0.0, # GFPGAN strength
|
|
||||||
save_original=False,
|
|
||||||
image_callback=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
|
||||||
# TODO: can this return multiple results?
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=results[0][0],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(
|
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
The `invoke` function is the last portion of an invocation. It is provided an
|
|
||||||
`InvocationContext` which contains services to perform work as well as a
|
|
||||||
`session_id` for use as needed. It should return a class with output values that
|
|
||||||
derives from `BaseInvocationOutput`.
|
|
||||||
|
|
||||||
Before being called, the invocation will have all of its fields set from
|
|
||||||
defaults, inputs, and finally links (overriding in that order).
|
|
||||||
|
|
||||||
Assume that this invocation may be running simultaneously with other
|
|
||||||
invocations, may be running on another machine, or in other interesting
|
|
||||||
scenarios. If you need functionality, please provide it as a service in the
|
|
||||||
`InvocationServices` class, and make sure it can be overridden.
|
|
||||||
|
|
||||||
### Outputs
|
|
||||||
|
|
||||||
```py
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_output"] = "image_output"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output classes look like an invocation class without the invoke method. Prefer
|
|
||||||
to use an existing output class if available, and prefer to name inputs the same
|
|
||||||
as outputs when possible, to promote automatic invocation linking.
|
|
||||||
|
|
||||||
## Schema Generation
|
|
||||||
|
|
||||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
|
||||||
|
|
||||||
### Required Properties
|
|
||||||
|
|
||||||
The schema generation treat all properties with default values as optional. This
|
|
||||||
makes sense internally, but when when using these classes via the generated
|
|
||||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
|
||||||
marked as optional.
|
|
||||||
|
|
||||||
We know that this property will always be present, so the additional logic
|
|
||||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
|
||||||
|
|
||||||
To fix this, we can leverage `pydantic`'s
|
|
||||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
|
||||||
to mark properties that we know will always be present as required.
|
|
||||||
|
|
||||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_output"] = "image_output"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# fmt: on
|
|
||||||
```
|
|
||||||
|
|
||||||
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
|
|
||||||
`image`, `width` and `height` properties marked as optional, even though we know
|
|
||||||
they will always have a value.
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["image_output"] = "image_output"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# Add schema customization
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
|
||||||
```
|
|
||||||
|
|
||||||
With the customization in place, the schema will now show these properties as
|
|
||||||
required, obviating the need for extensive null checks in client code.
|
|
||||||
|
|
||||||
See this `pydantic` issue for discussion on this solution:
|
|
||||||
<https://github.com/pydantic/pydantic/discussions/4577> -->
|
|
||||||
|
|
||||||
|
@ -2,15 +2,18 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
import re
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
@ -20,8 +23,8 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, validator
|
||||||
from pydantic.fields import Undefined
|
from pydantic.fields import Undefined, ModelField
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -141,9 +144,11 @@ class UIType(str, Enum):
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Misc
|
# region Misc
|
||||||
FilePath = "FilePath"
|
|
||||||
Enum = "enum"
|
Enum = "enum"
|
||||||
Scheduler = "Scheduler"
|
Scheduler = "Scheduler"
|
||||||
|
WorkflowField = "WorkflowField"
|
||||||
|
IsIntermediate = "IsIntermediate"
|
||||||
|
MetadataField = "MetadataField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
@ -365,12 +370,12 @@ def OutputField(
|
|||||||
class UIConfigBase(BaseModel):
|
class UIConfigBase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Provides additional node configuration to the UI.
|
Provides additional node configuration to the UI.
|
||||||
This is used internally by the @tags and @title decorator logic. You probably want to use those
|
This is used internally by the @invocation decorator logic. Do not use this directly.
|
||||||
decorators, though you may add this class to a node definition to specify the title and tags.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
|
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||||
title: Optional[str] = Field(default=None, description="The display name of the node")
|
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||||
|
category: Optional[str] = Field(default=None, description="The node's category")
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
@ -383,10 +388,11 @@ class InvocationContext:
|
|||||||
|
|
||||||
|
|
||||||
class BaseInvocationOutput(BaseModel):
|
class BaseInvocationOutput(BaseModel):
|
||||||
"""Base class for all invocation outputs"""
|
"""
|
||||||
|
Base class for all invocation outputs.
|
||||||
|
|
||||||
# All outputs must include a type name like this:
|
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||||
# type: Literal['your_output_name'] # noqa f821
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses_tuple(cls):
|
def get_all_subclasses_tuple(cls):
|
||||||
@ -422,12 +428,12 @@ class MissingInputException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class BaseInvocation(ABC, BaseModel):
|
class BaseInvocation(ABC, BaseModel):
|
||||||
"""A node to process inputs and produce outputs.
|
|
||||||
May use dependency injection in __init__ to receive providers.
|
|
||||||
"""
|
"""
|
||||||
|
A node to process inputs and produce outputs.
|
||||||
|
May use dependency injection in __init__ to receive providers.
|
||||||
|
|
||||||
# All invocations must include a type name like this:
|
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||||
# type: Literal['your_output_name'] # noqa f821
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def get_all_subclasses(cls):
|
||||||
@ -466,6 +472,8 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["title"] = uiconfig.title
|
schema["title"] = uiconfig.title
|
||||||
if uiconfig and hasattr(uiconfig, "tags"):
|
if uiconfig and hasattr(uiconfig, "tags"):
|
||||||
schema["tags"] = uiconfig.tags
|
schema["tags"] = uiconfig.tags
|
||||||
|
if uiconfig and hasattr(uiconfig, "category"):
|
||||||
|
schema["category"] = uiconfig.category
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = list()
|
schema["required"] = list()
|
||||||
schema["required"].extend(["type", "id"])
|
schema["required"].extend(["type", "id"])
|
||||||
@ -505,37 +513,110 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(
|
||||||
is_intermediate: bool = InputField(
|
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
|
||||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
|
||||||
)
|
)
|
||||||
|
is_intermediate: bool = InputField(
|
||||||
|
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
|
||||||
|
)
|
||||||
|
workflow: Optional[str] = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The workflow to save with the image",
|
||||||
|
ui_type=UIType.WorkflowField,
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("workflow", pre=True)
|
||||||
|
def validate_workflow_is_json(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
json.loads(v)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
raise ValueError("Workflow must be valid JSON")
|
||||||
|
return v
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseInvocation)
|
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||||
|
|
||||||
|
|
||||||
def title(title: str) -> Callable[[Type[T]], Type[T]]:
|
def invocation(
|
||||||
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
|
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
||||||
|
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||||
|
"""
|
||||||
|
Adds metadata to an invocation.
|
||||||
|
|
||||||
def wrapper(cls: Type[T]) -> Type[T]:
|
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
|
||||||
|
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
|
||||||
|
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
|
||||||
|
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
|
||||||
|
# Validate invocation types on creation of invocation classes
|
||||||
|
# TODO: ensure unique?
|
||||||
|
if re.compile(r"^\S+$").match(invocation_type) is None:
|
||||||
|
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
|
||||||
|
|
||||||
|
# Add OpenAPI schema extras
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
||||||
cls.UIConfig.title = title
|
if title is not None:
|
||||||
|
cls.UIConfig.title = title
|
||||||
|
if tags is not None:
|
||||||
|
cls.UIConfig.tags = tags
|
||||||
|
if category is not None:
|
||||||
|
cls.UIConfig.category = category
|
||||||
|
|
||||||
|
# Add the invocation type to the pydantic model of the invocation
|
||||||
|
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||||
|
invocation_type_field = ModelField.infer(
|
||||||
|
name="type",
|
||||||
|
value=invocation_type,
|
||||||
|
annotation=invocation_type_annotation,
|
||||||
|
class_validators=None,
|
||||||
|
config=cls.__config__,
|
||||||
|
)
|
||||||
|
cls.__fields__.update({"type": invocation_type_field})
|
||||||
|
cls.__annotations__.update({"type": invocation_type_annotation})
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
|
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
|
||||||
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
|
|
||||||
|
|
||||||
|
def invocation_output(
|
||||||
|
output_type: str,
|
||||||
|
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
|
||||||
|
"""
|
||||||
|
Adds metadata to an invocation output.
|
||||||
|
|
||||||
|
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
|
||||||
|
# Validate output types on creation of invocation output classes
|
||||||
|
# TODO: ensure unique?
|
||||||
|
if re.compile(r"^\S+$").match(output_type) is None:
|
||||||
|
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||||
|
|
||||||
|
# Add the output type to the pydantic model of the invocation output
|
||||||
|
output_type_annotation = Literal[output_type] # type: ignore
|
||||||
|
output_type_field = ModelField.infer(
|
||||||
|
name="type",
|
||||||
|
value=output_type,
|
||||||
|
annotation=output_type_annotation,
|
||||||
|
class_validators=None,
|
||||||
|
config=cls.__config__,
|
||||||
|
)
|
||||||
|
cls.__fields__.update({"type": output_type_field})
|
||||||
|
cls.__annotations__.update({"type": output_type_annotation})
|
||||||
|
|
||||||
def wrapper(cls: Type[T]) -> Type[T]:
|
|
||||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
|
||||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
|
|
||||||
cls.UIConfig.tags = list(tags)
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
@ -8,17 +7,13 @@ from pydantic import validator
|
|||||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Range")
|
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
||||||
@tags("collection", "integer", "range")
|
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
type: Literal["range"] = "range"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
start: int = InputField(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
stop: int = InputField(default=10, description="The stop of the range")
|
stop: int = InputField(default=10, description="The stop of the range")
|
||||||
step: int = InputField(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
@ -33,14 +28,15 @@ class RangeInvocation(BaseInvocation):
|
|||||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Range of Size")
|
@invocation(
|
||||||
@tags("range", "integer", "size", "collection")
|
"range_of_size",
|
||||||
|
title="Integer Range of Size",
|
||||||
|
tags=["collection", "integer", "size", "range"],
|
||||||
|
category="collections",
|
||||||
|
)
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
"""Creates a range from start to start + size with step"""
|
"""Creates a range from start to start + size with step"""
|
||||||
|
|
||||||
type: Literal["range_of_size"] = "range_of_size"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
start: int = InputField(default=0, description="The start of the range")
|
start: int = InputField(default=0, description="The start of the range")
|
||||||
size: int = InputField(default=1, description="The number of values")
|
size: int = InputField(default=1, description="The number of values")
|
||||||
step: int = InputField(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
@ -49,14 +45,15 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||||
|
|
||||||
|
|
||||||
@title("Random Range")
|
@invocation(
|
||||||
@tags("range", "integer", "random", "collection")
|
"random_range",
|
||||||
|
title="Random Range",
|
||||||
|
tags=["range", "integer", "random", "collection"],
|
||||||
|
category="collections",
|
||||||
|
)
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
type: Literal["random_range"] = "random_range"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
low: int = InputField(default=0, description="The inclusive low value")
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
size: int = InputField(default=1, description="The number of values to generate")
|
size: int = InputField(default=1, description="The number of values to generate")
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
@ -26,8 +26,8 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
@ -44,13 +44,10 @@ class ConditioningFieldData:
|
|||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
@title("Compel Prompt")
|
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
|
||||||
@tags("prompt", "compel")
|
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
|
||||||
|
|
||||||
prompt: str = InputField(
|
prompt: str = InputField(
|
||||||
default="",
|
default="",
|
||||||
description=FieldDescriptions.compel_prompt,
|
description=FieldDescriptions.compel_prompt,
|
||||||
@ -265,13 +262,15 @@ class SDXLPromptInvocationBase:
|
|||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Compel Prompt")
|
@invocation(
|
||||||
@tags("sdxl", "compel", "prompt")
|
"sdxl_compel_prompt",
|
||||||
|
title="SDXL Prompt",
|
||||||
|
tags=["sdxl", "compel", "prompt"],
|
||||||
|
category="conditioning",
|
||||||
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
|
|
||||||
|
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
|
||||||
original_width: int = InputField(default=1024, description="")
|
original_width: int = InputField(default=1024, description="")
|
||||||
@ -324,13 +323,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Refiner Compel Prompt")
|
@invocation(
|
||||||
@tags("sdxl", "compel", "prompt")
|
"sdxl_refiner_compel_prompt",
|
||||||
|
title="SDXL Refiner Prompt",
|
||||||
|
tags=["sdxl", "compel", "prompt"],
|
||||||
|
category="conditioning",
|
||||||
|
)
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
|
||||||
|
|
||||||
style: str = InputField(
|
style: str = InputField(
|
||||||
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
|
||||||
) # TODO: ?
|
) # TODO: ?
|
||||||
@ -372,20 +373,17 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("clip_skip_output")
|
||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
|
||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@title("CLIP Skip")
|
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
|
||||||
@tags("clipskip", "clip", "skip")
|
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
type: Literal["clip_skip"] = "clip_skip"
|
|
||||||
|
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
|
@ -40,8 +40,8 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -87,23 +87,18 @@ class ControlField(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("control_output")
|
||||||
class ControlOutput(BaseInvocationOutput):
|
class ControlOutput(BaseInvocationOutput):
|
||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
type: Literal["control_output"] = "control_output"
|
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||||
|
|
||||||
|
|
||||||
@title("ControlNet")
|
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
|
||||||
@tags("controlnet")
|
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
type: Literal["controlnet"] = "controlnet"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The control image")
|
image: ImageField = InputField(description="The control image")
|
||||||
control_model: ControlNetModelField = InputField(
|
control_model: ControlNetModelField = InputField(
|
||||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||||
@ -134,12 +129,10 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
||||||
class ImageProcessorInvocation(BaseInvocation):
|
class ImageProcessorInvocation(BaseInvocation):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
type: Literal["image_processor"] = "image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -151,11 +144,6 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
# FIXME: what happened to image metadata?
|
|
||||||
# metadata = context.services.metadata.build_metadata(
|
|
||||||
# session_id=context.graph_execution_state_id, node=self
|
|
||||||
# )
|
|
||||||
|
|
||||||
# currently can't see processed image in node UI without a showImage node,
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@ -165,6 +153,7 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
@ -179,14 +168,15 @@ class ImageProcessorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Canny Processor")
|
@invocation(
|
||||||
@tags("controlnet", "canny")
|
"canny_image_processor",
|
||||||
|
title="Canny Processor",
|
||||||
|
tags=["controlnet", "canny"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
|
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
|
||||||
|
|
||||||
# Input
|
|
||||||
low_threshold: int = InputField(
|
low_threshold: int = InputField(
|
||||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||||
)
|
)
|
||||||
@ -200,14 +190,15 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("HED (softedge) Processor")
|
@invocation(
|
||||||
@tags("controlnet", "hed", "softedge")
|
"hed_image_processor",
|
||||||
|
title="HED (softedge) Processor",
|
||||||
|
tags=["controlnet", "hed", "softedge"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
|
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
@ -227,14 +218,15 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Lineart Processor")
|
@invocation(
|
||||||
@tags("controlnet", "lineart")
|
"lineart_image_processor",
|
||||||
|
title="Lineart Processor",
|
||||||
|
tags=["controlnet", "lineart"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
|
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||||
@ -247,14 +239,15 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Lineart Anime Processor")
|
@invocation(
|
||||||
@tags("controlnet", "lineart", "anime")
|
"lineart_anime_image_processor",
|
||||||
|
title="Lineart Anime Processor",
|
||||||
|
tags=["controlnet", "lineart", "anime"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
|
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
@ -268,14 +261,15 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Openpose Processor")
|
@invocation(
|
||||||
@tags("controlnet", "openpose", "pose")
|
"openpose_image_processor",
|
||||||
|
title="Openpose Processor",
|
||||||
|
tags=["controlnet", "openpose", "pose"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
|
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
@ -291,14 +285,15 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Midas (Depth) Processor")
|
@invocation(
|
||||||
@tags("controlnet", "midas", "depth")
|
"midas_depth_image_processor",
|
||||||
|
title="Midas Depth Processor",
|
||||||
|
tags=["controlnet", "midas"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
|
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
@ -316,14 +311,15 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Normal BAE Processor")
|
@invocation(
|
||||||
@tags("controlnet", "normal", "bae")
|
"normalbae_image_processor",
|
||||||
|
title="Normal BAE Processor",
|
||||||
|
tags=["controlnet"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
|
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
@ -335,14 +331,10 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("MLSD Processor")
|
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
||||||
@tags("controlnet", "mlsd")
|
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
@ -360,14 +352,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("PIDI Processor")
|
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
||||||
@tags("controlnet", "pidi")
|
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
@ -385,14 +373,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Content Shuffle Processor")
|
@invocation(
|
||||||
@tags("controlnet", "contentshuffle")
|
"content_shuffle_image_processor",
|
||||||
|
title="Content Shuffle Processor",
|
||||||
|
tags=["controlnet", "contentshuffle"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
|
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
@ -413,27 +402,30 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
@title("Zoe (Depth) Processor")
|
@invocation(
|
||||||
@tags("controlnet", "zoe", "depth")
|
"zoe_depth_image_processor",
|
||||||
|
title="Zoe (Depth) Processor",
|
||||||
|
tags=["controlnet", "zoe", "depth"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Mediapipe Face Processor")
|
@invocation(
|
||||||
@tags("controlnet", "mediapipe", "face")
|
"mediapipe_face_processor",
|
||||||
|
title="Mediapipe Face Processor",
|
||||||
|
tags=["controlnet", "mediapipe", "face"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
|
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||||
|
|
||||||
@ -447,14 +439,15 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Leres (Depth) Processor")
|
@invocation(
|
||||||
@tags("controlnet", "leres", "depth")
|
"leres_image_processor",
|
||||||
|
title="Leres (Depth) Processor",
|
||||||
|
tags=["controlnet", "leres", "depth"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
|
|
||||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
|
||||||
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
|
||||||
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
boost: bool = InputField(default=False, description="Whether to use boost mode")
|
||||||
@ -474,14 +467,15 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Tile Resample Processor")
|
@invocation(
|
||||||
@tags("controlnet", "tile")
|
"tile_image_processor",
|
||||||
|
title="Tile Resample Processor",
|
||||||
|
tags=["controlnet", "tile"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Tile resampler processor"""
|
"""Tile resampler processor"""
|
||||||
|
|
||||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||||
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||||
|
|
||||||
@ -512,13 +506,15 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@title("Segment Anything Processor")
|
@invocation(
|
||||||
@tags("controlnet", "segmentanything")
|
"segment_anything_processor",
|
||||||
|
title="Segment Anything Processor",
|
||||||
|
tags=["controlnet", "segmentanything"],
|
||||||
|
category="controlnet",
|
||||||
|
)
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
|
||||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy
|
import numpy
|
||||||
@ -8,17 +7,18 @@ from PIL import Image, ImageOps
|
|||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("OpenCV Inpaint")
|
@invocation(
|
||||||
@tags("opencv", "inpaint")
|
"cv_inpaint",
|
||||||
|
title="OpenCV Inpaint",
|
||||||
|
tags=["opencv", "inpaint"],
|
||||||
|
category="inpaint",
|
||||||
|
)
|
||||||
class CvInpaintInvocation(BaseInvocation):
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||||
|
|
||||||
@ -45,6 +45,7 @@ class CvInpaintInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -13,18 +13,13 @@ from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
|||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("Show Image")
|
@invocation("show_image", title="Show Image", tags=["image"], category="image")
|
||||||
@tags("image")
|
|
||||||
class ShowImageInvocation(BaseInvocation):
|
class ShowImageInvocation(BaseInvocation):
|
||||||
"""Displays a provided image, and passes it forward in the pipeline."""
|
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["show_image"] = "show_image"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to show")
|
image: ImageField = InputField(description="The image to show")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -41,15 +36,10 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Blank Image")
|
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
|
||||||
@tags("image")
|
|
||||||
class BlankImageInvocation(BaseInvocation):
|
class BlankImageInvocation(BaseInvocation):
|
||||||
"""Creates a blank image and forwards it to the pipeline"""
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["blank_image"] = "blank_image"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
width: int = InputField(default=512, description="The width of the image")
|
width: int = InputField(default=512, description="The width of the image")
|
||||||
height: int = InputField(default=512, description="The height of the image")
|
height: int = InputField(default=512, description="The height of the image")
|
||||||
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
||||||
@ -65,6 +55,7 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -74,15 +65,10 @@ class BlankImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Crop Image")
|
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
|
||||||
@tags("image", "crop")
|
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_crop"] = "img_crop"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to crop")
|
image: ImageField = InputField(description="The image to crop")
|
||||||
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
|
||||||
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
@ -102,6 +88,7 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -111,15 +98,10 @@ class ImageCropInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Paste Image")
|
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
|
||||||
@tags("image", "paste")
|
|
||||||
class ImagePasteInvocation(BaseInvocation):
|
class ImagePasteInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_paste"] = "img_paste"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
base_image: ImageField = InputField(description="The base image")
|
base_image: ImageField = InputField(description="The base image")
|
||||||
image: ImageField = InputField(description="The image to paste")
|
image: ImageField = InputField(description="The image to paste")
|
||||||
mask: Optional[ImageField] = InputField(
|
mask: Optional[ImageField] = InputField(
|
||||||
@ -154,6 +136,7 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -163,15 +146,10 @@ class ImagePasteInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Mask from Alpha")
|
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
|
||||||
@tags("image", "mask")
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["tomask"] = "tomask"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to create the mask from")
|
image: ImageField = InputField(description="The image to create the mask from")
|
||||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||||
|
|
||||||
@ -189,6 +167,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -198,15 +177,10 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Multiply Images")
|
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
|
||||||
@tags("image", "multiply")
|
|
||||||
class ImageMultiplyInvocation(BaseInvocation):
|
class ImageMultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_mul"] = "img_mul"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image1: ImageField = InputField(description="The first image to multiply")
|
image1: ImageField = InputField(description="The first image to multiply")
|
||||||
image2: ImageField = InputField(description="The second image to multiply")
|
image2: ImageField = InputField(description="The second image to multiply")
|
||||||
|
|
||||||
@ -223,6 +197,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -235,15 +210,10 @@ class ImageMultiplyInvocation(BaseInvocation):
|
|||||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
@title("Extract Image Channel")
|
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
|
||||||
@tags("image", "channel")
|
|
||||||
class ImageChannelInvocation(BaseInvocation):
|
class ImageChannelInvocation(BaseInvocation):
|
||||||
"""Gets a channel from an image."""
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_chan"] = "img_chan"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to get the channel from")
|
image: ImageField = InputField(description="The image to get the channel from")
|
||||||
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||||
|
|
||||||
@ -259,6 +229,7 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -271,15 +242,10 @@ class ImageChannelInvocation(BaseInvocation):
|
|||||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
@title("Convert Image Mode")
|
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
|
||||||
@tags("image", "convert")
|
|
||||||
class ImageConvertInvocation(BaseInvocation):
|
class ImageConvertInvocation(BaseInvocation):
|
||||||
"""Converts an image to a different mode."""
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_conv"] = "img_conv"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to convert")
|
image: ImageField = InputField(description="The image to convert")
|
||||||
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||||
|
|
||||||
@ -295,6 +261,7 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -304,15 +271,10 @@ class ImageConvertInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Blur Image")
|
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
|
||||||
@tags("image", "blur")
|
|
||||||
class ImageBlurInvocation(BaseInvocation):
|
class ImageBlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_blur"] = "img_blur"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to blur")
|
image: ImageField = InputField(description="The image to blur")
|
||||||
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
|
||||||
# Metadata
|
# Metadata
|
||||||
@ -333,6 +295,7 @@ class ImageBlurInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -362,15 +325,10 @@ PIL_RESAMPLING_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@title("Resize Image")
|
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
|
||||||
@tags("image", "resize")
|
|
||||||
class ImageResizeInvocation(BaseInvocation):
|
class ImageResizeInvocation(BaseInvocation):
|
||||||
"""Resizes an image to specific dimensions"""
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_resize"] = "img_resize"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to resize")
|
image: ImageField = InputField(description="The image to resize")
|
||||||
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
@ -397,6 +355,7 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -406,15 +365,10 @@ class ImageResizeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Scale Image")
|
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
|
||||||
@tags("image", "scale")
|
|
||||||
class ImageScaleInvocation(BaseInvocation):
|
class ImageScaleInvocation(BaseInvocation):
|
||||||
"""Scales an image by a factor"""
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_scale"] = "img_scale"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to scale")
|
image: ImageField = InputField(description="The image to scale")
|
||||||
scale_factor: float = InputField(
|
scale_factor: float = InputField(
|
||||||
default=2.0,
|
default=2.0,
|
||||||
@ -442,6 +396,7 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -451,15 +406,10 @@ class ImageScaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Lerp Image")
|
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
|
||||||
@tags("image", "lerp")
|
|
||||||
class ImageLerpInvocation(BaseInvocation):
|
class ImageLerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_lerp"] = "img_lerp"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
@ -479,6 +429,7 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -488,15 +439,10 @@ class ImageLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Inverse Lerp Image")
|
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
|
||||||
@tags("image", "ilerp")
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation):
|
class ImageInverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_ilerp"] = "img_ilerp"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to lerp")
|
image: ImageField = InputField(description="The image to lerp")
|
||||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
@ -516,6 +462,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -525,15 +472,10 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Blur NSFW Image")
|
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
|
||||||
@tags("image", "nsfw")
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_nsfw"] = "img_nsfw"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
|
||||||
@ -559,6 +501,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -574,15 +517,10 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
|||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
@title("Add Invisible Watermark")
|
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
|
||||||
@tags("image", "watermark")
|
|
||||||
class ImageWatermarkInvocation(BaseInvocation):
|
class ImageWatermarkInvocation(BaseInvocation):
|
||||||
"""Add an invisible watermark to an image"""
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to check")
|
image: ImageField = InputField(description="The image to check")
|
||||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||||
metadata: Optional[CoreMetadata] = InputField(
|
metadata: Optional[CoreMetadata] = InputField(
|
||||||
@ -600,6 +538,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -609,14 +548,10 @@ class ImageWatermarkInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Mask Edge")
|
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
|
||||||
@tags("image", "mask", "inpaint")
|
|
||||||
class MaskEdgeInvocation(BaseInvocation):
|
class MaskEdgeInvocation(BaseInvocation):
|
||||||
"""Applies an edge mask to an image"""
|
"""Applies an edge mask to an image"""
|
||||||
|
|
||||||
type: Literal["mask_edge"] = "mask_edge"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to apply the mask to")
|
image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
edge_size: int = InputField(description="The size of the edge")
|
edge_size: int = InputField(description="The size of the edge")
|
||||||
edge_blur: int = InputField(description="The amount of blur on the edge")
|
edge_blur: int = InputField(description="The amount of blur on the edge")
|
||||||
@ -648,6 +583,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -657,14 +593,10 @@ class MaskEdgeInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Combine Mask")
|
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
|
||||||
@tags("image", "mask", "multiply")
|
|
||||||
class MaskCombineInvocation(BaseInvocation):
|
class MaskCombineInvocation(BaseInvocation):
|
||||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
type: Literal["mask_combine"] = "mask_combine"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
mask1: ImageField = InputField(description="The first mask to combine")
|
mask1: ImageField = InputField(description="The first mask to combine")
|
||||||
mask2: ImageField = InputField(description="The second image to combine")
|
mask2: ImageField = InputField(description="The second image to combine")
|
||||||
|
|
||||||
@ -681,6 +613,7 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -690,17 +623,13 @@ class MaskCombineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Color Correct")
|
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
|
||||||
@tags("image", "color")
|
|
||||||
class ColorCorrectInvocation(BaseInvocation):
|
class ColorCorrectInvocation(BaseInvocation):
|
||||||
"""
|
"""
|
||||||
Shifts the colors of a target image to match the reference image, optionally
|
Shifts the colors of a target image to match the reference image, optionally
|
||||||
using a mask to only color-correct certain regions of the target image.
|
using a mask to only color-correct certain regions of the target image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["color_correct"] = "color_correct"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to color-correct")
|
image: ImageField = InputField(description="The image to color-correct")
|
||||||
reference: ImageField = InputField(description="Reference image for color-correction")
|
reference: ImageField = InputField(description="Reference image for color-correction")
|
||||||
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||||
@ -789,6 +718,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -798,14 +728,10 @@ class ColorCorrectInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Image Hue Adjustment")
|
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
|
||||||
@tags("image", "hue", "hsl")
|
|
||||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Hue of an image."""
|
"""Adjusts the Hue of an image."""
|
||||||
|
|
||||||
type: Literal["img_hue_adjust"] = "img_hue_adjust"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||||
|
|
||||||
@ -831,6 +757,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -842,14 +769,15 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Image Luminosity Adjustment")
|
@invocation(
|
||||||
@tags("image", "luminosity", "hsl")
|
"img_luminosity_adjust",
|
||||||
|
title="Adjust Image Luminosity",
|
||||||
|
tags=["image", "luminosity", "hsl"],
|
||||||
|
category="image",
|
||||||
|
)
|
||||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Luminosity (Value) of an image."""
|
"""Adjusts the Luminosity (Value) of an image."""
|
||||||
|
|
||||||
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
luminosity: float = InputField(
|
luminosity: float = InputField(
|
||||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||||
@ -881,6 +809,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -892,14 +821,15 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Image Saturation Adjustment")
|
@invocation(
|
||||||
@tags("image", "saturation", "hsl")
|
"img_saturation_adjust",
|
||||||
|
title="Adjust Image Saturation",
|
||||||
|
tags=["image", "saturation", "hsl"],
|
||||||
|
category="image",
|
||||||
|
)
|
||||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||||
"""Adjusts the Saturation of an image."""
|
"""Adjusts the Saturation of an image."""
|
||||||
|
|
||||||
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to adjust")
|
image: ImageField = InputField(description="The image to adjust")
|
||||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||||
|
|
||||||
@ -929,6 +859,7 @@ class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -12,7 +12,7 @@ from invokeai.backend.image_util.lama import LaMA
|
|||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
def infill_methods() -> list[str]:
|
||||||
@ -116,14 +116,10 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
@title("Solid Color Infill")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
@tags("image", "inpaint")
|
|
||||||
class InfillColorInvocation(BaseInvocation):
|
class InfillColorInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
color: ColorField = InputField(
|
color: ColorField = InputField(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
@ -145,6 +141,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -154,14 +151,10 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Tile Infill")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
@tags("image", "inpaint")
|
|
||||||
class InfillTileInvocation(BaseInvocation):
|
class InfillTileInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
|
||||||
|
|
||||||
# Input
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = InputField(
|
seed: int = InputField(
|
||||||
@ -184,6 +177,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -193,14 +187,10 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("PatchMatch Infill")
|
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
@tags("image", "inpaint")
|
|
||||||
class InfillPatchMatchInvocation(BaseInvocation):
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -218,6 +208,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -227,14 +218,10 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("LaMa Infill")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
|
||||||
@tags("image", "inpaint")
|
|
||||||
class LaMaInfillInvocation(BaseInvocation):
|
class LaMaInfillInvocation(BaseInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
type: Literal["infill_lama"] = "infill_lama"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
@ -47,7 +47,18 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
|
|||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
@ -58,15 +69,27 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@title("Create Denoise Mask")
|
@invocation_output("scheduler_output")
|
||||||
@tags("mask", "denoise")
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
|
||||||
|
class SchedulerInvocation(BaseInvocation):
|
||||||
|
"""Selects a scheduler."""
|
||||||
|
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
|
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
||||||
|
return SchedulerOutput(scheduler=self.scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["create_denoise_mask"] = "create_denoise_mask"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||||
@ -158,14 +181,15 @@ def get_scheduler(
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
@title("Denoise Latents")
|
@invocation(
|
||||||
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
|
"denoise_latents",
|
||||||
|
title="Denoise Latents",
|
||||||
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
|
category="latents",
|
||||||
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
type: Literal["denoise_latents"] = "denoise_latents"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
@ -512,14 +536,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
@title("Latents to Image")
|
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
|
||||||
@tags("latents", "image", "vae", "l2i")
|
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i"] = "l2i"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -600,6 +620,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -612,14 +633,10 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
@title("Resize Latents")
|
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
|
||||||
@tags("latents", "resize")
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
|
|
||||||
type: Literal["lresize"] = "lresize"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -660,14 +677,10 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@title("Scale Latents")
|
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
|
||||||
@tags("latents", "resize")
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
|
|
||||||
type: Literal["lscale"] = "lscale"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -700,14 +713,10 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
@title("Image to Latents")
|
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
|
||||||
@tags("latents", "image", "vae", "i2l")
|
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
type: Literal["i2l"] = "i2l"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(
|
image: ImageField = InputField(
|
||||||
description="The image to encode",
|
description="The image to encode",
|
||||||
)
|
)
|
||||||
@ -784,14 +793,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
@title("Blend Latents")
|
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
|
||||||
@tags("latents", "blend")
|
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
type: Literal["lblend"] = "lblend"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents_a: LatentsField = InputField(
|
latents_a: LatentsField = InputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
|
@ -1,22 +1,16 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import IntegerOutput
|
from invokeai.app.invocations.primitives import IntegerOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("Add Integers")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
|
||||||
@tags("math")
|
|
||||||
class AddInvocation(BaseInvocation):
|
class AddInvocation(BaseInvocation):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
type: Literal["add"] = "add"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@ -24,14 +18,10 @@ class AddInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Subtract Integers")
|
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
|
||||||
@tags("math")
|
|
||||||
class SubtractInvocation(BaseInvocation):
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
|
|
||||||
type: Literal["sub"] = "sub"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@ -39,14 +29,10 @@ class SubtractInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a - self.b)
|
return IntegerOutput(value=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Multiply Integers")
|
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
|
||||||
@tags("math")
|
|
||||||
class MultiplyInvocation(BaseInvocation):
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
|
|
||||||
type: Literal["mul"] = "mul"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@ -54,14 +40,10 @@ class MultiplyInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=self.a * self.b)
|
return IntegerOutput(value=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Divide Integers")
|
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
|
||||||
@tags("math")
|
|
||||||
class DivideInvocation(BaseInvocation):
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
|
|
||||||
type: Literal["div"] = "div"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
@ -69,14 +51,10 @@ class DivideInvocation(BaseInvocation):
|
|||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@title("Random Integer")
|
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
|
||||||
@tags("math")
|
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
|
|
||||||
type: Literal["rand_int"] = "rand_int"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
low: int = InputField(default=0, description="The inclusive low value")
|
low: int = InputField(default=0, description="The inclusive low value")
|
||||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -8,8 +8,8 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
@ -91,21 +91,17 @@ class ImageMetadata(BaseModelExcludeNull):
|
|||||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("metadata_accumulator_output")
|
||||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||||
"""The output of the MetadataAccumulator node"""
|
"""The output of the MetadataAccumulator node"""
|
||||||
|
|
||||||
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
|
|
||||||
|
|
||||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||||
|
|
||||||
|
|
||||||
@title("Metadata Accumulator")
|
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
|
||||||
@tags("metadata")
|
|
||||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||||
"""Outputs a Core Metadata Object"""
|
"""Outputs a Core Metadata Object"""
|
||||||
|
|
||||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
|
||||||
|
|
||||||
generation_mode: str = InputField(
|
generation_mode: str = InputField(
|
||||||
description="The generation mode that output this image",
|
description="The generation mode that output this image",
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -13,8 +13,8 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -49,11 +49,10 @@ class VaeField(BaseModel):
|
|||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("model_loader_output")
|
||||||
class ModelLoaderOutput(BaseInvocationOutput):
|
class ModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
type: Literal["model_loader_output"] = "model_loader_output"
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
@ -74,14 +73,10 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
@title("Main Model")
|
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
|
||||||
@tags("model")
|
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["main_model_loader"] = "main_model_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
@ -170,25 +165,18 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("lora_loader_output")
|
||||||
class LoraLoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["lora_loader_output"] = "lora_loader_output"
|
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
@title("LoRA")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
|
||||||
@tags("lora", "model")
|
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
@ -247,25 +235,19 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("sdxl_lora_loader_output")
|
||||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL LoRA Loader Output"""
|
"""SDXL LoRA Loader Output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL LoRA")
|
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
|
||||||
@tags("sdxl", "lora", "model")
|
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = Field(
|
unet: Optional[UNetField] = Field(
|
||||||
@ -349,23 +331,17 @@ class VAEModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("vae_loader_output")
|
||||||
class VaeLoaderOutput(BaseInvocationOutput):
|
class VaeLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""VAE output"""
|
||||||
|
|
||||||
type: Literal["vae_loader_output"] = "vae_loader_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("VAE")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
|
||||||
@tags("vae", "model")
|
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
type: Literal["vae_loader"] = "vae_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
vae_model: VAEModelField = InputField(
|
vae_model: VAEModelField = InputField(
|
||||||
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
|
||||||
)
|
)
|
||||||
@ -392,24 +368,18 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("seamless_output")
|
||||||
class SeamlessModeOutput(BaseInvocationOutput):
|
class SeamlessModeOutput(BaseInvocationOutput):
|
||||||
"""Modified Seamless Model output"""
|
"""Modified Seamless Model output"""
|
||||||
|
|
||||||
type: Literal["seamless_output"] = "seamless_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("Seamless")
|
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
|
||||||
@tags("seamless", "model")
|
|
||||||
class SeamlessModeInvocation(BaseInvocation):
|
class SeamlessModeInvocation(BaseInvocation):
|
||||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||||
|
|
||||||
type: Literal["seamless"] = "seamless"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
@ -16,8 +15,8 @@ from .baseinvocation import (
|
|||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -62,12 +61,10 @@ Nodes
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("noise_output")
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
|
|
||||||
type: Literal["noise_output"] = "noise_output"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
|
||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
@ -81,14 +78,10 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Noise")
|
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
|
||||||
@tags("latents", "noise")
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
seed: int = InputField(
|
seed: int = InputField(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=SEED_MAX,
|
le=SEED_MAX,
|
||||||
|
@ -31,8 +31,8 @@ from .baseinvocation import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||||
@ -56,11 +56,8 @@ ORT_TO_NP_TYPE = {
|
|||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||||
|
|
||||||
|
|
||||||
@title("ONNX Prompt (Raw)")
|
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
|
||||||
@tags("onnx", "prompt")
|
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
|
||||||
|
|
||||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
@ -141,14 +138,15 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
@title("ONNX Text to Latents")
|
@invocation(
|
||||||
@tags("latents", "inference", "txt2img", "onnx")
|
"t2l_onnx",
|
||||||
|
title="ONNX Text to Latents",
|
||||||
|
tags=["latents", "inference", "txt2img", "onnx"],
|
||||||
|
category="latents",
|
||||||
|
)
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.positive_cond,
|
description=FieldDescriptions.positive_cond,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -316,14 +314,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
@title("ONNX Latents to Image")
|
@invocation(
|
||||||
@tags("latents", "image", "vae", "onnx")
|
"l2i_onnx",
|
||||||
|
title="ONNX Latents to Image",
|
||||||
|
tags=["latents", "image", "vae", "onnx"],
|
||||||
|
category="image",
|
||||||
|
)
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = InputField(
|
latents: LatentsField = InputField(
|
||||||
description=FieldDescriptions.denoised_latents,
|
description=FieldDescriptions.denoised_latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -376,6 +375,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -385,17 +385,14 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("model_loader_output_onnx")
|
||||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelField(BaseModel):
|
class OnnxModelField(BaseModel):
|
||||||
@ -406,14 +403,10 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
@title("ONNX Main Model")
|
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
|
||||||
@tags("onnx", "model")
|
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
model: OnnxModelField = InputField(
|
model: OnnxModelField = InputField(
|
||||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||||
)
|
)
|
||||||
|
@ -42,17 +42,13 @@ from matplotlib.ticker import MaxNLocator
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("Float Range")
|
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
|
||||||
@tags("math", "range")
|
|
||||||
class FloatLinearRangeInvocation(BaseInvocation):
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range"""
|
||||||
|
|
||||||
type: Literal["float_range"] = "float_range"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
start: float = InputField(default=5, description="The first value of the range")
|
start: float = InputField(default=5, description="The first value of the range")
|
||||||
stop: float = InputField(default=10, description="The last value of the range")
|
stop: float = InputField(default=10, description="The last value of the range")
|
||||||
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
|
||||||
@ -100,14 +96,10 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
|||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@title("Step Param Easing")
|
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
|
||||||
@tags("step", "easing")
|
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
type: Literal["step_param_easing"] = "step_param_easing"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
|
||||||
num_steps: int = InputField(default=20, description="number of denoising steps")
|
num_steps: int = InputField(default=20, description="number of denoising steps")
|
||||||
start_value: float = InputField(default=0.0, description="easing starting value")
|
start_value: float = InputField(default=0.0, description="easing starting value")
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -15,8 +15,8 @@ from .baseinvocation import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -29,44 +29,39 @@ Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
|
|||||||
# region Boolean
|
# region Boolean
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("boolean_output")
|
||||||
class BooleanOutput(BaseInvocationOutput):
|
class BooleanOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single boolean"""
|
"""Base class for nodes that output a single boolean"""
|
||||||
|
|
||||||
type: Literal["boolean_output"] = "boolean_output"
|
|
||||||
value: bool = OutputField(description="The output boolean")
|
value: bool = OutputField(description="The output boolean")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("boolean_collection_output")
|
||||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of booleans"""
|
"""Base class for nodes that output a collection of booleans"""
|
||||||
|
|
||||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("Boolean Primitive")
|
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
||||||
@tags("primitives", "boolean")
|
|
||||||
class BooleanInvocation(BaseInvocation):
|
class BooleanInvocation(BaseInvocation):
|
||||||
"""A boolean primitive value"""
|
"""A boolean primitive value"""
|
||||||
|
|
||||||
type: Literal["boolean"] = "boolean"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
value: bool = InputField(default=False, description="The boolean value")
|
value: bool = InputField(default=False, description="The boolean value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||||
return BooleanOutput(value=self.value)
|
return BooleanOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Boolean Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "boolean", "collection")
|
"boolean_collection",
|
||||||
|
title="Boolean Collection Primitive",
|
||||||
|
tags=["primitives", "boolean", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class BooleanCollectionInvocation(BaseInvocation):
|
class BooleanCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of boolean primitive values"""
|
"""A collection of boolean primitive values"""
|
||||||
|
|
||||||
type: Literal["boolean_collection"] = "boolean_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[bool] = InputField(
|
collection: list[bool] = InputField(
|
||||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||||
)
|
)
|
||||||
@ -80,44 +75,39 @@ class BooleanCollectionInvocation(BaseInvocation):
|
|||||||
# region Integer
|
# region Integer
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("integer_output")
|
||||||
class IntegerOutput(BaseInvocationOutput):
|
class IntegerOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single integer"""
|
"""Base class for nodes that output a single integer"""
|
||||||
|
|
||||||
type: Literal["integer_output"] = "integer_output"
|
|
||||||
value: int = OutputField(description="The output integer")
|
value: int = OutputField(description="The output integer")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("integer_collection_output")
|
||||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of integers"""
|
"""Base class for nodes that output a collection of integers"""
|
||||||
|
|
||||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Primitive")
|
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
||||||
@tags("primitives", "integer")
|
|
||||||
class IntegerInvocation(BaseInvocation):
|
class IntegerInvocation(BaseInvocation):
|
||||||
"""An integer primitive value"""
|
"""An integer primitive value"""
|
||||||
|
|
||||||
type: Literal["integer"] = "integer"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
value: int = InputField(default=0, description="The integer value")
|
value: int = InputField(default=0, description="The integer value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(value=self.value)
|
return IntegerOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "integer", "collection")
|
"integer_collection",
|
||||||
|
title="Integer Collection Primitive",
|
||||||
|
tags=["primitives", "integer", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class IntegerCollectionInvocation(BaseInvocation):
|
class IntegerCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of integer primitive values"""
|
"""A collection of integer primitive values"""
|
||||||
|
|
||||||
type: Literal["integer_collection"] = "integer_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[int] = InputField(
|
collection: list[int] = InputField(
|
||||||
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||||
)
|
)
|
||||||
@ -131,44 +121,39 @@ class IntegerCollectionInvocation(BaseInvocation):
|
|||||||
# region Float
|
# region Float
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("float_output")
|
||||||
class FloatOutput(BaseInvocationOutput):
|
class FloatOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single float"""
|
"""Base class for nodes that output a single float"""
|
||||||
|
|
||||||
type: Literal["float_output"] = "float_output"
|
|
||||||
value: float = OutputField(description="The output float")
|
value: float = OutputField(description="The output float")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("float_collection_output")
|
||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of floats"""
|
"""Base class for nodes that output a collection of floats"""
|
||||||
|
|
||||||
type: Literal["float_collection_output"] = "float_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("Float Primitive")
|
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
||||||
@tags("primitives", "float")
|
|
||||||
class FloatInvocation(BaseInvocation):
|
class FloatInvocation(BaseInvocation):
|
||||||
"""A float primitive value"""
|
"""A float primitive value"""
|
||||||
|
|
||||||
type: Literal["float"] = "float"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
value: float = InputField(default=0.0, description="The float value")
|
value: float = InputField(default=0.0, description="The float value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
return FloatOutput(value=self.value)
|
return FloatOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Float Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "float", "collection")
|
"float_collection",
|
||||||
|
title="Float Collection Primitive",
|
||||||
|
tags=["primitives", "float", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class FloatCollectionInvocation(BaseInvocation):
|
class FloatCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of float primitive values"""
|
"""A collection of float primitive values"""
|
||||||
|
|
||||||
type: Literal["float_collection"] = "float_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[float] = InputField(
|
collection: list[float] = InputField(
|
||||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||||
)
|
)
|
||||||
@ -182,44 +167,39 @@ class FloatCollectionInvocation(BaseInvocation):
|
|||||||
# region String
|
# region String
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("string_output")
|
||||||
class StringOutput(BaseInvocationOutput):
|
class StringOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single string"""
|
"""Base class for nodes that output a single string"""
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
|
||||||
value: str = OutputField(description="The output string")
|
value: str = OutputField(description="The output string")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("string_collection_output")
|
||||||
class StringCollectionOutput(BaseInvocationOutput):
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of strings"""
|
"""Base class for nodes that output a collection of strings"""
|
||||||
|
|
||||||
type: Literal["string_collection_output"] = "string_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("String Primitive")
|
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
||||||
@tags("primitives", "string")
|
|
||||||
class StringInvocation(BaseInvocation):
|
class StringInvocation(BaseInvocation):
|
||||||
"""A string primitive value"""
|
"""A string primitive value"""
|
||||||
|
|
||||||
type: Literal["string"] = "string"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
return StringOutput(value=self.value)
|
return StringOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("String Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "string", "collection")
|
"string_collection",
|
||||||
|
title="String Collection Primitive",
|
||||||
|
tags=["primitives", "string", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class StringCollectionInvocation(BaseInvocation):
|
class StringCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of string primitive values"""
|
"""A collection of string primitive values"""
|
||||||
|
|
||||||
type: Literal["string_collection"] = "string_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[str] = InputField(
|
collection: list[str] = InputField(
|
||||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||||
)
|
)
|
||||||
@ -239,33 +219,26 @@ class ImageField(BaseModel):
|
|||||||
image_name: str = Field(description="The name of the image")
|
image_name: str = Field(description="The name of the image")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
type: Literal["image_output"] = "image_output"
|
|
||||||
image: ImageField = OutputField(description="The output image")
|
image: ImageField = OutputField(description="The output image")
|
||||||
width: int = OutputField(description="The width of the image in pixels")
|
width: int = OutputField(description="The width of the image in pixels")
|
||||||
height: int = OutputField(description="The height of the image in pixels")
|
height: int = OutputField(description="The height of the image in pixels")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("image_collection_output")
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of images"""
|
"""Base class for nodes that output a collection of images"""
|
||||||
|
|
||||||
type: Literal["image_collection_output"] = "image_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("Image Primitive")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
||||||
@tags("primitives", "image")
|
|
||||||
class ImageInvocation(BaseInvocation):
|
class ImageInvocation(BaseInvocation):
|
||||||
"""An image primitive value"""
|
"""An image primitive value"""
|
||||||
|
|
||||||
# Metadata
|
|
||||||
type: Literal["image"] = "image"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The image to load")
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -278,14 +251,15 @@ class ImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Image Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "image", "collection")
|
"image_collection",
|
||||||
|
title="Image Collection Primitive",
|
||||||
|
tags=["primitives", "image", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of image primitive values"""
|
"""A collection of image primitive values"""
|
||||||
|
|
||||||
type: Literal["image_collection"] = "image_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[ImageField] = InputField(
|
collection: list[ImageField] = InputField(
|
||||||
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||||
)
|
)
|
||||||
@ -306,10 +280,10 @@ class DenoiseMaskField(BaseModel):
|
|||||||
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("denoise_mask_output")
|
||||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
type: Literal["denoise_mask_output"] = "denoise_mask_output"
|
|
||||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
|
||||||
|
|
||||||
@ -325,11 +299,10 @@ class LatentsField(BaseModel):
|
|||||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("latents_output")
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single latents tensor"""
|
"""Base class for nodes that output a single latents tensor"""
|
||||||
|
|
||||||
type: Literal["latents_output"] = "latents_output"
|
|
||||||
|
|
||||||
latents: LatentsField = OutputField(
|
latents: LatentsField = OutputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
)
|
)
|
||||||
@ -337,25 +310,20 @@ class LatentsOutput(BaseInvocationOutput):
|
|||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("latents_collection_output")
|
||||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of latents tensors"""
|
"""Base class for nodes that output a collection of latents tensors"""
|
||||||
|
|
||||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
|
||||||
|
|
||||||
collection: list[LatentsField] = OutputField(
|
collection: list[LatentsField] = OutputField(
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
ui_type=UIType.LatentsCollection,
|
ui_type=UIType.LatentsCollection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Latents Primitive")
|
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
||||||
@tags("primitives", "latents")
|
|
||||||
class LatentsInvocation(BaseInvocation):
|
class LatentsInvocation(BaseInvocation):
|
||||||
"""A latents tensor primitive value"""
|
"""A latents tensor primitive value"""
|
||||||
|
|
||||||
type: Literal["latents"] = "latents"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
@ -364,14 +332,15 @@ class LatentsInvocation(BaseInvocation):
|
|||||||
return build_latents_output(self.latents.latents_name, latents)
|
return build_latents_output(self.latents.latents_name, latents)
|
||||||
|
|
||||||
|
|
||||||
@title("Latents Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "latents", "collection")
|
"latents_collection",
|
||||||
|
title="Latents Collection Primitive",
|
||||||
|
tags=["primitives", "latents", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class LatentsCollectionInvocation(BaseInvocation):
|
class LatentsCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of latents tensor primitive values"""
|
"""A collection of latents tensor primitive values"""
|
||||||
|
|
||||||
type: Literal["latents_collection"] = "latents_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[LatentsField] = InputField(
|
collection: list[LatentsField] = InputField(
|
||||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||||
)
|
)
|
||||||
@ -405,30 +374,24 @@ class ColorField(BaseModel):
|
|||||||
return (self.r, self.g, self.b, self.a)
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("color_output")
|
||||||
class ColorOutput(BaseInvocationOutput):
|
class ColorOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single color"""
|
"""Base class for nodes that output a single color"""
|
||||||
|
|
||||||
type: Literal["color_output"] = "color_output"
|
|
||||||
color: ColorField = OutputField(description="The output color")
|
color: ColorField = OutputField(description="The output color")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("color_collection_output")
|
||||||
class ColorCollectionOutput(BaseInvocationOutput):
|
class ColorCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of colors"""
|
"""Base class for nodes that output a collection of colors"""
|
||||||
|
|
||||||
type: Literal["color_collection_output"] = "color_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||||
|
|
||||||
|
|
||||||
@title("Color Primitive")
|
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
||||||
@tags("primitives", "color")
|
|
||||||
class ColorInvocation(BaseInvocation):
|
class ColorInvocation(BaseInvocation):
|
||||||
"""A color primitive value"""
|
"""A color primitive value"""
|
||||||
|
|
||||||
type: Literal["color"] = "color"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||||
@ -446,47 +409,47 @@ class ConditioningField(BaseModel):
|
|||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
type: Literal["conditioning_output"] = "conditioning_output"
|
|
||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||||
|
|
||||||
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
collection: list[ConditioningField] = OutputField(
|
collection: list[ConditioningField] = OutputField(
|
||||||
description="The output conditioning tensors",
|
description="The output conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
ui_type=UIType.ConditioningCollection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Conditioning Primitive")
|
@invocation(
|
||||||
@tags("primitives", "conditioning")
|
"conditioning",
|
||||||
|
title="Conditioning Primitive",
|
||||||
|
tags=["primitives", "conditioning"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class ConditioningInvocation(BaseInvocation):
|
class ConditioningInvocation(BaseInvocation):
|
||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
type: Literal["conditioning"] = "conditioning"
|
|
||||||
|
|
||||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
return ConditioningOutput(conditioning=self.conditioning)
|
return ConditioningOutput(conditioning=self.conditioning)
|
||||||
|
|
||||||
|
|
||||||
@title("Conditioning Primitive Collection")
|
@invocation(
|
||||||
@tags("primitives", "conditioning", "collection")
|
"conditioning_collection",
|
||||||
|
title="Conditioning Collection Primitive",
|
||||||
|
tags=["primitives", "conditioning", "collection"],
|
||||||
|
category="primitives",
|
||||||
|
)
|
||||||
class ConditioningCollectionInvocation(BaseInvocation):
|
class ConditioningCollectionInvocation(BaseInvocation):
|
||||||
"""A collection of conditioning tensor primitive values"""
|
"""A collection of conditioning tensor primitive values"""
|
||||||
|
|
||||||
type: Literal["conditioning_collection"] = "conditioning_collection"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
collection: list[ConditioningField] = InputField(
|
collection: list[ConditioningField] = InputField(
|
||||||
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from os.path import exists
|
from os.path import exists
|
||||||
from typing import Literal, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
|
||||||
@ -7,17 +7,13 @@ from pydantic import validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||||
|
|
||||||
|
|
||||||
@title("Dynamic Prompt")
|
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
|
||||||
@tags("prompt", "collection")
|
|
||||||
class DynamicPromptInvocation(BaseInvocation):
|
class DynamicPromptInvocation(BaseInvocation):
|
||||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||||
|
|
||||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
|
||||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
@ -33,15 +29,11 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
return StringCollectionOutput(collection=prompts)
|
return StringCollectionOutput(collection=prompts)
|
||||||
|
|
||||||
|
|
||||||
@title("Prompts from File")
|
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
|
||||||
@tags("prompt", "file")
|
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
type: Literal["prompt_from_file"] = "prompt_from_file"
|
file_path: str = InputField(description="Path to prompt text file")
|
||||||
|
|
||||||
# Inputs
|
|
||||||
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
|
|
||||||
pre_prompt: Optional[str] = InputField(
|
pre_prompt: Optional[str] = InputField(
|
||||||
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -10,41 +8,35 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
tags,
|
invocation,
|
||||||
title,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("sdxl_model_loader_output")
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("sdxl_refiner_model_loader_output")
|
||||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Main Model")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
|
||||||
@tags("model", "sdxl")
|
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
model: MainModelField = InputField(
|
model: MainModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||||
)
|
)
|
||||||
@ -122,14 +114,15 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Refiner Model")
|
@invocation(
|
||||||
@tags("model", "sdxl", "refiner")
|
"sdxl_refiner_model_loader",
|
||||||
|
title="SDXL Refiner Model",
|
||||||
|
tags=["model", "sdxl", "refiner"],
|
||||||
|
category="model",
|
||||||
|
)
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
model: MainModelField = InputField(
|
model: MainModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_refiner_model,
|
description=FieldDescriptions.sdxl_refiner_model,
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
|||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
# TODO: Use model manager to load?
|
# TODO: Use model manager to load?
|
||||||
@ -23,14 +23,10 @@ ESRGAN_MODELS = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@title("Upscale (RealESRGAN)")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
|
||||||
@tags("esrgan", "upscale")
|
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
type: Literal["esrgan"] = "esrgan"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
|
|
||||||
@ -110,6 +106,7 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
@ -14,11 +14,13 @@ from ..invocations import * # noqa: F401 F403
|
|||||||
from ..invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
invocation,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
@ -148,24 +150,16 @@ class NodeAlreadyExecutedError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Create and use an Empty output?
|
# TODO: Create and use an Empty output?
|
||||||
|
@invocation_output("graph_output")
|
||||||
class GraphInvocationOutput(BaseInvocationOutput):
|
class GraphInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["graph_output"] = "graph_output"
|
pass
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"image",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
|
@invocation("graph")
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
"""Execute a graph"""
|
"""Execute a graph"""
|
||||||
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
|
|
||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
graph: "Graph" = Field(description="The graph to run", default=None)
|
graph: "Graph" = Field(description="The graph to run", default=None)
|
||||||
|
|
||||||
@ -174,22 +168,20 @@ class GraphInvocation(BaseInvocation):
|
|||||||
return GraphInvocationOutput()
|
return GraphInvocationOutput()
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("iterate_output")
|
||||||
class IterateInvocationOutput(BaseInvocationOutput):
|
class IterateInvocationOutput(BaseInvocationOutput):
|
||||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||||
|
|
||||||
type: Literal["iterate_output"] = "iterate_output"
|
|
||||||
|
|
||||||
item: Any = OutputField(
|
item: Any = OutputField(
|
||||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
|
@invocation("iterate")
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
"""Iterates over a list of items"""
|
"""Iterates over a list of items"""
|
||||||
|
|
||||||
type: Literal["iterate"] = "iterate"
|
|
||||||
|
|
||||||
collection: list[Any] = InputField(
|
collection: list[Any] = InputField(
|
||||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||||
)
|
)
|
||||||
@ -200,19 +192,17 @@ class IterateInvocation(BaseInvocation):
|
|||||||
return IterateInvocationOutput(item=self.collection[self.index])
|
return IterateInvocationOutput(item=self.collection[self.index])
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("collect_output")
|
||||||
class CollectInvocationOutput(BaseInvocationOutput):
|
class CollectInvocationOutput(BaseInvocationOutput):
|
||||||
type: Literal["collect_output"] = "collect_output"
|
|
||||||
|
|
||||||
collection: list[Any] = OutputField(
|
collection: list[Any] = OutputField(
|
||||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("collect")
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
|
|
||||||
type: Literal["collect"] = "collect"
|
|
||||||
|
|
||||||
item: Any = InputField(
|
item: Any = InputField(
|
||||||
description="The item to collect (all inputs must be of the same type)",
|
description="The item to collect (all inputs must be of the same type)",
|
||||||
ui_type=UIType.CollectionItem,
|
ui_type=UIType.CollectionItem,
|
||||||
|
@ -60,7 +60,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
graph: Optional[dict] = None,
|
workflow: Optional[str] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
@ -110,7 +110,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
graph: Optional[dict] = None,
|
workflow: Optional[str] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -119,12 +119,23 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None or workflow is not None:
|
||||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
if metadata is not None:
|
||||||
if graph is not None:
|
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||||
pnginfo.add_text("invokeai_graph", json.dumps(graph))
|
if workflow is not None:
|
||||||
|
pnginfo.add_text("invokeai_workflow", workflow)
|
||||||
|
else:
|
||||||
|
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
|
||||||
|
# TODO: retain non-invokeai metadata on save...
|
||||||
|
original_metadata = image.info.get("invokeai_metadata", None)
|
||||||
|
if original_metadata is not None:
|
||||||
|
pnginfo.add_text("invokeai_metadata", original_metadata)
|
||||||
|
original_workflow = image.info.get("invokeai_workflow", None)
|
||||||
|
if original_workflow is not None:
|
||||||
|
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||||
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
|
@ -54,6 +54,7 @@ class ImageServiceABC(ABC):
|
|||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
@ -177,6 +178,7 @@ class ImageService(ImageServiceABC):
|
|||||||
board_id: Optional[str] = None,
|
board_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_origin not in ResourceOrigin:
|
if image_origin not in ResourceOrigin:
|
||||||
raise InvalidOriginException
|
raise InvalidOriginException
|
||||||
@ -186,16 +188,16 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
image_name = self._services.names.create_image_name()
|
image_name = self._services.names.create_image_name()
|
||||||
|
|
||||||
graph = None
|
# TODO: Do we want to store the graph in the image at all? I don't think so...
|
||||||
|
# graph = None
|
||||||
if session_id is not None:
|
# if session_id is not None:
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||||
if session_raw is not None:
|
# if session_raw is not None:
|
||||||
try:
|
# try:
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
# graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
# self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
# graph = None
|
||||||
|
|
||||||
(width, height) = image.size
|
(width, height) = image.size
|
||||||
|
|
||||||
@ -217,7 +219,7 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
|
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
|
@ -53,7 +53,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
- `starred`: change whether the image is starred
|
- `starred`: change whether the image is starred
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
image_category: Optional[ImageCategory] = Field(default=None, description="The image's new category.")
|
||||||
"""The image's new category."""
|
"""The image's new category."""
|
||||||
session_id: Optional[StrictStr] = Field(
|
session_id: Optional[StrictStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -7,5 +7,4 @@ stats.html
|
|||||||
index.html
|
index.html
|
||||||
.yarn/
|
.yarn/
|
||||||
*.scss
|
*.scss
|
||||||
src/services/api/
|
src/services/api/schema.d.ts
|
||||||
src/services/fixtures/*
|
|
||||||
|
@ -7,8 +7,7 @@ index.html
|
|||||||
.yarn/
|
.yarn/
|
||||||
.yalc/
|
.yalc/
|
||||||
*.scss
|
*.scss
|
||||||
src/services/api/
|
src/services/api/schema.d.ts
|
||||||
src/services/fixtures/*
|
|
||||||
docs/
|
docs/
|
||||||
static/
|
static/
|
||||||
src/theme/css/overlayscrollbars.css
|
src/theme/css/overlayscrollbars.css
|
||||||
|
@ -74,6 +74,7 @@
|
|||||||
"@nanostores/react": "^0.7.1",
|
"@nanostores/react": "^0.7.1",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
|
"@stevebel/png": "^1.5.1",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"formik": "^2.4.3",
|
"formik": "^2.4.3",
|
||||||
"framer-motion": "^10.16.1",
|
"framer-motion": "^10.16.1",
|
||||||
@ -110,6 +111,7 @@
|
|||||||
"roarr": "^7.15.1",
|
"roarr": "^7.15.1",
|
||||||
"serialize-error": "^11.0.1",
|
"serialize-error": "^11.0.1",
|
||||||
"socket.io-client": "^4.7.2",
|
"socket.io-client": "^4.7.2",
|
||||||
|
"type-fest": "^4.2.0",
|
||||||
"use-debounce": "^9.0.4",
|
"use-debounce": "^9.0.4",
|
||||||
"use-image": "^1.1.1",
|
"use-image": "^1.1.1",
|
||||||
"uuid": "^9.0.0",
|
"uuid": "^9.0.0",
|
||||||
|
@ -719,7 +719,7 @@
|
|||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"reloadNodeTemplates": "Reload Node Templates",
|
"reloadNodeTemplates": "Reload Node Templates",
|
||||||
"saveWorkflow": "Save Workflow",
|
"downloadWorkflow": "Download Workflow JSON",
|
||||||
"loadWorkflow": "Load Workflow",
|
"loadWorkflow": "Load Workflow",
|
||||||
"resetWorkflow": "Reset Workflow",
|
"resetWorkflow": "Reset Workflow",
|
||||||
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import {
|
import {
|
||||||
KeyboardEvent,
|
KeyboardEvent,
|
||||||
ReactNode,
|
ReactNode,
|
||||||
@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
import { useUploadImageMutation } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
isSelected: boolean;
|
||||||
|
isHovered: boolean;
|
||||||
|
};
|
||||||
|
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||||
|
const shadow = useMemo(() => {
|
||||||
|
if (isSelected && isHovered) {
|
||||||
|
return 'nodeHoveredSelected.light';
|
||||||
|
}
|
||||||
|
if (isSelected) {
|
||||||
|
return 'nodeSelected.light';
|
||||||
|
}
|
||||||
|
if (isHovered) {
|
||||||
|
return 'nodeHovered.light';
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
}, [isHovered, isSelected]);
|
||||||
|
const shadowDark = useMemo(() => {
|
||||||
|
if (isSelected && isHovered) {
|
||||||
|
return 'nodeHoveredSelected.dark';
|
||||||
|
}
|
||||||
|
if (isSelected) {
|
||||||
|
return 'nodeSelected.dark';
|
||||||
|
}
|
||||||
|
if (isHovered) {
|
||||||
|
return 'nodeHovered.dark';
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
}, [isHovered, isSelected]);
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
className="selection-box"
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
bottom: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
borderRadius: 'base',
|
||||||
|
opacity: isSelected || isHovered ? 1 : 0.5,
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
pointerEvents: 'none',
|
||||||
|
shadow,
|
||||||
|
_dark: {
|
||||||
|
shadow: shadowDark,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SelectionOverlay);
|
@ -15,6 +15,7 @@ import { BoardDTO } from 'services/api/types';
|
|||||||
import { menuListMotionProps } from 'theme/components/menu';
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
||||||
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
|
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
board?: BoardDTO;
|
board?: BoardDTO;
|
||||||
@ -33,12 +34,16 @@ const BoardContextMenu = ({
|
|||||||
|
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(stateSelector, ({ gallery, system }) => {
|
createSelector(
|
||||||
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
stateSelector,
|
||||||
const isProcessing = system.isProcessing;
|
({ gallery, system }) => {
|
||||||
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
||||||
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
|
const isProcessing = system.isProcessing;
|
||||||
}),
|
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
||||||
|
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
[board_id]
|
[board_id]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -9,20 +9,24 @@ import {
|
|||||||
MenuButton,
|
MenuButton,
|
||||||
MenuList,
|
MenuList,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
|
setActiveTab,
|
||||||
setShouldShowImageDetails,
|
setShouldShowImageDetails,
|
||||||
setShouldShowProgressInViewer,
|
setShouldShowProgressInViewer,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
@ -37,12 +41,12 @@ import {
|
|||||||
FaSeedling,
|
FaSeedling,
|
||||||
FaShareAlt,
|
FaShareAlt,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
|
import { MdDeviceHub } from 'react-icons/md';
|
||||||
import {
|
import {
|
||||||
useGetImageDTOQuery,
|
useGetImageDTOQuery,
|
||||||
useGetImageMetadataQuery,
|
useGetImageMetadataFromFileQuery,
|
||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { menuListMotionProps } from 'theme/components/menu';
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
import { useDebounce } from 'use-debounce';
|
|
||||||
import { sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToImg2Img } from '../../store/actions';
|
||||||
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
|
||||||
|
|
||||||
@ -101,22 +105,36 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||||
useRecallParameters();
|
useRecallParameters();
|
||||||
|
|
||||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
|
||||||
lastSelectedImage,
|
|
||||||
500
|
|
||||||
);
|
|
||||||
|
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
const { currentData: imageDTO } = useGetImageDTOQuery(
|
||||||
lastSelectedImage?.image_name ?? skipToken
|
lastSelectedImage?.image_name ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
const { currentData: metadataData } = useGetImageMetadataQuery(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
debounceState.isPending()
|
lastSelectedImage?.image_name ?? skipToken,
|
||||||
? skipToken
|
{
|
||||||
: debouncedMetadataQueryArg?.image_name ?? skipToken
|
selectFromResult: (res) => ({
|
||||||
|
isLoading: res.isFetching,
|
||||||
|
metadata: res?.currentData?.metadata,
|
||||||
|
workflow: res?.currentData?.workflow,
|
||||||
|
}),
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const metadata = metadataData?.metadata;
|
const handleLoadWorkflow = useCallback(() => {
|
||||||
|
if (!workflow) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(workflowLoaded(workflow));
|
||||||
|
dispatch(setActiveTab('nodes'));
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
recallAllParameters(metadata);
|
recallAllParameters(metadata);
|
||||||
@ -153,6 +171,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
|
|
||||||
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
useHotkeys('p', handleUsePrompt, [imageDTO]);
|
||||||
|
|
||||||
|
useHotkeys('w', handleLoadWorkflow, [workflow]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
dispatch(sentImageToImg2Img());
|
dispatch(sentImageToImg2Img());
|
||||||
dispatch(initialImageSelected(imageDTO));
|
dispatch(initialImageSelected(imageDTO));
|
||||||
@ -259,22 +279,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
|
|
||||||
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
|
isLoading={isLoading}
|
||||||
|
icon={<MdDeviceHub />}
|
||||||
|
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||||
|
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||||
|
isDisabled={!workflow}
|
||||||
|
onClick={handleLoadWorkflow}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
isLoading={isLoading}
|
||||||
icon={<FaQuoteRight />}
|
icon={<FaQuoteRight />}
|
||||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||||
isDisabled={!metadata?.positive_prompt}
|
isDisabled={!metadata?.positive_prompt}
|
||||||
onClick={handleUsePrompt}
|
onClick={handleUsePrompt}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
|
isLoading={isLoading}
|
||||||
icon={<FaSeedling />}
|
icon={<FaSeedling />}
|
||||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||||
isDisabled={!metadata?.seed}
|
isDisabled={!metadata?.seed}
|
||||||
onClick={handleUseSeed}
|
onClick={handleUseSeed}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
|
isLoading={isLoading}
|
||||||
icon={<FaAsterisk />}
|
icon={<FaAsterisk />}
|
||||||
tooltip={`${t('parameters.useAll')} (A)`}
|
tooltip={`${t('parameters.useAll')} (A)`}
|
||||||
aria-label={`${t('parameters.useAll')} (A)`}
|
aria-label={`${t('parameters.useAll')} (A)`}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { Flex, MenuItem, Text } from '@chakra-ui/react';
|
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
@ -8,9 +7,12 @@ import {
|
|||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/changeBoardModal/store/slice';
|
} from 'features/changeBoardModal/store/slice';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -26,14 +28,13 @@ import {
|
|||||||
FaShare,
|
FaShare,
|
||||||
FaTrash,
|
FaTrash,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { MdStar, MdStarBorder } from 'react-icons/md';
|
import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
|
||||||
import {
|
import {
|
||||||
useGetImageMetadataQuery,
|
useGetImageMetadataFromFileQuery,
|
||||||
useStarImagesMutation,
|
useStarImagesMutation,
|
||||||
useUnstarImagesMutation,
|
useUnstarImagesMutation,
|
||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { useDebounce } from 'use-debounce';
|
|
||||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||||
|
|
||||||
type SingleSelectionMenuItemsProps = {
|
type SingleSelectionMenuItemsProps = {
|
||||||
@ -50,15 +51,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
|
|
||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||||
imageDTO.image_name,
|
imageDTO.image_name,
|
||||||
500
|
{
|
||||||
);
|
selectFromResult: (res) => ({
|
||||||
|
isLoading: res.isFetching,
|
||||||
const { currentData } = useGetImageMetadataQuery(
|
metadata: res?.currentData?.metadata,
|
||||||
debounceState.isPending()
|
workflow: res?.currentData?.workflow,
|
||||||
? skipToken
|
}),
|
||||||
: debouncedMetadataQueryArg ?? skipToken
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const [starImages] = useStarImagesMutation();
|
const [starImages] = useStarImagesMutation();
|
||||||
@ -67,8 +68,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
const { isClipboardAPIAvailable, copyImageToClipboard } =
|
||||||
useCopyImageToClipboard();
|
useCopyImageToClipboard();
|
||||||
|
|
||||||
const metadata = currentData?.metadata;
|
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
if (!imageDTO) {
|
if (!imageDTO) {
|
||||||
return;
|
return;
|
||||||
@ -99,6 +98,22 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
recallSeed(metadata?.seed);
|
recallSeed(metadata?.seed);
|
||||||
}, [metadata?.seed, recallSeed]);
|
}, [metadata?.seed, recallSeed]);
|
||||||
|
|
||||||
|
const handleLoadWorkflow = useCallback(() => {
|
||||||
|
if (!workflow) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(workflowLoaded(workflow));
|
||||||
|
dispatch(setActiveTab('nodes'));
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}, [dispatch, workflow]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
dispatch(sentImageToImg2Img());
|
dispatch(sentImageToImg2Img());
|
||||||
dispatch(initialImageSelected(imageDTO));
|
dispatch(initialImageSelected(imageDTO));
|
||||||
@ -118,7 +133,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
}, [dispatch, imageDTO, t, toaster]);
|
}, [dispatch, imageDTO, t, toaster]);
|
||||||
|
|
||||||
const handleUseAllParameters = useCallback(() => {
|
const handleUseAllParameters = useCallback(() => {
|
||||||
console.log(metadata);
|
|
||||||
recallAllParameters(metadata);
|
recallAllParameters(metadata);
|
||||||
}, [metadata, recallAllParameters]);
|
}, [metadata, recallAllParameters]);
|
||||||
|
|
||||||
@ -169,27 +183,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<FaQuoteRight />}
|
icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
|
||||||
|
onClickCapture={handleLoadWorkflow}
|
||||||
|
isDisabled={isLoading || !workflow}
|
||||||
|
>
|
||||||
|
{t('nodes.loadWorkflow')}
|
||||||
|
</MenuItem>
|
||||||
|
<MenuItem
|
||||||
|
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
|
||||||
onClickCapture={handleRecallPrompt}
|
onClickCapture={handleRecallPrompt}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
metadata?.positive_prompt === undefined &&
|
isLoading ||
|
||||||
metadata?.negative_prompt === undefined
|
(metadata?.positive_prompt === undefined &&
|
||||||
|
metadata?.negative_prompt === undefined)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
{t('parameters.usePrompt')}
|
{t('parameters.usePrompt')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<FaSeedling />}
|
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
|
||||||
onClickCapture={handleRecallSeed}
|
onClickCapture={handleRecallSeed}
|
||||||
isDisabled={metadata?.seed === undefined}
|
isDisabled={isLoading || metadata?.seed === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.useSeed')}
|
{t('parameters.useSeed')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<FaAsterisk />}
|
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={!metadata}
|
isDisabled={isLoading || !metadata}
|
||||||
>
|
>
|
||||||
{t('parameters.useAll')}
|
{t('parameters.useAll')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@ -228,20 +249,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
>
|
>
|
||||||
{t('gallery.deleteImage')}
|
{t('gallery.deleteImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
{metadata?.created_by && (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
padding: '5px 10px',
|
|
||||||
marginTop: '5px',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Text fontSize="xs" fontWeight="bold">
|
|
||||||
Created by {metadata?.created_by}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(SingleSelectionMenuItems);
|
export default memo(SingleSelectionMenuItems);
|
||||||
|
|
||||||
|
const SpinnerIcon = () => (
|
||||||
|
<Flex w="14px" alignItems="center" justifyContent="center">
|
||||||
|
<Spinner size="xs" />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
|
|||||||
import { isString } from 'lodash-es';
|
import { isString } from 'lodash-es';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { FaCopy, FaSave } from 'react-icons/fa';
|
import { FaCopy, FaDownload } from 'react-icons/fa';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
label: string;
|
label: string;
|
||||||
@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
|
|||||||
navigator.clipboard.writeText(dataString);
|
navigator.clipboard.writeText(dataString);
|
||||||
}, [dataString]);
|
}, [dataString]);
|
||||||
|
|
||||||
const handleSave = useCallback(() => {
|
const handleDownload = useCallback(() => {
|
||||||
const blob = new Blob([dataString]);
|
const blob = new Blob([dataString]);
|
||||||
const a = document.createElement('a');
|
const a = document.createElement('a');
|
||||||
a.href = URL.createObjectURL(blob);
|
a.href = URL.createObjectURL(blob);
|
||||||
@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
|
|||||||
</Box>
|
</Box>
|
||||||
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
|
||||||
{withDownload && (
|
{withDownload && (
|
||||||
<Tooltip label={`Save ${label} JSON`}>
|
<Tooltip label={`Download ${label} JSON`}>
|
||||||
<IconButton
|
<IconButton
|
||||||
aria-label={`Save ${label} JSON`}
|
aria-label={`Download ${label} JSON`}
|
||||||
icon={<FaSave />}
|
icon={<FaDownload />}
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
opacity={0.7}
|
opacity={0.7}
|
||||||
onClick={handleSave}
|
onClick={handleDownload}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
import { CoreMetadata } from 'features/nodes/types/types';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { UnsafeImageMetadata } from 'services/api/types';
|
|
||||||
import ImageMetadataItem from './ImageMetadataItem';
|
import ImageMetadataItem from './ImageMetadataItem';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
metadata?: UnsafeImageMetadata['metadata'];
|
metadata?: CoreMetadata;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImageMetadataActions = (props: Props) => {
|
const ImageMetadataActions = (props: Props) => {
|
||||||
@ -94,14 +94,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallNegativePrompt}
|
onClick={handleRecallNegativePrompt}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.seed !== undefined && (
|
{metadata.seed !== undefined && metadata.seed !== null && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label="Seed"
|
label="Seed"
|
||||||
value={metadata.seed}
|
value={metadata.seed}
|
||||||
onClick={handleRecallSeed}
|
onClick={handleRecallSeed}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.model !== undefined && (
|
{metadata.model !== undefined && metadata.model !== null && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label="Model"
|
label="Model"
|
||||||
value={metadata.model.model_name}
|
value={metadata.model.model_name}
|
||||||
@ -150,7 +150,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallSteps}
|
onClick={handleRecallSteps}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.cfg_scale !== undefined && (
|
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label="CFG scale"
|
label="CFG scale"
|
||||||
value={metadata.cfg_scale}
|
value={metadata.cfg_scale}
|
||||||
|
@ -9,14 +9,12 @@ import {
|
|||||||
Tabs,
|
Tabs,
|
||||||
Text,
|
Text,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { useDebounce } from 'use-debounce';
|
|
||||||
import ImageMetadataActions from './ImageMetadataActions';
|
|
||||||
import DataViewer from './DataViewer';
|
import DataViewer from './DataViewer';
|
||||||
|
import ImageMetadataActions from './ImageMetadataActions';
|
||||||
|
|
||||||
type ImageMetadataViewerProps = {
|
type ImageMetadataViewerProps = {
|
||||||
image: ImageDTO;
|
image: ImageDTO;
|
||||||
@ -29,19 +27,16 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
// dispatch(setShouldShowImageDetails(false));
|
// dispatch(setShouldShowImageDetails(false));
|
||||||
// });
|
// });
|
||||||
|
|
||||||
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
|
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||||
image.image_name,
|
image.image_name,
|
||||||
500
|
{
|
||||||
|
selectFromResult: (res) => ({
|
||||||
|
metadata: res?.currentData?.metadata,
|
||||||
|
workflow: res?.currentData?.workflow,
|
||||||
|
}),
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const { currentData } = useGetImageMetadataQuery(
|
|
||||||
debounceState.isPending()
|
|
||||||
? skipToken
|
|
||||||
: debouncedMetadataQueryArg ?? skipToken
|
|
||||||
);
|
|
||||||
const metadata = currentData?.metadata;
|
|
||||||
const graph = currentData?.graph;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
layerStyle="first"
|
layerStyle="first"
|
||||||
@ -71,17 +66,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||||
>
|
>
|
||||||
<TabList>
|
<TabList>
|
||||||
<Tab>Core Metadata</Tab>
|
<Tab>Metadata</Tab>
|
||||||
<Tab>Image Details</Tab>
|
<Tab>Image Details</Tab>
|
||||||
<Tab>Graph</Tab>
|
<Tab>Workflow</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
|
|
||||||
<TabPanels>
|
<TabPanels>
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
{metadata ? (
|
{metadata ? (
|
||||||
<DataViewer data={metadata} label="Core Metadata" />
|
<DataViewer data={metadata} label="Metadata" />
|
||||||
) : (
|
) : (
|
||||||
<IAINoContentFallback label="No core metadata found" />
|
<IAINoContentFallback label="No metadata found" />
|
||||||
)}
|
)}
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
@ -92,10 +87,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
|||||||
)}
|
)}
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
{graph ? (
|
{workflow ? (
|
||||||
<DataViewer data={graph} label="Graph" />
|
<DataViewer data={workflow} label="Workflow" />
|
||||||
) : (
|
) : (
|
||||||
<IAINoContentFallback label="No graph found" />
|
<IAINoContentFallback label="No workflow found" />
|
||||||
)}
|
)}
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
|
||||||
|
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||||
|
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
|
const embedWorkflow = useEmbedWorkflow(nodeId);
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(
|
||||||
|
nodeEmbedWorkflowChanged({
|
||||||
|
nodeId,
|
||||||
|
embedWorkflow: e.target.checked,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!hasImageOutput) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||||
|
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
|
||||||
|
<Checkbox
|
||||||
|
className="nopan"
|
||||||
|
size="sm"
|
||||||
|
onChange={handleChange}
|
||||||
|
isChecked={embedWorkflow}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(EmbedWorkflowCheckbox);
|
@ -41,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
|||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
py: 1,
|
py: 2,
|
||||||
gap: 1,
|
gap: 1,
|
||||||
borderBottomRadius: withFooter ? 0 : 'base',
|
borderBottomRadius: withFooter ? 0 : 'base',
|
||||||
}}
|
}}
|
||||||
|
@ -1,16 +1,8 @@
|
|||||||
import {
|
import { Flex } from '@chakra-ui/react';
|
||||||
Checkbox,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Spacer,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
|
||||||
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
|
|
||||||
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { memo } from 'react';
|
||||||
|
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
|
||||||
|
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
|
|||||||
px: 2,
|
px: 2,
|
||||||
py: 0,
|
py: 0,
|
||||||
h: 6,
|
h: 6,
|
||||||
|
justifyContent: 'space-between',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Spacer />
|
<EmbedWorkflowCheckbox nodeId={nodeId} />
|
||||||
<SaveImageCheckbox nodeId={nodeId} />
|
<SaveToGalleryCheckbox nodeId={nodeId} />
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(InvocationNodeFooter);
|
export default memo(InvocationNodeFooter);
|
||||||
|
|
||||||
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
|
||||||
const is_intermediate = useIsIntermediate(nodeId);
|
|
||||||
const handleChangeIsIntermediate = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
dispatch(
|
|
||||||
fieldBooleanValueChanged({
|
|
||||||
nodeId,
|
|
||||||
fieldName: 'is_intermediate',
|
|
||||||
value: !e.target.checked,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
[dispatch, nodeId]
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!hasImageOutput) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
|
||||||
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
|
|
||||||
<Checkbox
|
|
||||||
className="nopan"
|
|
||||||
size="sm"
|
|
||||||
onChange={handleChangeIsIntermediate}
|
|
||||||
isChecked={!is_intermediate}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
SaveImageCheckbox.displayName = 'SaveImageCheckbox';
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Icon,
|
Icon,
|
||||||
Modal,
|
Modal,
|
||||||
ModalBody,
|
ModalBody,
|
||||||
@ -14,16 +12,14 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import IAITextarea from 'common/components/IAITextarea';
|
|
||||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { isInvocationNodeData } from 'features/nodes/types/types';
|
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { FaInfoCircle } from 'react-icons/fa';
|
import { FaInfoCircle } from 'react-icons/fa';
|
||||||
|
import NotesTextarea from './NotesTextarea';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -80,13 +76,29 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
|||||||
const data = useNodeData(nodeId);
|
const data = useNodeData(nodeId);
|
||||||
const nodeTemplate = useNodeTemplate(nodeId);
|
const nodeTemplate = useNodeTemplate(nodeId);
|
||||||
|
|
||||||
|
const title = useMemo(() => {
|
||||||
|
if (data?.label && nodeTemplate?.title) {
|
||||||
|
return `${data.label} (${nodeTemplate.title})`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data?.label && !nodeTemplate) {
|
||||||
|
return data.label;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!data?.label && nodeTemplate) {
|
||||||
|
return nodeTemplate.title;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 'Unknown Node';
|
||||||
|
}, [data, nodeTemplate]);
|
||||||
|
|
||||||
if (!isInvocationNodeData(data)) {
|
if (!isInvocationNodeData(data)) {
|
||||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ flexDir: 'column' }}>
|
<Flex sx={{ flexDir: 'column' }}>
|
||||||
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
|
<Text sx={{ fontWeight: 600 }}>{title}</Text>
|
||||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||||
{nodeTemplate?.description}
|
{nodeTemplate?.description}
|
||||||
</Text>
|
</Text>
|
||||||
@ -96,29 +108,3 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
TooltipContent.displayName = 'TooltipContent';
|
TooltipContent.displayName = 'TooltipContent';
|
||||||
|
|
||||||
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const data = useNodeData(nodeId);
|
|
||||||
const handleNotesChanged = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
|
||||||
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
|
|
||||||
},
|
|
||||||
[dispatch, nodeId]
|
|
||||||
);
|
|
||||||
if (!isInvocationNodeData(data)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>Notes</FormLabel>
|
|
||||||
<IAITextarea
|
|
||||||
value={data?.notes}
|
|
||||||
onChange={handleNotesChanged}
|
|
||||||
rows={10}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
NotesTextarea.displayName = 'NodesTextarea';
|
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
|
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||||
|
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const data = useNodeData(nodeId);
|
||||||
|
const handleNotesChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
|
||||||
|
},
|
||||||
|
[dispatch, nodeId]
|
||||||
|
);
|
||||||
|
if (!isInvocationNodeData(data)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>Notes</FormLabel>
|
||||||
|
<IAITextarea
|
||||||
|
value={data?.notes}
|
||||||
|
onChange={handleNotesChanged}
|
||||||
|
rows={10}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NotesTextarea);
|
@ -0,0 +1,41 @@
|
|||||||
|
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
|
||||||
|
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
|
||||||
|
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
|
const isIntermediate = useIsIntermediate(nodeId);
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(
|
||||||
|
nodeIsIntermediateChanged({
|
||||||
|
nodeId,
|
||||||
|
isIntermediate: !e.target.checked,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!hasImageOutput) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
|
||||||
|
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
|
||||||
|
<Checkbox
|
||||||
|
className="nopan"
|
||||||
|
size="sm"
|
||||||
|
onChange={handleChange}
|
||||||
|
isChecked={!isIntermediate}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SaveToGalleryCheckbox);
|
@ -0,0 +1,167 @@
|
|||||||
|
import {
|
||||||
|
Editable,
|
||||||
|
EditableInput,
|
||||||
|
EditablePreview,
|
||||||
|
Flex,
|
||||||
|
Tooltip,
|
||||||
|
forwardRef,
|
||||||
|
useEditableControls,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||||
|
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||||
|
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
||||||
|
import FieldTooltipContent from './FieldTooltipContent';
|
||||||
|
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
nodeId: string;
|
||||||
|
fieldName: string;
|
||||||
|
kind: 'input' | 'output';
|
||||||
|
isMissingInput?: boolean;
|
||||||
|
withTooltip?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const EditableFieldTitle = forwardRef((props: Props, ref) => {
|
||||||
|
const {
|
||||||
|
nodeId,
|
||||||
|
fieldName,
|
||||||
|
kind,
|
||||||
|
isMissingInput = false,
|
||||||
|
withTooltip = false,
|
||||||
|
} = props;
|
||||||
|
const label = useFieldLabel(nodeId, fieldName);
|
||||||
|
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [localTitle, setLocalTitle] = useState(
|
||||||
|
label || fieldTemplateTitle || 'Unknown Field'
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSubmit = useCallback(
|
||||||
|
async (newTitle: string) => {
|
||||||
|
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
|
||||||
|
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
||||||
|
},
|
||||||
|
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChange = useCallback((newTitle: string) => {
|
||||||
|
setLocalTitle(newTitle);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Another component may change the title; sync local title with global state
|
||||||
|
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
|
||||||
|
}, [label, fieldTemplateTitle]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
label={
|
||||||
|
withTooltip ? (
|
||||||
|
<FieldTooltipContent
|
||||||
|
nodeId={nodeId}
|
||||||
|
fieldName={fieldName}
|
||||||
|
kind="input"
|
||||||
|
/>
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
|
placement="top"
|
||||||
|
hasArrow
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
ref={ref}
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
overflow: 'hidden',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'flex-start',
|
||||||
|
gap: 1,
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Editable
|
||||||
|
value={localTitle}
|
||||||
|
onChange={handleChange}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
as={Flex}
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
alignItems: 'center',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<EditablePreview
|
||||||
|
sx={{
|
||||||
|
p: 0,
|
||||||
|
fontWeight: isMissingInput ? 600 : 400,
|
||||||
|
textAlign: 'left',
|
||||||
|
_hover: {
|
||||||
|
fontWeight: '600 !important',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
noOfLines={1}
|
||||||
|
/>
|
||||||
|
<EditableInput
|
||||||
|
className="nodrag"
|
||||||
|
sx={{
|
||||||
|
p: 0,
|
||||||
|
w: 'full',
|
||||||
|
fontWeight: 600,
|
||||||
|
color: 'base.900',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.100',
|
||||||
|
},
|
||||||
|
_focusVisible: {
|
||||||
|
p: 0,
|
||||||
|
textAlign: 'left',
|
||||||
|
boxShadow: 'none',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<EditableControls />
|
||||||
|
</Editable>
|
||||||
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
export default memo(EditableFieldTitle);
|
||||||
|
|
||||||
|
const EditableControls = memo(() => {
|
||||||
|
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||||
|
const handleClick = useCallback(
|
||||||
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
const { onClick } = getEditButtonProps();
|
||||||
|
if (!onClick) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onClick(e);
|
||||||
|
e.preventDefault();
|
||||||
|
},
|
||||||
|
[getEditButtonProps]
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isEditing) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
onClick={handleClick}
|
||||||
|
position="absolute"
|
||||||
|
w="full"
|
||||||
|
h="full"
|
||||||
|
top={0}
|
||||||
|
insetInlineStart={0}
|
||||||
|
cursor="text"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
EditableControls.displayName = 'EditableControls';
|
@ -1,16 +1,7 @@
|
|||||||
import {
|
import { Flex, Text, forwardRef } from '@chakra-ui/react';
|
||||||
Editable,
|
|
||||||
EditableInput,
|
|
||||||
EditablePreview,
|
|
||||||
Flex,
|
|
||||||
forwardRef,
|
|
||||||
useEditableControls,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
|
||||||
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
|
||||||
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
|
import { memo } from 'react';
|
||||||
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
|
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -24,31 +15,6 @@ const FieldTitle = forwardRef((props: Props, ref) => {
|
|||||||
const label = useFieldLabel(nodeId, fieldName);
|
const label = useFieldLabel(nodeId, fieldName);
|
||||||
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const [localTitle, setLocalTitle] = useState(
|
|
||||||
label || fieldTemplateTitle || 'Unknown Field'
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleSubmit = useCallback(
|
|
||||||
async (newTitle: string) => {
|
|
||||||
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
|
|
||||||
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
|
|
||||||
},
|
|
||||||
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleChange = useCallback((newTitle: string) => {
|
|
||||||
setLocalTitle(newTitle);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
// Another component may change the title; sync local title with global state
|
|
||||||
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
|
|
||||||
}, [label, fieldTemplateTitle]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
ref={ref}
|
ref={ref}
|
||||||
@ -62,82 +28,11 @@ const FieldTitle = forwardRef((props: Props, ref) => {
|
|||||||
w: 'full',
|
w: 'full',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Editable
|
<Text sx={{ fontWeight: isMissingInput ? 600 : 400 }}>
|
||||||
value={localTitle}
|
{label || fieldTemplateTitle}
|
||||||
onChange={handleChange}
|
</Text>
|
||||||
onSubmit={handleSubmit}
|
|
||||||
as={Flex}
|
|
||||||
sx={{
|
|
||||||
position: 'relative',
|
|
||||||
alignItems: 'center',
|
|
||||||
h: 'full',
|
|
||||||
w: 'full',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<EditablePreview
|
|
||||||
sx={{
|
|
||||||
p: 0,
|
|
||||||
fontWeight: isMissingInput ? 600 : 400,
|
|
||||||
textAlign: 'left',
|
|
||||||
_hover: {
|
|
||||||
fontWeight: '600 !important',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
noOfLines={1}
|
|
||||||
/>
|
|
||||||
<EditableInput
|
|
||||||
className="nodrag"
|
|
||||||
sx={{
|
|
||||||
p: 0,
|
|
||||||
fontWeight: 600,
|
|
||||||
color: 'base.900',
|
|
||||||
_dark: {
|
|
||||||
color: 'base.100',
|
|
||||||
},
|
|
||||||
_focusVisible: {
|
|
||||||
p: 0,
|
|
||||||
textAlign: 'left',
|
|
||||||
boxShadow: 'none',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<EditableControls />
|
|
||||||
</Editable>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
export default memo(FieldTitle);
|
export default memo(FieldTitle);
|
||||||
|
|
||||||
const EditableControls = memo(() => {
|
|
||||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
|
||||||
const handleClick = useCallback(
|
|
||||||
(e: MouseEvent<HTMLDivElement>) => {
|
|
||||||
const { onClick } = getEditButtonProps();
|
|
||||||
if (!onClick) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
onClick(e);
|
|
||||||
e.preventDefault();
|
|
||||||
},
|
|
||||||
[getEditButtonProps]
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isEditing) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
onClick={handleClick}
|
|
||||||
position="absolute"
|
|
||||||
w="full"
|
|
||||||
h="full"
|
|
||||||
top={0}
|
|
||||||
insetInlineStart={0}
|
|
||||||
cursor="text"
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
EditableControls.displayName = 'EditableControls';
|
|
||||||
|
@ -34,6 +34,8 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return 'Unknown Field';
|
return 'Unknown Field';
|
||||||
|
} else {
|
||||||
|
return fieldTemplate?.title || 'Unknown Field';
|
||||||
}
|
}
|
||||||
}, [field, fieldTemplate]);
|
}, [field, fieldTemplate]);
|
||||||
|
|
||||||
|
@ -1,16 +1,11 @@
|
|||||||
import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
|
||||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||||
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
|
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
|
||||||
import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind';
|
|
||||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||||
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
|
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
|
||||||
import { PropsWithChildren, memo, useMemo } from 'react';
|
import { PropsWithChildren, memo, useMemo } from 'react';
|
||||||
|
import EditableFieldTitle from './EditableFieldTitle';
|
||||||
import FieldContextMenu from './FieldContextMenu';
|
import FieldContextMenu from './FieldContextMenu';
|
||||||
import FieldHandle from './FieldHandle';
|
import FieldHandle from './FieldHandle';
|
||||||
import FieldTitle from './FieldTitle';
|
|
||||||
import FieldTooltipContent from './FieldTooltipContent';
|
|
||||||
import InputFieldRenderer from './InputFieldRenderer';
|
import InputFieldRenderer from './InputFieldRenderer';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
@ -21,7 +16,6 @@ interface Props {
|
|||||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||||
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
|
||||||
const input = useFieldInputKind(nodeId, fieldName);
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -51,11 +45,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
|
|
||||||
if (fieldTemplate?.fieldKind !== 'input') {
|
if (fieldTemplate?.fieldKind !== 'input') {
|
||||||
return (
|
return (
|
||||||
<InputFieldWrapper
|
<InputFieldWrapper shouldDim={shouldDim}>
|
||||||
nodeId={nodeId}
|
|
||||||
fieldName={fieldName}
|
|
||||||
shouldDim={shouldDim}
|
|
||||||
>
|
|
||||||
<FormControl
|
<FormControl
|
||||||
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
|
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
|
||||||
>
|
>
|
||||||
@ -66,19 +56,14 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<InputFieldWrapper
|
<InputFieldWrapper shouldDim={shouldDim}>
|
||||||
nodeId={nodeId}
|
|
||||||
fieldName={fieldName}
|
|
||||||
shouldDim={shouldDim}
|
|
||||||
>
|
|
||||||
<FormControl
|
<FormControl
|
||||||
as={Flex}
|
|
||||||
isInvalid={isMissingInput}
|
isInvalid={isMissingInput}
|
||||||
isDisabled={isConnected}
|
isDisabled={isConnected}
|
||||||
sx={{
|
sx={{
|
||||||
alignItems: 'stretch',
|
alignItems: 'stretch',
|
||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
ps: 2,
|
ps: fieldTemplate.input === 'direct' ? 0 : 2,
|
||||||
gap: 2,
|
gap: 2,
|
||||||
h: 'full',
|
h: 'full',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
@ -86,42 +71,27 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
|||||||
>
|
>
|
||||||
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
|
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
|
||||||
{(ref) => (
|
{(ref) => (
|
||||||
<Tooltip
|
<FormLabel
|
||||||
label={
|
sx={{
|
||||||
<FieldTooltipContent
|
display: 'flex',
|
||||||
nodeId={nodeId}
|
alignItems: 'center',
|
||||||
fieldName={fieldName}
|
mb: 0,
|
||||||
kind="input"
|
px: 1,
|
||||||
/>
|
gap: 2,
|
||||||
}
|
}}
|
||||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
|
||||||
placement="top"
|
|
||||||
hasArrow
|
|
||||||
>
|
>
|
||||||
<FormLabel
|
<EditableFieldTitle
|
||||||
sx={{
|
ref={ref}
|
||||||
mb: 0,
|
nodeId={nodeId}
|
||||||
width: input === 'connection' ? 'auto' : '25%',
|
fieldName={fieldName}
|
||||||
flexShrink: 0,
|
kind="input"
|
||||||
flexGrow: 0,
|
isMissingInput={isMissingInput}
|
||||||
}}
|
withTooltip
|
||||||
>
|
/>
|
||||||
<FieldTitle
|
</FormLabel>
|
||||||
ref={ref}
|
|
||||||
nodeId={nodeId}
|
|
||||||
fieldName={fieldName}
|
|
||||||
kind="input"
|
|
||||||
isMissingInput={isMissingInput}
|
|
||||||
/>
|
|
||||||
</FormLabel>
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
)}
|
||||||
</FieldContextMenu>
|
</FieldContextMenu>
|
||||||
<Box
|
<Box>
|
||||||
sx={{
|
|
||||||
width: input === 'connection' ? 'auto' : '75%',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||||
</Box>
|
</Box>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
@ -143,19 +113,12 @@ export default memo(InputField);
|
|||||||
|
|
||||||
type InputFieldWrapperProps = PropsWithChildren<{
|
type InputFieldWrapperProps = PropsWithChildren<{
|
||||||
shouldDim: boolean;
|
shouldDim: boolean;
|
||||||
nodeId: string;
|
|
||||||
fieldName: string;
|
|
||||||
}>;
|
}>;
|
||||||
|
|
||||||
const InputFieldWrapper = memo(
|
const InputFieldWrapper = memo(
|
||||||
({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => {
|
({ shouldDim, children }: InputFieldWrapperProps) => {
|
||||||
const { isMouseOverField, handleMouseOver, handleMouseOut } =
|
|
||||||
useIsMouseOverField(nodeId, fieldName);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
onMouseOver={handleMouseOver}
|
|
||||||
onMouseOut={handleMouseOut}
|
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
minH: 8,
|
minH: 8,
|
||||||
@ -169,7 +132,6 @@ const InputFieldWrapper = memo(
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react';
|
import {
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Icon,
|
||||||
|
Spacer,
|
||||||
|
Tooltip,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||||
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
|
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
|
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaInfoCircle, FaTrash } from 'react-icons/fa';
|
import { FaInfoCircle, FaTrash } from 'react-icons/fa';
|
||||||
import FieldTitle from './FieldTitle';
|
import EditableFieldTitle from './EditableFieldTitle';
|
||||||
import FieldTooltipContent from './FieldTooltipContent';
|
import FieldTooltipContent from './FieldTooltipContent';
|
||||||
import InputFieldRenderer from './InputFieldRenderer';
|
import InputFieldRenderer from './InputFieldRenderer';
|
||||||
|
|
||||||
@ -18,8 +25,8 @@ type Props = {
|
|||||||
|
|
||||||
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { isMouseOverField, handleMouseOut, handleMouseOver } =
|
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
|
||||||
useIsMouseOverField(nodeId, fieldName);
|
useMouseOverNode(nodeId);
|
||||||
|
|
||||||
const handleRemoveField = useCallback(() => {
|
const handleRemoveField = useCallback(() => {
|
||||||
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
||||||
@ -27,8 +34,8 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
onMouseOver={handleMouseOver}
|
onMouseEnter={handleMouseOver}
|
||||||
onMouseOut={handleMouseOut}
|
onMouseLeave={handleMouseOut}
|
||||||
layerStyle="second"
|
layerStyle="second"
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
@ -42,11 +49,15 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
sx={{
|
sx={{
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'space-between',
|
|
||||||
mb: 0,
|
mb: 0,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
|
<EditableFieldTitle
|
||||||
|
nodeId={nodeId}
|
||||||
|
fieldName={fieldName}
|
||||||
|
kind="input"
|
||||||
|
/>
|
||||||
|
<Spacer />
|
||||||
<Tooltip
|
<Tooltip
|
||||||
label={
|
label={
|
||||||
<FieldTooltipContent
|
<FieldTooltipContent
|
||||||
@ -74,7 +85,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
|
|||||||
</FormLabel>
|
</FormLabel>
|
||||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
|
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -92,6 +92,7 @@ const ControlNetModelInputFieldComponent = (
|
|||||||
error={!selectedModel}
|
error={!selectedModel}
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
|
sx={{ width: '100%' }}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -101,8 +101,10 @@ const LoRAModelInputFieldComponent = (
|
|||||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
}
|
}
|
||||||
|
error={!selectedLoRAModel}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
sx={{
|
sx={{
|
||||||
|
width: '100%',
|
||||||
'.mantine-Select-dropdown': {
|
'.mantine-Select-dropdown': {
|
||||||
width: '16rem !important',
|
width: '16rem !important',
|
||||||
},
|
},
|
||||||
|
@ -134,6 +134,7 @@ const MainModelInputFieldComponent = (
|
|||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
sx={{
|
sx={{
|
||||||
|
width: '100%',
|
||||||
'.mantine-Select-dropdown': {
|
'.mantine-Select-dropdown': {
|
||||||
width: '16rem !important',
|
width: '16rem !important',
|
||||||
},
|
},
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { SelectItem } from '@mantine/core';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
SDXLRefinerModelInputFieldTemplate,
|
SDXLRefinerModelInputFieldTemplate,
|
||||||
SDXLRefinerModelInputFieldValue,
|
SDXLRefinerModelInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
@ -101,20 +101,17 @@ const RefinerModelInputFieldComponent = (
|
|||||||
value={selectedModel?.id}
|
value={selectedModel?.id}
|
||||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
data={data}
|
data={data}
|
||||||
error={data.length === 0}
|
error={!selectedModel}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
sx={{
|
sx={{
|
||||||
|
width: '100%',
|
||||||
'.mantine-Select-dropdown': {
|
'.mantine-Select-dropdown': {
|
||||||
width: '16rem !important',
|
width: '16rem !important',
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
{isSyncModelEnabled && (
|
{isSyncModelEnabled && <SyncModelsButton className="nodrag" iconMode />}
|
||||||
<Box mt={7}>
|
|
||||||
<SyncModelsButton className="nodrag" iconMode />
|
|
||||||
</Box>
|
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -128,10 +128,11 @@ const ModelInputFieldComponent = (
|
|||||||
value={selectedModel?.id}
|
value={selectedModel?.id}
|
||||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
data={data}
|
data={data}
|
||||||
error={data.length === 0}
|
error={!selectedModel}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
sx={{
|
sx={{
|
||||||
|
width: '100%',
|
||||||
'.mantine-Select-dropdown': {
|
'.mantine-Select-dropdown': {
|
||||||
width: '16rem !important',
|
width: '16rem !important',
|
||||||
},
|
},
|
||||||
|
@ -4,9 +4,9 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
|||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
|
FieldComponentProps,
|
||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
VaeModelInputFieldValue,
|
VaeModelInputFieldValue,
|
||||||
FieldComponentProps,
|
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam';
|
||||||
@ -88,17 +88,15 @@ const VaeModelInputFieldComponent = (
|
|||||||
className="nowheel nodrag"
|
className="nowheel nodrag"
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
tooltip={selectedVaeModel?.description}
|
tooltip={selectedVaeModel?.description}
|
||||||
label={
|
|
||||||
selectedVaeModel?.base_model &&
|
|
||||||
MODEL_TYPE_MAP[selectedVaeModel?.base_model]
|
|
||||||
}
|
|
||||||
value={selectedVaeModel?.id ?? 'default'}
|
value={selectedVaeModel?.id ?? 'default'}
|
||||||
placeholder="Default"
|
placeholder="Default"
|
||||||
data={data}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
disabled={data.length === 0}
|
disabled={data.length === 0}
|
||||||
|
error={!selectedVaeModel}
|
||||||
clearable
|
clearable
|
||||||
sx={{
|
sx={{
|
||||||
|
width: '100%',
|
||||||
'.mantine-Select-dropdown': {
|
'.mantine-Select-dropdown': {
|
||||||
width: '16rem !important',
|
width: '16rem !important',
|
||||||
},
|
},
|
||||||
|
@ -27,9 +27,11 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
|||||||
const handleSubmit = useCallback(
|
const handleSubmit = useCallback(
|
||||||
async (newTitle: string) => {
|
async (newTitle: string) => {
|
||||||
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
|
dispatch(nodeLabelChanged({ nodeId, label: newTitle }));
|
||||||
setLocalTitle(newTitle || title || 'Problem Setting Title');
|
setLocalTitle(
|
||||||
|
newTitle || title || templateTitle || 'Problem Setting Title'
|
||||||
|
);
|
||||||
},
|
},
|
||||||
[nodeId, dispatch, title]
|
[dispatch, nodeId, title, templateTitle]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleChange = useCallback((newTitle: string) => {
|
const handleChange = useCallback((newTitle: string) => {
|
||||||
|
@ -7,6 +7,8 @@ import {
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||||
|
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import {
|
import {
|
||||||
DRAG_HANDLE_CLASSNAME,
|
DRAG_HANDLE_CLASSNAME,
|
||||||
NODE_WIDTH,
|
NODE_WIDTH,
|
||||||
@ -23,6 +25,8 @@ type NodeWrapperProps = PropsWithChildren & {
|
|||||||
|
|
||||||
const NodeWrapper = (props: NodeWrapperProps) => {
|
const NodeWrapper = (props: NodeWrapperProps) => {
|
||||||
const { nodeId, width, children, selected } = props;
|
const { nodeId, width, children, selected } = props;
|
||||||
|
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
|
||||||
|
useMouseOverNode(nodeId);
|
||||||
|
|
||||||
const selectIsInProgress = useMemo(
|
const selectIsInProgress = useMemo(
|
||||||
() =>
|
() =>
|
||||||
@ -36,25 +40,16 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
|
|
||||||
const isInProgress = useAppSelector(selectIsInProgress);
|
const isInProgress = useAppSelector(selectIsInProgress);
|
||||||
|
|
||||||
const [
|
const [nodeInProgressLight, nodeInProgressDark, shadowsXl, shadowsBase] =
|
||||||
nodeSelectedLight,
|
useToken('shadows', [
|
||||||
nodeSelectedDark,
|
'nodeInProgress.light',
|
||||||
nodeInProgressLight,
|
'nodeInProgress.dark',
|
||||||
nodeInProgressDark,
|
'shadows.xl',
|
||||||
shadowsXl,
|
'shadows.base',
|
||||||
shadowsBase,
|
]);
|
||||||
] = useToken('shadows', [
|
|
||||||
'nodeSelected.light',
|
|
||||||
'nodeSelected.dark',
|
|
||||||
'nodeInProgress.light',
|
|
||||||
'nodeInProgress.dark',
|
|
||||||
'shadows.xl',
|
|
||||||
'shadows.base',
|
|
||||||
]);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selectedShadow = useColorModeValue(nodeSelectedLight, nodeSelectedDark);
|
|
||||||
const inProgressShadow = useColorModeValue(
|
const inProgressShadow = useColorModeValue(
|
||||||
nodeInProgressLight,
|
nodeInProgressLight,
|
||||||
nodeInProgressDark
|
nodeInProgressDark
|
||||||
@ -69,6 +64,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
onClick={handleClick}
|
onClick={handleClick}
|
||||||
|
onMouseEnter={handleMouseOver}
|
||||||
|
onMouseLeave={handleMouseOut}
|
||||||
className={DRAG_HANDLE_CLASSNAME}
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
sx={{
|
sx={{
|
||||||
h: 'full',
|
h: 'full',
|
||||||
@ -77,11 +74,6 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
w: width ?? NODE_WIDTH,
|
w: width ?? NODE_WIDTH,
|
||||||
transitionProperty: 'common',
|
transitionProperty: 'common',
|
||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
shadow: selected
|
|
||||||
? isInProgress
|
|
||||||
? undefined
|
|
||||||
: selectedShadow
|
|
||||||
: undefined,
|
|
||||||
cursor: 'grab',
|
cursor: 'grab',
|
||||||
opacity,
|
opacity,
|
||||||
}}
|
}}
|
||||||
@ -116,6 +108,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
{children}
|
{children}
|
||||||
|
<NodeSelectionOverlay isSelected={selected} isHovered={isMouseOverNode} />
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -2,12 +2,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
|||||||
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
|
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaSave } from 'react-icons/fa';
|
import { FaDownload } from 'react-icons/fa';
|
||||||
|
|
||||||
const SaveWorkflowButton = () => {
|
const DownloadWorkflowButton = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const workflow = useWorkflow();
|
const workflow = useWorkflow();
|
||||||
const handleSave = useCallback(() => {
|
const handleDownload = useCallback(() => {
|
||||||
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
|
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
|
||||||
const a = document.createElement('a');
|
const a = document.createElement('a');
|
||||||
a.href = URL.createObjectURL(blob);
|
a.href = URL.createObjectURL(blob);
|
||||||
@ -18,12 +18,12 @@ const SaveWorkflowButton = () => {
|
|||||||
}, [workflow]);
|
}, [workflow]);
|
||||||
return (
|
return (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaSave />}
|
icon={<FaDownload />}
|
||||||
tooltip={t('nodes.saveWorkflow')}
|
tooltip={t('nodes.downloadWorkflow')}
|
||||||
aria-label={t('nodes.saveWorkflow')}
|
aria-label={t('nodes.downloadWorkflow')}
|
||||||
onClick={handleSave}
|
onClick={handleDownload}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(SaveWorkflowButton);
|
export default memo(DownloadWorkflowButton);
|
@ -2,7 +2,7 @@ import { Flex } from '@chakra-ui/layout';
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import LoadWorkflowButton from './LoadWorkflowButton';
|
import LoadWorkflowButton from './LoadWorkflowButton';
|
||||||
import ResetWorkflowButton from './ResetWorkflowButton';
|
import ResetWorkflowButton from './ResetWorkflowButton';
|
||||||
import SaveWorkflowButton from './SaveWorkflowButton';
|
import DownloadWorkflowButton from './DownloadWorkflowButton';
|
||||||
|
|
||||||
const TopCenterPanel = () => {
|
const TopCenterPanel = () => {
|
||||||
return (
|
return (
|
||||||
@ -15,7 +15,7 @@ const TopCenterPanel = () => {
|
|||||||
transform: 'translate(-50%)',
|
transform: 'translate(-50%)',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<SaveWorkflowButton />
|
<DownloadWorkflowButton />
|
||||||
<LoadWorkflowButton />
|
<LoadWorkflowButton />
|
||||||
<ResetWorkflowButton />
|
<ResetWorkflowButton />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -0,0 +1,74 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { InvocationTemplate, NodeData } from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea';
|
||||||
|
import NodeTitle from '../../flow/nodes/common/NodeTitle';
|
||||||
|
import ScrollableContent from '../ScrollableContent';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const lastSelectedNodeId =
|
||||||
|
nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||||
|
|
||||||
|
const lastSelectedNode = nodes.nodes.find(
|
||||||
|
(node) => node.id === lastSelectedNodeId
|
||||||
|
);
|
||||||
|
|
||||||
|
const lastSelectedNodeTemplate = lastSelectedNode
|
||||||
|
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return {
|
||||||
|
data: lastSelectedNode?.data,
|
||||||
|
template: lastSelectedNodeTemplate,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const InspectorDetailsTab = () => {
|
||||||
|
const { data, template } = useAppSelector(selector);
|
||||||
|
|
||||||
|
if (!template || !data) {
|
||||||
|
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Content data={data} template={template} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(InspectorDetailsTab);
|
||||||
|
|
||||||
|
const Content = (props: { data: NodeData; template: InvocationTemplate }) => {
|
||||||
|
const { data } = props;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ScrollableContent>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
position: 'relative',
|
||||||
|
p: 1,
|
||||||
|
gap: 2,
|
||||||
|
w: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<NodeTitle nodeId={data.id} />
|
||||||
|
<NotesTextarea nodeId={data.id} />
|
||||||
|
</Flex>
|
||||||
|
</ScrollableContent>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
@ -4,12 +4,13 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
import { ImageOutput } from 'services/api/types';
|
||||||
import ScrollableContent from '../ScrollableContent';
|
|
||||||
import { AnyResult } from 'services/events/types';
|
import { AnyResult } from 'services/events/types';
|
||||||
import StringOutputPreview from './outputs/StringOutputPreview';
|
import ScrollableContent from '../ScrollableContent';
|
||||||
import NumberOutputPreview from './outputs/NumberOutputPreview';
|
import ImageOutputPreview from './outputs/ImageOutputPreview';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -21,11 +22,16 @@ const selector = createSelector(
|
|||||||
(node) => node.id === lastSelectedNodeId
|
(node) => node.id === lastSelectedNodeId
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const lastSelectedNodeTemplate = lastSelectedNode
|
||||||
|
? nodes.nodeTemplates[lastSelectedNode.data.type]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
const nes =
|
const nes =
|
||||||
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
|
||||||
|
|
||||||
return {
|
return {
|
||||||
node: lastSelectedNode,
|
node: lastSelectedNode,
|
||||||
|
template: lastSelectedNodeTemplate,
|
||||||
nes,
|
nes,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@ -33,9 +39,9 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const InspectorOutputsTab = () => {
|
const InspectorOutputsTab = () => {
|
||||||
const { node, nes } = useAppSelector(selector);
|
const { node, template, nes } = useAppSelector(selector);
|
||||||
|
|
||||||
if (!node || !nes) {
|
if (!node || !nes || !isInvocationNode(node)) {
|
||||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,33 +69,16 @@ const InspectorOutputsTab = () => {
|
|||||||
w: 'full',
|
w: 'full',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{nes.outputs.map((result, i) => {
|
{template?.outputType === 'image_output' ? (
|
||||||
if (result.type === 'string_output') {
|
nes.outputs.map((result, i) => (
|
||||||
return (
|
<ImageOutputPreview
|
||||||
<StringOutputPreview key={getKey(result, i)} output={result} />
|
key={getKey(result, i)}
|
||||||
);
|
output={result as ImageOutput}
|
||||||
}
|
/>
|
||||||
if (result.type === 'float_output') {
|
))
|
||||||
return (
|
) : (
|
||||||
<NumberOutputPreview key={getKey(result, i)} output={result} />
|
<DataViewer data={nes.outputs} label="Node Outputs" />
|
||||||
);
|
)}
|
||||||
}
|
|
||||||
if (result.type === 'integer_output') {
|
|
||||||
return (
|
|
||||||
<NumberOutputPreview key={getKey(result, i)} output={result} />
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (result.type === 'image_output') {
|
|
||||||
return (
|
|
||||||
<ImageOutputPreview key={getKey(result, i)} output={result} />
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<pre key={getKey(result, i)}>
|
|
||||||
{JSON.stringify(result, null, 2)}
|
|
||||||
</pre>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -10,6 +10,7 @@ import { memo } from 'react';
|
|||||||
import InspectorDataTab from './InspectorDataTab';
|
import InspectorDataTab from './InspectorDataTab';
|
||||||
import InspectorOutputsTab from './InspectorOutputsTab';
|
import InspectorOutputsTab from './InspectorOutputsTab';
|
||||||
import InspectorTemplateTab from './InspectorTemplateTab';
|
import InspectorTemplateTab from './InspectorTemplateTab';
|
||||||
|
// import InspectorDetailsTab from './InspectorDetailsTab';
|
||||||
|
|
||||||
const InspectorPanel = () => {
|
const InspectorPanel = () => {
|
||||||
return (
|
return (
|
||||||
@ -29,12 +30,16 @@ const InspectorPanel = () => {
|
|||||||
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
|
||||||
>
|
>
|
||||||
<TabList>
|
<TabList>
|
||||||
|
{/* <Tab>Details</Tab> */}
|
||||||
<Tab>Outputs</Tab>
|
<Tab>Outputs</Tab>
|
||||||
<Tab>Data</Tab>
|
<Tab>Data</Tab>
|
||||||
<Tab>Template</Tab>
|
<Tab>Template</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
|
|
||||||
<TabPanels>
|
<TabPanels>
|
||||||
|
{/* <TabPanel>
|
||||||
|
<InspectorDetailsTab />
|
||||||
|
</TabPanel> */}
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
<InspectorOutputsTab />
|
<InspectorOutputsTab />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
import { Text } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { FloatOutput, IntegerOutput } from 'services/api/types';
|
|
||||||
|
|
||||||
type Props = {
|
|
||||||
output: IntegerOutput | FloatOutput;
|
|
||||||
};
|
|
||||||
|
|
||||||
const NumberOutputPreview = ({ output }: Props) => {
|
|
||||||
return <Text>{output.value}</Text>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(NumberOutputPreview);
|
|
@ -1,13 +0,0 @@
|
|||||||
import { Text } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { StringOutput } from 'services/api/types';
|
|
||||||
|
|
||||||
type Props = {
|
|
||||||
output: StringOutput;
|
|
||||||
};
|
|
||||||
|
|
||||||
const StringOutputPreview = ({ output }: Props) => {
|
|
||||||
return <Text>{output.value}</Text>;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(StringOutputPreview);
|
|
@ -22,6 +22,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
|
|||||||
}
|
}
|
||||||
return map(nodeTemplate.inputs)
|
return map(nodeTemplate.inputs)
|
||||||
.filter((field) => ['any', 'direct'].includes(field.input))
|
.filter((field) => ['any', 'direct'].includes(field.input))
|
||||||
|
.filter((field) => !field.ui_hidden)
|
||||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||||
.map((field) => field.name)
|
.map((field) => field.name)
|
||||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||||
|
@ -143,6 +143,8 @@ export const useBuildNodeData = () => {
|
|||||||
isOpen: true,
|
isOpen: true,
|
||||||
label: '',
|
label: '',
|
||||||
notes: '',
|
notes: '',
|
||||||
|
embedWorkflow: false,
|
||||||
|
isIntermediate: true,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
|
|||||||
}
|
}
|
||||||
return map(nodeTemplate.inputs)
|
return map(nodeTemplate.inputs)
|
||||||
.filter((field) => field.input === 'connection')
|
.filter((field) => field.input === 'connection')
|
||||||
|
.filter((field) => !field.ui_hidden)
|
||||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||||
.map((field) => field.name)
|
.map((field) => field.name)
|
||||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { isInvocationNode } from '../types/types';
|
||||||
|
|
||||||
|
export const useEmbedWorkflow = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return node.data.embedWorkflow;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const embedWorkflow = useAppSelector(selector);
|
||||||
|
return embedWorkflow;
|
||||||
|
};
|
@ -15,7 +15,7 @@ export const useIsIntermediate = (nodeId: string) => {
|
|||||||
if (!isInvocationNode(node)) {
|
if (!isInvocationNode(node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return Boolean(node.data.inputs.is_intermediate?.value);
|
return node.data.isIntermediate;
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
),
|
),
|
||||||
|
@ -3,7 +3,7 @@ import { useLogger } from 'app/logging/useLogger';
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||||
import { zWorkflow } from 'features/nodes/types/types';
|
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
@ -24,52 +24,65 @@ export const useLoadWorkflowFromFile = () => {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const parsedJSON = JSON.parse(String(rawJSON));
|
const parsedJSON = JSON.parse(String(rawJSON));
|
||||||
const result = zWorkflow.safeParse(parsedJSON);
|
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
const message = fromZodError(result.error, {
|
const { message } = fromZodError(result.error, {
|
||||||
prefix: 'Workflow Validation Error',
|
prefix: 'Workflow Validation Error',
|
||||||
}).toString();
|
});
|
||||||
|
|
||||||
logger.error({ error: parseify(result.error) }, message);
|
logger.error({ error: parseify(result.error) }, message);
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: 'Unable to Validate Workflow',
|
title: 'Unable to Validate Workflow',
|
||||||
description: (
|
|
||||||
<WorkflowValidationErrorContent error={result.error} />
|
|
||||||
),
|
|
||||||
status: 'error',
|
status: 'error',
|
||||||
duration: 5000,
|
duration: 5000,
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
reader.abort();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
dispatch(workflowLoaded(result.data.workflow));
|
||||||
|
|
||||||
dispatch(workflowLoaded(result.data));
|
if (!result.data.warnings.length) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Workflow Loaded',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
reader.abort();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
makeToast({
|
makeToast({
|
||||||
title: 'Workflow Loaded',
|
title: 'Workflow Loaded with Warnings',
|
||||||
status: 'success',
|
status: 'warning',
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
result.data.warnings.forEach(({ message, ...rest }) => {
|
||||||
|
logger.warn(rest, message);
|
||||||
|
});
|
||||||
|
|
||||||
reader.abort();
|
reader.abort();
|
||||||
} catch (error) {
|
} catch {
|
||||||
// file reader error
|
// file reader error
|
||||||
if (error) {
|
dispatch(
|
||||||
dispatch(
|
addToast(
|
||||||
addToast(
|
makeToast({
|
||||||
makeToast({
|
title: 'Unable to Load Workflow',
|
||||||
title: 'Unable to Load Workflow',
|
status: 'error',
|
||||||
status: 'error',
|
})
|
||||||
})
|
)
|
||||||
)
|
);
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { mouseOverNodeChanged } from '../store/nodesSlice';
|
||||||
|
|
||||||
|
export const useMouseOverNode = (nodeId: string) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ nodes }) => nodes.mouseOverNode === nodeId,
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isMouseOverNode = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleMouseOver = useCallback(() => {
|
||||||
|
!isMouseOverNode && dispatch(mouseOverNodeChanged(nodeId));
|
||||||
|
}, [dispatch, nodeId, isMouseOverNode]);
|
||||||
|
|
||||||
|
const handleMouseOut = useCallback(() => {
|
||||||
|
isMouseOverNode && dispatch(mouseOverNodeChanged(null));
|
||||||
|
}, [dispatch, isMouseOverNode]);
|
||||||
|
|
||||||
|
return { isMouseOverNode, handleMouseOver, handleMouseOut };
|
||||||
|
};
|
@ -21,6 +21,7 @@ export const useOutputFieldNames = (nodeId: string) => {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
return map(nodeTemplate.outputs)
|
return map(nodeTemplate.outputs)
|
||||||
|
.filter((field) => !field.ui_hidden)
|
||||||
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
|
||||||
.map((field) => field.name)
|
.map((field) => field.name)
|
||||||
.filter((fieldName) => fieldName !== 'is_intermediate');
|
.filter((fieldName) => fieldName !== 'is_intermediate');
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
|
import { cloneDeep, forEach, isEqual, map, uniqBy } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
addEdge,
|
addEdge,
|
||||||
applyEdgeChanges,
|
applyEdgeChanges,
|
||||||
@ -18,7 +18,7 @@ import {
|
|||||||
Viewport,
|
Viewport,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { sessionInvoked } from 'services/api/thunks/session';
|
import { sessionCanceled, sessionInvoked } from 'services/api/thunks/session';
|
||||||
import { ImageField } from 'services/api/types';
|
import { ImageField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
appSocketGeneratorProgress,
|
appSocketGeneratorProgress,
|
||||||
@ -102,6 +102,7 @@ export const initialNodesState: NodesState = {
|
|||||||
nodeExecutionStates: {},
|
nodeExecutionStates: {},
|
||||||
viewport: { x: 0, y: 0, zoom: 1 },
|
viewport: { x: 0, y: 0, zoom: 1 },
|
||||||
mouseOverField: null,
|
mouseOverField: null,
|
||||||
|
mouseOverNode: null,
|
||||||
nodesToCopy: [],
|
nodesToCopy: [],
|
||||||
edgesToCopy: [],
|
edgesToCopy: [],
|
||||||
selectionMode: SelectionMode.Partial,
|
selectionMode: SelectionMode.Partial,
|
||||||
@ -245,6 +246,34 @@ const nodesSlice = createSlice({
|
|||||||
}
|
}
|
||||||
field.label = label;
|
field.label = label;
|
||||||
},
|
},
|
||||||
|
nodeEmbedWorkflowChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
|
||||||
|
) => {
|
||||||
|
const { nodeId, embedWorkflow } = action.payload;
|
||||||
|
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||||
|
|
||||||
|
const node = state.nodes?.[nodeIndex];
|
||||||
|
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
node.data.embedWorkflow = embedWorkflow;
|
||||||
|
},
|
||||||
|
nodeIsIntermediateChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
|
||||||
|
) => {
|
||||||
|
const { nodeId, isIntermediate } = action.payload;
|
||||||
|
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||||
|
|
||||||
|
const node = state.nodes?.[nodeIndex];
|
||||||
|
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
node.data.isIntermediate = isIntermediate;
|
||||||
|
},
|
||||||
nodeIsOpenChanged: (
|
nodeIsOpenChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ nodeId: string; isOpen: boolean }>
|
action: PayloadAction<{ nodeId: string; isOpen: boolean }>
|
||||||
@ -561,7 +590,7 @@ const nodesSlice = createSlice({
|
|||||||
nodeEditorReset: (state) => {
|
nodeEditorReset: (state) => {
|
||||||
state.nodes = [];
|
state.nodes = [];
|
||||||
state.edges = [];
|
state.edges = [];
|
||||||
state.workflow.exposedFields = [];
|
state.workflow = cloneDeep(initialWorkflow);
|
||||||
},
|
},
|
||||||
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldValidateGraph = action.payload;
|
state.shouldValidateGraph = action.payload;
|
||||||
@ -637,6 +666,9 @@ const nodesSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
state.mouseOverField = action.payload;
|
state.mouseOverField = action.payload;
|
||||||
},
|
},
|
||||||
|
mouseOverNodeChanged: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.mouseOverNode = action.payload;
|
||||||
|
},
|
||||||
selectedAll: (state) => {
|
selectedAll: (state) => {
|
||||||
state.nodes = applyNodeChanges(
|
state.nodes = applyNodeChanges(
|
||||||
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
|
state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })),
|
||||||
@ -790,6 +822,13 @@ const nodesSlice = createSlice({
|
|||||||
nes.outputs = [];
|
nes.outputs = [];
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
builder.addCase(sessionCanceled.fulfilled, (state) => {
|
||||||
|
map(state.nodeExecutionStates, (nes) => {
|
||||||
|
if (nes.status === NodeStatus.IN_PROGRESS) {
|
||||||
|
nes.status = NodeStatus.PENDING;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -850,6 +889,9 @@ export const {
|
|||||||
addNodePopoverClosed,
|
addNodePopoverClosed,
|
||||||
addNodePopoverToggled,
|
addNodePopoverToggled,
|
||||||
selectionModeChanged,
|
selectionModeChanged,
|
||||||
|
nodeEmbedWorkflowChanged,
|
||||||
|
nodeIsIntermediateChanged,
|
||||||
|
mouseOverNodeChanged,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
export default nodesSlice.reducer;
|
export default nodesSlice.reducer;
|
||||||
|
@ -35,6 +35,7 @@ export type NodesState = {
|
|||||||
viewport: Viewport;
|
viewport: Viewport;
|
||||||
isReady: boolean;
|
isReady: boolean;
|
||||||
mouseOverField: FieldIdentifier | null;
|
mouseOverField: FieldIdentifier | null;
|
||||||
|
mouseOverNode: string | null;
|
||||||
nodesToCopy: Node<NodeData>[];
|
nodesToCopy: Node<NodeData>[];
|
||||||
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
edgesToCopy: Edge<InvocationEdgeExtra>[];
|
||||||
isAddNodePopoverOpen: boolean;
|
isAddNodePopoverOpen: boolean;
|
||||||
|
@ -62,7 +62,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
DenoiseMaskField: {
|
DenoiseMaskField: {
|
||||||
title: 'Denoise Mask',
|
title: 'Denoise Mask',
|
||||||
description: 'Denoise Mask may be passed between nodes',
|
description: 'Denoise Mask may be passed between nodes',
|
||||||
color: 'red.700',
|
color: 'base.500',
|
||||||
},
|
},
|
||||||
LatentsField: {
|
LatentsField: {
|
||||||
title: 'Latents',
|
title: 'Latents',
|
||||||
@ -174,11 +174,6 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Color Collection',
|
title: 'Color Collection',
|
||||||
description: 'A collection of colors.',
|
description: 'A collection of colors.',
|
||||||
},
|
},
|
||||||
FilePath: {
|
|
||||||
color: 'base.500',
|
|
||||||
title: 'File Path',
|
|
||||||
description: 'A path to a file.',
|
|
||||||
},
|
|
||||||
ONNXModelField: {
|
ONNXModelField: {
|
||||||
color: 'base.500',
|
color: 'base.500',
|
||||||
title: 'ONNX Model',
|
title: 'ONNX Model',
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
|
import { store } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
zBaseModel,
|
zBaseModel,
|
||||||
zMainOrOnnxModel,
|
zMainOrOnnxModel,
|
||||||
|
zSDXLRefinerModel,
|
||||||
zScheduler,
|
zScheduler,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { keyBy } from 'lodash-es';
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { Node } from 'reactflow';
|
import { Node } from 'reactflow';
|
||||||
|
import { JsonObject } from 'type-fest';
|
||||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
AnyInvocationType,
|
AnyInvocationType,
|
||||||
@ -98,7 +102,6 @@ export const zFieldType = z.enum([
|
|||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
// region Misc
|
// region Misc
|
||||||
'FilePath',
|
|
||||||
'enum',
|
'enum',
|
||||||
'Scheduler',
|
'Scheduler',
|
||||||
// endregion
|
// endregion
|
||||||
@ -106,8 +109,17 @@ export const zFieldType = z.enum([
|
|||||||
|
|
||||||
export type FieldType = z.infer<typeof zFieldType>;
|
export type FieldType = z.infer<typeof zFieldType>;
|
||||||
|
|
||||||
|
export const zReservedFieldType = z.enum([
|
||||||
|
'WorkflowField',
|
||||||
|
'IsIntermediate',
|
||||||
|
'MetadataField',
|
||||||
|
]);
|
||||||
|
|
||||||
|
export type ReservedFieldType = z.infer<typeof zReservedFieldType>;
|
||||||
|
|
||||||
export const isFieldType = (value: unknown): value is FieldType =>
|
export const isFieldType = (value: unknown): value is FieldType =>
|
||||||
zFieldType.safeParse(value).success;
|
zFieldType.safeParse(value).success ||
|
||||||
|
zReservedFieldType.safeParse(value).success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An input field template is generated on each page load from the OpenAPI schema.
|
* An input field template is generated on each page load from the OpenAPI schema.
|
||||||
@ -215,7 +227,7 @@ export type DenoiseMaskFieldValue = z.infer<typeof zDenoiseMaskField>;
|
|||||||
|
|
||||||
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('integer'),
|
type: z.literal('integer'),
|
||||||
value: z.number().optional(),
|
value: z.number().int().optional(),
|
||||||
});
|
});
|
||||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||||
|
|
||||||
@ -641,6 +653,11 @@ export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'Scheduler';
|
type: 'Scheduler';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'WorkflowField';
|
||||||
|
};
|
||||||
|
|
||||||
export const isInputFieldValue = (
|
export const isInputFieldValue = (
|
||||||
field?: InputFieldValue | OutputFieldValue
|
field?: InputFieldValue | OutputFieldValue
|
||||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||||
@ -661,6 +678,7 @@ export type TypeHints = {
|
|||||||
export type InvocationSchemaExtra = {
|
export type InvocationSchemaExtra = {
|
||||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
||||||
title: string;
|
title: string;
|
||||||
|
category?: string;
|
||||||
tags?: string[];
|
tags?: string[];
|
||||||
properties: Omit<
|
properties: Omit<
|
||||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||||
@ -737,6 +755,48 @@ export const isInvocationFieldSchema = (
|
|||||||
|
|
||||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||||
|
|
||||||
|
export const zCoreMetadata = z
|
||||||
|
.object({
|
||||||
|
app_version: z.string().nullish(),
|
||||||
|
generation_mode: z.string().nullish(),
|
||||||
|
created_by: z.string().nullish(),
|
||||||
|
positive_prompt: z.string().nullish(),
|
||||||
|
negative_prompt: z.string().nullish(),
|
||||||
|
width: z.number().int().nullish(),
|
||||||
|
height: z.number().int().nullish(),
|
||||||
|
seed: z.number().int().nullish(),
|
||||||
|
rand_device: z.string().nullish(),
|
||||||
|
cfg_scale: z.number().nullish(),
|
||||||
|
steps: z.number().int().nullish(),
|
||||||
|
scheduler: z.string().nullish(),
|
||||||
|
clip_skip: z.number().int().nullish(),
|
||||||
|
model: zMainOrOnnxModel.nullish(),
|
||||||
|
controlnets: z.array(zControlField).nullish(),
|
||||||
|
loras: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
lora: zLoRAModelField,
|
||||||
|
weight: z.number(),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.nullish(),
|
||||||
|
vae: zVaeModelField.nullish(),
|
||||||
|
strength: z.number().nullish(),
|
||||||
|
init_image: z.string().nullish(),
|
||||||
|
positive_style_prompt: z.string().nullish(),
|
||||||
|
negative_style_prompt: z.string().nullish(),
|
||||||
|
refiner_model: zSDXLRefinerModel.nullish(),
|
||||||
|
refiner_cfg_scale: z.number().nullish(),
|
||||||
|
refiner_steps: z.number().int().nullish(),
|
||||||
|
refiner_scheduler: z.string().nullish(),
|
||||||
|
refiner_positive_aesthetic_store: z.number().nullish(),
|
||||||
|
refiner_negative_aesthetic_store: z.number().nullish(),
|
||||||
|
refiner_start: z.number().nullish(),
|
||||||
|
})
|
||||||
|
.catchall(z.record(z.any()));
|
||||||
|
|
||||||
|
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||||
|
|
||||||
export const zInvocationNodeData = z.object({
|
export const zInvocationNodeData = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||||
@ -747,6 +807,8 @@ export const zInvocationNodeData = z.object({
|
|||||||
label: z.string(),
|
label: z.string(),
|
||||||
isOpen: z.boolean(),
|
isOpen: z.boolean(),
|
||||||
notes: z.string(),
|
notes: z.string(),
|
||||||
|
embedWorkflow: z.boolean(),
|
||||||
|
isIntermediate: z.boolean(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Massage this to get better type safety while developing
|
// Massage this to get better type safety while developing
|
||||||
@ -767,28 +829,38 @@ export const zNotesNodeData = z.object({
|
|||||||
|
|
||||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||||
|
|
||||||
|
const zPosition = z
|
||||||
|
.object({
|
||||||
|
x: z.number(),
|
||||||
|
y: z.number(),
|
||||||
|
})
|
||||||
|
.default({ x: 0, y: 0 });
|
||||||
|
|
||||||
|
const zDimension = z.number().gt(0).nullish();
|
||||||
|
|
||||||
export const zWorkflowInvocationNode = z.object({
|
export const zWorkflowInvocationNode = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
type: z.literal('invocation'),
|
type: z.literal('invocation'),
|
||||||
data: zInvocationNodeData,
|
data: zInvocationNodeData,
|
||||||
width: z.number().gt(0),
|
width: zDimension,
|
||||||
height: z.number().gt(0),
|
height: zDimension,
|
||||||
position: z.object({
|
position: zPosition,
|
||||||
x: z.number(),
|
|
||||||
y: z.number(),
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||||
|
|
||||||
|
export const isWorkflowInvocationNode = (
|
||||||
|
val: unknown
|
||||||
|
): val is WorkflowInvocationNode =>
|
||||||
|
zWorkflowInvocationNode.safeParse(val).success;
|
||||||
|
|
||||||
export const zWorkflowNotesNode = z.object({
|
export const zWorkflowNotesNode = z.object({
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
type: z.literal('notes'),
|
type: z.literal('notes'),
|
||||||
data: zNotesNodeData,
|
data: zNotesNodeData,
|
||||||
width: z.number().gt(0),
|
width: zDimension,
|
||||||
height: z.number().gt(0),
|
height: zDimension,
|
||||||
position: z.object({
|
position: zPosition,
|
||||||
x: z.number(),
|
|
||||||
y: z.number(),
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export const zWorkflowNode = z.discriminatedUnion('type', [
|
export const zWorkflowNode = z.discriminatedUnion('type', [
|
||||||
@ -798,14 +870,25 @@ export const zWorkflowNode = z.discriminatedUnion('type', [
|
|||||||
|
|
||||||
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
|
export type WorkflowNode = z.infer<typeof zWorkflowNode>;
|
||||||
|
|
||||||
export const zWorkflowEdge = z.object({
|
export const zDefaultWorkflowEdge = z.object({
|
||||||
source: z.string().trim().min(1),
|
source: z.string().trim().min(1),
|
||||||
sourceHandle: z.string().trim().min(1),
|
sourceHandle: z.string().trim().min(1),
|
||||||
target: z.string().trim().min(1),
|
target: z.string().trim().min(1),
|
||||||
targetHandle: z.string().trim().min(1),
|
targetHandle: z.string().trim().min(1),
|
||||||
id: z.string().trim().min(1),
|
id: z.string().trim().min(1),
|
||||||
type: z.enum(['default', 'collapsed']),
|
type: z.literal('default'),
|
||||||
});
|
});
|
||||||
|
export const zCollapsedWorkflowEdge = z.object({
|
||||||
|
source: z.string().trim().min(1),
|
||||||
|
target: z.string().trim().min(1),
|
||||||
|
id: z.string().trim().min(1),
|
||||||
|
type: z.literal('collapsed'),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const zWorkflowEdge = z.union([
|
||||||
|
zDefaultWorkflowEdge,
|
||||||
|
zCollapsedWorkflowEdge,
|
||||||
|
]);
|
||||||
|
|
||||||
export const zFieldIdentifier = z.object({
|
export const zFieldIdentifier = z.object({
|
||||||
nodeId: z.string().trim().min(1),
|
nodeId: z.string().trim().min(1),
|
||||||
@ -828,21 +911,92 @@ export const zSemVer = z.string().refine((val) => {
|
|||||||
|
|
||||||
export type SemVer = z.infer<typeof zSemVer>;
|
export type SemVer = z.infer<typeof zSemVer>;
|
||||||
|
|
||||||
|
export type WorkflowWarning = {
|
||||||
|
message: string;
|
||||||
|
issues: string[];
|
||||||
|
data: JsonObject;
|
||||||
|
};
|
||||||
|
|
||||||
export const zWorkflow = z.object({
|
export const zWorkflow = z.object({
|
||||||
name: z.string(),
|
name: z.string().default(''),
|
||||||
author: z.string(),
|
author: z.string().default(''),
|
||||||
description: z.string(),
|
description: z.string().default(''),
|
||||||
version: z.string(),
|
version: z.string().default(''),
|
||||||
contact: z.string(),
|
contact: z.string().default(''),
|
||||||
tags: z.string(),
|
tags: z.string().default(''),
|
||||||
notes: z.string(),
|
notes: z.string().default(''),
|
||||||
nodes: z.array(zWorkflowNode),
|
nodes: z.array(zWorkflowNode).default([]),
|
||||||
edges: z.array(zWorkflowEdge),
|
edges: z.array(zWorkflowEdge).default([]),
|
||||||
exposedFields: z.array(zFieldIdentifier),
|
exposedFields: z.array(zFieldIdentifier).default([]),
|
||||||
|
meta: z
|
||||||
|
.object({
|
||||||
|
version: zSemVer,
|
||||||
|
})
|
||||||
|
.default({ version: '1.0.0' }),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||||
|
const nodeTemplates = store.getState().nodes.nodeTemplates;
|
||||||
|
const { nodes, edges } = workflow;
|
||||||
|
const warnings: WorkflowWarning[] = [];
|
||||||
|
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||||
|
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||||
|
invocationNodes.forEach((node, i) => {
|
||||||
|
const nodeTemplate = nodeTemplates[node.data.type];
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
warnings.push({
|
||||||
|
message: `Node "${node.data.label || node.data.id}" skipped`,
|
||||||
|
issues: [`Unable to find template for type "${node.data.type}"`],
|
||||||
|
data: node,
|
||||||
|
});
|
||||||
|
delete nodes[i];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
edges.forEach((edge, i) => {
|
||||||
|
const sourceNode = keyedNodes[edge.source];
|
||||||
|
const targetNode = keyedNodes[edge.target];
|
||||||
|
const issues: string[] = [];
|
||||||
|
if (!sourceNode) {
|
||||||
|
issues.push(`Output node ${edge.source} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!targetNode) {
|
||||||
|
issues.push(`Input node ${edge.target} does not exist`);
|
||||||
|
} else if (
|
||||||
|
edge.type === 'default' &&
|
||||||
|
!(edge.targetHandle in targetNode.data.inputs)
|
||||||
|
) {
|
||||||
|
issues.push(
|
||||||
|
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (issues.length) {
|
||||||
|
delete edges[i];
|
||||||
|
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||||
|
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||||
|
warnings.push({
|
||||||
|
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||||
|
issues,
|
||||||
|
data: edge,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return { workflow, warnings };
|
||||||
});
|
});
|
||||||
|
|
||||||
export type Workflow = z.infer<typeof zWorkflow>;
|
export type Workflow = z.infer<typeof zWorkflow>;
|
||||||
|
|
||||||
|
export type ImageMetadataAndWorkflow = {
|
||||||
|
metadata?: CoreMetadata;
|
||||||
|
workflow?: Workflow;
|
||||||
|
};
|
||||||
|
|
||||||
export type CurrentImageNodeData = {
|
export type CurrentImageNodeData = {
|
||||||
id: string;
|
id: string;
|
||||||
type: 'current_image';
|
type: 'current_image';
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { logger } from 'app/logging/logger';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { NodesState } from '../store/types';
|
import { NodesState } from '../store/types';
|
||||||
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
|
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
|
||||||
|
import { fromZodError } from 'zod-validation-error';
|
||||||
|
import { parseify } from 'common/util/serialize';
|
||||||
|
|
||||||
export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||||
const { workflow: workflowMeta, nodes, edges } = nodesState;
|
const { workflow: workflowMeta, nodes, edges } = nodesState;
|
||||||
@ -14,6 +15,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
|||||||
nodes.forEach((node) => {
|
nodes.forEach((node) => {
|
||||||
const result = zWorkflowNode.safeParse(node);
|
const result = zWorkflowNode.safeParse(node);
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
const { message } = fromZodError(result.error, {
|
||||||
|
prefix: 'Unable to parse node',
|
||||||
|
});
|
||||||
|
logger('nodes').warn({ node: parseify(node) }, message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
workflow.nodes.push(result.data);
|
workflow.nodes.push(result.data);
|
||||||
@ -22,6 +27,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
|||||||
edges.forEach((edge) => {
|
edges.forEach((edge) => {
|
||||||
const result = zWorkflowEdge.safeParse(edge);
|
const result = zWorkflowEdge.safeParse(edge);
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
const { message } = fromZodError(result.error, {
|
||||||
|
prefix: 'Unable to parse edge',
|
||||||
|
});
|
||||||
|
logger('nodes').warn({ edge: parseify(edge) }, message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
workflow.edges.push(result.data);
|
workflow.edges.push(result.data);
|
||||||
@ -29,7 +38,3 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
|||||||
|
|
||||||
return workflow;
|
return workflow;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
|
|
||||||
buildWorkflow(nodes)
|
|
||||||
);
|
|
||||||
|
@ -28,7 +28,6 @@ import {
|
|||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
VaeInputFieldTemplate,
|
VaeInputFieldTemplate,
|
||||||
VaeModelInputFieldTemplate,
|
VaeModelInputFieldTemplate,
|
||||||
isFieldType,
|
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
|
|
||||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||||
@ -422,9 +421,7 @@ const buildSchedulerInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getFieldType = (
|
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
|
||||||
schemaObject: InvocationFieldSchema
|
|
||||||
): FieldType => {
|
|
||||||
let fieldType = '';
|
let fieldType = '';
|
||||||
|
|
||||||
const { ui_type } = schemaObject;
|
const { ui_type } = schemaObject;
|
||||||
@ -460,10 +457,6 @@ export const getFieldType = (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isFieldType(fieldType)) {
|
|
||||||
throw `Field type "${fieldType}" is unknown!`;
|
|
||||||
}
|
|
||||||
|
|
||||||
return fieldType;
|
return fieldType;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -475,12 +468,9 @@ export const getFieldType = (
|
|||||||
export const buildInputFieldTemplate = (
|
export const buildInputFieldTemplate = (
|
||||||
nodeSchema: InvocationSchemaObject,
|
nodeSchema: InvocationSchemaObject,
|
||||||
fieldSchema: InvocationFieldSchema,
|
fieldSchema: InvocationFieldSchema,
|
||||||
name: string
|
name: string,
|
||||||
|
fieldType: FieldType
|
||||||
) => {
|
) => {
|
||||||
// console.log('input', schemaObject);
|
|
||||||
const fieldType = getFieldType(fieldSchema);
|
|
||||||
// console.log('input fieldType', fieldType);
|
|
||||||
|
|
||||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||||
|
|
||||||
const extra = {
|
const extra = {
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import * as png from '@stevebel/png';
|
||||||
|
import {
|
||||||
|
ImageMetadataAndWorkflow,
|
||||||
|
zCoreMetadata,
|
||||||
|
zWorkflow,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { get } from 'lodash-es';
|
||||||
|
|
||||||
|
export const getMetadataAndWorkflowFromImageBlob = async (
|
||||||
|
image: Blob
|
||||||
|
): Promise<ImageMetadataAndWorkflow> => {
|
||||||
|
const data: ImageMetadataAndWorkflow = {};
|
||||||
|
const buffer = await image.arrayBuffer();
|
||||||
|
const text = png.decode(buffer).text;
|
||||||
|
|
||||||
|
const rawMetadata = get(text, 'invokeai_metadata');
|
||||||
|
if (rawMetadata) {
|
||||||
|
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||||
|
if (metadataResult.success) {
|
||||||
|
data.metadata = metadataResult.data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const rawWorkflow = get(text, 'invokeai_workflow');
|
||||||
|
if (rawWorkflow) {
|
||||||
|
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||||
|
if (workflowResult.success) {
|
||||||
|
data.workflow = workflowResult.data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
};
|
@ -11,10 +11,10 @@ import {
|
|||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
|
||||||
@ -41,7 +41,9 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
// Handle Seamless Plugs
|
// Handle Seamless Plugs
|
||||||
const unetLoaderId = modelLoaderNodeId;
|
const unetLoaderId = modelLoaderNodeId;
|
||||||
let clipLoaderId = modelLoaderNodeId;
|
let clipLoaderId = modelLoaderNodeId;
|
||||||
if ([SEAMLESS, REFINER_SEAMLESS].includes(modelLoaderNodeId)) {
|
if (
|
||||||
|
[SEAMLESS, SDXL_REFINER_INPAINT_CREATE_MASK].includes(modelLoaderNodeId)
|
||||||
|
) {
|
||||||
clipLoaderId = SDXL_MODEL_LOADER;
|
clipLoaderId = SDXL_MODEL_LOADER;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,24 +1,28 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
CreateDenoiseMaskInvocation,
|
||||||
|
ImageDTO,
|
||||||
MetadataAccumulatorInvocation,
|
MetadataAccumulatorInvocation,
|
||||||
SeamlessModeInvocation,
|
SeamlessModeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
import {
|
import {
|
||||||
CANVAS_OUTPUT,
|
CANVAS_OUTPUT,
|
||||||
|
INPAINT_IMAGE_RESIZE_UP,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MASK_BLUR,
|
MASK_BLUR,
|
||||||
METADATA_ACCUMULATOR,
|
METADATA_ACCUMULATOR,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
SDXL_REFINER_DENOISE_LATENTS,
|
SDXL_REFINER_DENOISE_LATENTS,
|
||||||
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SDXL_REFINER_MODEL_LOADER,
|
SDXL_REFINER_MODEL_LOADER,
|
||||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
|
|
||||||
@ -26,7 +30,8 @@ export const addSDXLRefinerToGraph = (
|
|||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId?: string
|
modelLoaderNodeId?: string,
|
||||||
|
canvasInitImage?: ImageDTO
|
||||||
): void => {
|
): void => {
|
||||||
const {
|
const {
|
||||||
refinerModel,
|
refinerModel,
|
||||||
@ -38,7 +43,12 @@ export const addSDXLRefinerToGraph = (
|
|||||||
refinerStart,
|
refinerStart,
|
||||||
} = state.sdxl;
|
} = state.sdxl;
|
||||||
|
|
||||||
const { seamlessXAxis, seamlessYAxis } = state.generation;
|
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
|
||||||
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
|
boundingBoxScaleMethod
|
||||||
|
);
|
||||||
|
|
||||||
if (!refinerModel) {
|
if (!refinerModel) {
|
||||||
return;
|
return;
|
||||||
@ -108,8 +118,8 @@ export const addSDXLRefinerToGraph = (
|
|||||||
|
|
||||||
// Add Seamless To Refiner
|
// Add Seamless To Refiner
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
graph.nodes[REFINER_SEAMLESS] = {
|
graph.nodes[SDXL_REFINER_SEAMLESS] = {
|
||||||
id: REFINER_SEAMLESS,
|
id: SDXL_REFINER_SEAMLESS,
|
||||||
type: 'seamless',
|
type: 'seamless',
|
||||||
seamless_x: seamlessXAxis,
|
seamless_x: seamlessXAxis,
|
||||||
seamless_y: seamlessYAxis,
|
seamless_y: seamlessYAxis,
|
||||||
@ -122,13 +132,23 @@ export const addSDXLRefinerToGraph = (
|
|||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: REFINER_SEAMLESS,
|
node_id: SDXL_REFINER_SEAMLESS,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: REFINER_SEAMLESS,
|
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_SEAMLESS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_SEAMLESS,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -203,6 +223,61 @@ export const addSDXLRefinerToGraph = (
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (
|
||||||
|
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
|
||||||
|
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
|
||||||
|
) {
|
||||||
|
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||||
|
type: 'create_denoise_mask',
|
||||||
|
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
is_intermediate: true,
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isUsingScaledDimensions) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = {
|
||||||
|
...(graph.nodes[
|
||||||
|
SDXL_REFINER_INPAINT_CREATE_MASK
|
||||||
|
] as CreateDenoiseMaskInvocation),
|
||||||
|
image: canvasInitImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.edges.push(
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: MASK_BLUR,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'mask',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
||||||
|
field: 'denoise_mask',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
|
graph.id === SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH ||
|
||||||
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH
|
graph.id === SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH
|
||||||
@ -213,7 +288,7 @@ export const addSDXLRefinerToGraph = (
|
|||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
node_id: CANVAS_OUTPUT,
|
node_id: isUsingScaledDimensions ? LATENTS_TO_IMAGE : CANVAS_OUTPUT,
|
||||||
field: 'latents',
|
field: 'latents',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -229,20 +304,4 @@ export const addSDXLRefinerToGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
|
|
||||||
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
|
|
||||||
) {
|
|
||||||
graph.edges.push({
|
|
||||||
source: {
|
|
||||||
node_id: MASK_BLUR,
|
|
||||||
field: 'image',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: SDXL_REFINER_DENOISE_LATENTS,
|
|
||||||
field: 'mask',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
@ -20,6 +20,7 @@ import {
|
|||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
@ -32,6 +33,7 @@ export const addVAEToGraph = (
|
|||||||
): void => {
|
): void => {
|
||||||
const { vae } = state.generation;
|
const { vae } = state.generation;
|
||||||
const { boundingBoxScaleMethod } = state.canvas;
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
|
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||||
boundingBoxScaleMethod
|
boundingBoxScaleMethod
|
||||||
@ -146,6 +148,24 @@ export const addVAEToGraph = (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shouldUseSDXLRefiner) {
|
||||||
|
if (
|
||||||
|
graph.id === SDXL_CANVAS_INPAINT_GRAPH ||
|
||||||
|
graph.id === SDXL_CANVAS_OUTPAINT_GRAPH
|
||||||
|
) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||||
|
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (vae && metadataAccumulator) {
|
if (vae && metadataAccumulator) {
|
||||||
metadataAccumulator.vae = vae;
|
metadataAccumulator.vae = vae;
|
||||||
}
|
}
|
||||||
|
@ -20,10 +20,10 @@ import {
|
|||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
@ -367,8 +367,15 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (shouldUseSDXLRefiner) {
|
if (shouldUseSDXLRefiner) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
addSDXLRefinerToGraph(
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
state,
|
||||||
|
graph,
|
||||||
|
SDXL_DENOISE_LATENTS,
|
||||||
|
modelLoaderNodeId
|
||||||
|
);
|
||||||
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
|
@ -36,10 +36,10 @@ import {
|
|||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
RANDOM_INT,
|
||||||
RANGE_OF_SIZE,
|
RANGE_OF_SIZE,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_INPAINT_GRAPH,
|
SDXL_CANVAS_INPAINT_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
@ -628,9 +628,12 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
state,
|
state,
|
||||||
graph,
|
graph,
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
modelLoaderNodeId
|
modelLoaderNodeId,
|
||||||
|
canvasInitImage
|
||||||
);
|
);
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
|
@ -41,10 +41,10 @@ import {
|
|||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
RANDOM_INT,
|
||||||
RANGE_OF_SIZE,
|
RANGE_OF_SIZE,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_OUTPAINT_GRAPH,
|
SDXL_CANVAS_OUTPAINT_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
@ -766,9 +766,12 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
state,
|
state,
|
||||||
graph,
|
graph,
|
||||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
modelLoaderNodeId
|
modelLoaderNodeId,
|
||||||
|
canvasInitImage
|
||||||
);
|
);
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
|
@ -22,10 +22,10 @@ import {
|
|||||||
NOISE,
|
NOISE,
|
||||||
ONNX_MODEL_LOADER,
|
ONNX_MODEL_LOADER,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
@ -347,8 +347,15 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (shouldUseSDXLRefiner) {
|
if (shouldUseSDXLRefiner) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
addSDXLRefinerToGraph(
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
state,
|
||||||
|
graph,
|
||||||
|
SDXL_DENOISE_LATENTS,
|
||||||
|
modelLoaderNodeId
|
||||||
|
);
|
||||||
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
|
@ -21,11 +21,11 @@ import {
|
|||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
RESIZE,
|
RESIZE,
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
|
||||||
@ -368,7 +368,9 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (shouldUseSDXLRefiner) {
|
if (shouldUseSDXLRefiner) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
|
@ -16,9 +16,9 @@ import {
|
|||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
REFINER_SEAMLESS,
|
|
||||||
SDXL_DENOISE_LATENTS,
|
SDXL_DENOISE_LATENTS,
|
||||||
SDXL_MODEL_LOADER,
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_SEAMLESS,
|
||||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
@ -261,7 +261,9 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (shouldUseSDXLRefiner) {
|
if (shouldUseSDXLRefiner) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
modelLoaderNodeId = REFINER_SEAMLESS;
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
|
@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
|
|||||||
import { Graph } from 'services/api/types';
|
import { Graph } from 'services/api/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { buildWorkflow } from '../buildWorkflow';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -34,12 +35,13 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
|||||||
const { nodes, edges } = nodesState;
|
const { nodes, edges } = nodesState;
|
||||||
|
|
||||||
const filteredNodes = nodes.filter(isInvocationNode);
|
const filteredNodes = nodes.filter(isInvocationNode);
|
||||||
|
const workflowJSON = JSON.stringify(buildWorkflow(nodesState));
|
||||||
|
|
||||||
// Reduce the node editor nodes into invocation graph nodes
|
// Reduce the node editor nodes into invocation graph nodes
|
||||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
|
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
|
||||||
(nodesAccumulator, node) => {
|
(nodesAccumulator, node) => {
|
||||||
const { id, data } = node;
|
const { id, data } = node;
|
||||||
const { type, inputs } = data;
|
const { type, inputs, isIntermediate, embedWorkflow } = data;
|
||||||
|
|
||||||
// Transform each node's inputs to simple key-value pairs
|
// Transform each node's inputs to simple key-value pairs
|
||||||
const transformedInputs = reduce(
|
const transformedInputs = reduce(
|
||||||
@ -58,8 +60,14 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
|||||||
type,
|
type,
|
||||||
id,
|
id,
|
||||||
...transformedInputs,
|
...transformedInputs,
|
||||||
|
is_intermediate: isIntermediate,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (embedWorkflow) {
|
||||||
|
// add the workflow to the node
|
||||||
|
Object.assign(graphNode, { workflow: workflowJSON });
|
||||||
|
}
|
||||||
|
|
||||||
// Add it to the nodes object
|
// Add it to the nodes object
|
||||||
Object.assign(nodesAccumulator, {
|
Object.assign(nodesAccumulator, {
|
||||||
[id]: graphNode,
|
[id]: graphNode,
|
||||||
|
@ -56,8 +56,9 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
|
|||||||
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
|
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
|
||||||
'sdxl_refiner_negative_conditioning';
|
'sdxl_refiner_negative_conditioning';
|
||||||
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
|
||||||
|
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
|
||||||
export const SEAMLESS = 'seamless';
|
export const SEAMLESS = 'seamless';
|
||||||
export const REFINER_SEAMLESS = 'refiner_seamless';
|
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
|
@ -4,10 +4,12 @@ import { reduce } from 'lodash-es';
|
|||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import {
|
import {
|
||||||
|
FieldType,
|
||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
OutputFieldTemplate,
|
OutputFieldTemplate,
|
||||||
|
isFieldType,
|
||||||
isInvocationFieldSchema,
|
isInvocationFieldSchema,
|
||||||
isInvocationOutputSchemaObject,
|
isInvocationOutputSchemaObject,
|
||||||
isInvocationSchemaObject,
|
isInvocationSchemaObject,
|
||||||
@ -16,23 +18,35 @@ import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
|||||||
|
|
||||||
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
|
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
|
||||||
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
|
||||||
|
const RESERVED_FIELD_TYPES = [
|
||||||
|
'WorkflowField',
|
||||||
|
'MetadataField',
|
||||||
|
'IsIntermediate',
|
||||||
|
];
|
||||||
|
|
||||||
const invocationDenylist: AnyInvocationType[] = [
|
const invocationDenylist: AnyInvocationType[] = [
|
||||||
'graph',
|
'graph',
|
||||||
'metadata_accumulator',
|
'metadata_accumulator',
|
||||||
];
|
];
|
||||||
|
|
||||||
const isAllowedInputField = (nodeType: string, fieldName: string) => {
|
const isReservedInputField = (nodeType: string, fieldName: string) => {
|
||||||
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
if (nodeType === 'collect' && fieldName === 'collection') {
|
if (nodeType === 'collect' && fieldName === 'collection') {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
if (nodeType === 'iterate' && fieldName === 'index') {
|
if (nodeType === 'iterate' && fieldName === 'index') {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
return true;
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
const isReservedFieldType = (fieldType: FieldType) => {
|
||||||
|
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
const isAllowedOutputField = (nodeType: string, fieldName: string) => {
|
const isAllowedOutputField = (nodeType: string, fieldName: string) => {
|
||||||
@ -54,7 +68,7 @@ export const parseSchema = (
|
|||||||
|
|
||||||
const invocations = filteredSchemas.reduce<
|
const invocations = filteredSchemas.reduce<
|
||||||
Record<string, InvocationTemplate>
|
Record<string, InvocationTemplate>
|
||||||
>((acc, schema) => {
|
>((invocationsAccumulator, schema) => {
|
||||||
const type = schema.properties.type.default;
|
const type = schema.properties.type.default;
|
||||||
const title = schema.title.replace('Invocation', '');
|
const title = schema.title.replace('Invocation', '');
|
||||||
const tags = schema.tags ?? [];
|
const tags = schema.tags ?? [];
|
||||||
@ -62,10 +76,14 @@ export const parseSchema = (
|
|||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
schema.properties,
|
schema.properties,
|
||||||
(inputsAccumulator, property, propertyName) => {
|
(
|
||||||
if (!isAllowedInputField(type, propertyName)) {
|
inputsAccumulator: Record<string, InputFieldTemplate>,
|
||||||
|
property,
|
||||||
|
propertyName
|
||||||
|
) => {
|
||||||
|
if (isReservedInputField(type, propertyName)) {
|
||||||
logger('nodes').trace(
|
logger('nodes').trace(
|
||||||
{ type, propertyName, property: parseify(property) },
|
{ node: type, fieldName: propertyName, field: parseify(property) },
|
||||||
'Skipped reserved input field'
|
'Skipped reserved input field'
|
||||||
);
|
);
|
||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
@ -73,37 +91,80 @@ export const parseSchema = (
|
|||||||
|
|
||||||
if (!isInvocationFieldSchema(property)) {
|
if (!isInvocationFieldSchema(property)) {
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
{ type, propertyName, property: parseify(property) },
|
{ node: type, propertyName, property: parseify(property) },
|
||||||
'Unhandled input property'
|
'Unhandled input property'
|
||||||
);
|
);
|
||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const field = buildInputFieldTemplate(schema, property, propertyName);
|
const fieldType = getFieldType(property);
|
||||||
|
|
||||||
if (field) {
|
if (!isFieldType(fieldType)) {
|
||||||
inputsAccumulator[propertyName] = field;
|
logger('nodes').warn(
|
||||||
|
{
|
||||||
|
node: type,
|
||||||
|
fieldName: propertyName,
|
||||||
|
fieldType,
|
||||||
|
field: parseify(property),
|
||||||
|
},
|
||||||
|
'Skipping unknown input field type'
|
||||||
|
);
|
||||||
|
return inputsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isReservedFieldType(fieldType)) {
|
||||||
|
logger('nodes').trace(
|
||||||
|
{
|
||||||
|
node: type,
|
||||||
|
fieldName: propertyName,
|
||||||
|
fieldType,
|
||||||
|
field: parseify(property),
|
||||||
|
},
|
||||||
|
'Skipping reserved field type'
|
||||||
|
);
|
||||||
|
return inputsAccumulator;
|
||||||
|
}
|
||||||
|
|
||||||
|
const field = buildInputFieldTemplate(
|
||||||
|
schema,
|
||||||
|
property,
|
||||||
|
propertyName,
|
||||||
|
fieldType
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!field) {
|
||||||
|
logger('nodes').debug(
|
||||||
|
{
|
||||||
|
node: type,
|
||||||
|
fieldName: propertyName,
|
||||||
|
fieldType,
|
||||||
|
field: parseify(property),
|
||||||
|
},
|
||||||
|
'Skipping input field with no template'
|
||||||
|
);
|
||||||
|
return inputsAccumulator;
|
||||||
|
}
|
||||||
|
|
||||||
|
inputsAccumulator[propertyName] = field;
|
||||||
return inputsAccumulator;
|
return inputsAccumulator;
|
||||||
},
|
},
|
||||||
{} as Record<string, InputFieldTemplate>
|
{}
|
||||||
);
|
);
|
||||||
|
|
||||||
const outputSchemaName = schema.output.$ref.split('/').pop();
|
const outputSchemaName = schema.output.$ref.split('/').pop();
|
||||||
|
|
||||||
if (!outputSchemaName) {
|
if (!outputSchemaName) {
|
||||||
logger('nodes').error(
|
logger('nodes').warn(
|
||||||
{ outputRefObject: parseify(schema.output) },
|
{ outputRefObject: parseify(schema.output) },
|
||||||
'No output schema name found in ref object'
|
'No output schema name found in ref object'
|
||||||
);
|
);
|
||||||
throw 'No output schema name found in ref object';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
|
const outputSchema = openAPI.components?.schemas?.[outputSchemaName];
|
||||||
if (!outputSchema) {
|
if (!outputSchema) {
|
||||||
logger('nodes').error({ outputSchemaName }, 'Output schema not found');
|
logger('nodes').warn({ outputSchemaName }, 'Output schema not found');
|
||||||
throw 'Output schema not found';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isInvocationOutputSchemaObject(outputSchema)) {
|
if (!isInvocationOutputSchemaObject(outputSchema)) {
|
||||||
@ -111,7 +172,7 @@ export const parseSchema = (
|
|||||||
{ outputSchema: parseify(outputSchema) },
|
{ outputSchema: parseify(outputSchema) },
|
||||||
'Invalid output schema'
|
'Invalid output schema'
|
||||||
);
|
);
|
||||||
throw 'Invalid output schema';
|
return invocationsAccumulator;
|
||||||
}
|
}
|
||||||
|
|
||||||
const outputType = outputSchema.properties.type.default;
|
const outputType = outputSchema.properties.type.default;
|
||||||
@ -136,6 +197,15 @@ export const parseSchema = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const fieldType = getFieldType(property);
|
const fieldType = getFieldType(property);
|
||||||
|
|
||||||
|
if (!isFieldType(fieldType)) {
|
||||||
|
logger('nodes').warn(
|
||||||
|
{ fieldName: propertyName, fieldType, field: parseify(property) },
|
||||||
|
'Skipping unknown output field type'
|
||||||
|
);
|
||||||
|
return outputsAccumulator;
|
||||||
|
}
|
||||||
|
|
||||||
outputsAccumulator[propertyName] = {
|
outputsAccumulator[propertyName] = {
|
||||||
fieldKind: 'output',
|
fieldKind: 'output',
|
||||||
name: propertyName,
|
name: propertyName,
|
||||||
@ -162,9 +232,9 @@ export const parseSchema = (
|
|||||||
outputType,
|
outputType,
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(acc, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
|
||||||
return acc;
|
return invocationsAccumulator;
|
||||||
}, {});
|
}, {});
|
||||||
|
|
||||||
return invocations;
|
return invocations;
|
||||||
|
@ -68,7 +68,7 @@ const ParamControlNetCollapse = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
<IAICollapse label="Control Adapters" activeLabel={activeLabel}>
|
||||||
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
<Flex sx={{ flexDir: 'column', gap: 2 }}>
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { CoreMetadata } from 'features/nodes/types/types';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useAppToaster } from '../../../app/components/Toaster';
|
import { useAppToaster } from '../../../app/components/Toaster';
|
||||||
@ -64,7 +65,7 @@ export const usePreselectedImage = () => {
|
|||||||
if (selectedImage.action === 'useAllParameters') {
|
if (selectedImage.action === 'useAllParameters') {
|
||||||
setImageNameForMetadata(selectedImage?.imageName);
|
setImageNameForMetadata(selectedImage?.imageName);
|
||||||
if (selectedImageMetadata) {
|
if (selectedImageMetadata) {
|
||||||
recallAllParameters(selectedImageMetadata.metadata);
|
recallAllParameters(selectedImageMetadata.metadata as CoreMetadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { CoreMetadata } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
setNegativeStylePromptSDXL,
|
setNegativeStylePromptSDXL,
|
||||||
@ -13,7 +14,7 @@ import {
|
|||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ImageDTO, UnsafeImageMetadata } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
import {
|
import {
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
@ -317,7 +318,7 @@ export const useRecallParameters = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const recallAllParameters = useCallback(
|
const recallAllParameters = useCallback(
|
||||||
(metadata: UnsafeImageMetadata['metadata'] | undefined) => {
|
(metadata: CoreMetadata | undefined) => {
|
||||||
if (!metadata) {
|
if (!metadata) {
|
||||||
allParameterNotSetToast();
|
allParameterNotSetToast();
|
||||||
return;
|
return;
|
||||||
|
@ -29,13 +29,15 @@ export const $projectId = atom<string | undefined>();
|
|||||||
* @example
|
* @example
|
||||||
* const { get, post, del } = $client.get();
|
* const { get, post, del } = $client.get();
|
||||||
*/
|
*/
|
||||||
export const $client = computed([$authToken, $baseUrl, $projectId], (authToken, baseUrl, projectId) =>
|
export const $client = computed(
|
||||||
createClient<paths>({
|
[$authToken, $baseUrl, $projectId],
|
||||||
headers: {
|
(authToken, baseUrl, projectId) =>
|
||||||
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
|
createClient<paths>({
|
||||||
...(projectId ? { "project-id": projectId } : {})
|
headers: {
|
||||||
},
|
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
|
||||||
// do not include `api/v1` in the base url for this client
|
...(projectId ? { 'project-id': projectId } : {}),
|
||||||
baseUrl: `${baseUrl ?? ''}`,
|
},
|
||||||
})
|
// do not include `api/v1` in the base url for this client
|
||||||
|
baseUrl: `${baseUrl ?? ''}`,
|
||||||
|
})
|
||||||
);
|
);
|
||||||
|
@ -19,7 +19,7 @@ export const boardsApi = api.injectEndpoints({
|
|||||||
*/
|
*/
|
||||||
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
|
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
|
||||||
query: (arg) => ({ url: 'boards/', params: arg }),
|
query: (arg) => ({ url: 'boards/', params: arg }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
// any list of boards
|
// any list of boards
|
||||||
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
|
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ export const boardsApi = api.injectEndpoints({
|
|||||||
url: 'boards/',
|
url: 'boards/',
|
||||||
params: { all: true },
|
params: { all: true },
|
||||||
}),
|
}),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
// any list of boards
|
// any list of boards
|
||||||
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
|
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ import {
|
|||||||
IMAGE_CATEGORIES,
|
IMAGE_CATEGORIES,
|
||||||
IMAGE_LIMIT,
|
IMAGE_LIMIT,
|
||||||
} from 'features/gallery/store/types';
|
} from 'features/gallery/store/types';
|
||||||
import { keyBy } from 'lodash';
|
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
|
||||||
|
import { keyBy } from 'lodash-es';
|
||||||
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
import { ApiFullTagDescription, LIST_TAG, api } from '..';
|
||||||
import { components, paths } from '../schema';
|
import { components, paths } from '../schema';
|
||||||
import {
|
import {
|
||||||
@ -26,6 +27,7 @@ import {
|
|||||||
imagesSelectors,
|
imagesSelectors,
|
||||||
} from '../util';
|
} from '../util';
|
||||||
import { boardsApi } from './boards';
|
import { boardsApi } from './boards';
|
||||||
|
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
|
||||||
|
|
||||||
export const imagesApi = api.injectEndpoints({
|
export const imagesApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
@ -113,6 +115,20 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
keepUnusedDataFor: 86400, // 24 hours
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
|
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
|
||||||
|
query: (image_name) => ({
|
||||||
|
url: `images/i/${image_name}/full`,
|
||||||
|
responseHandler: async (res) => {
|
||||||
|
return await res.blob();
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
providesTags: (result, error, image_name) => [
|
||||||
|
{ type: 'ImageMetadataFromFile', id: image_name },
|
||||||
|
],
|
||||||
|
transformResponse: (response: Blob) =>
|
||||||
|
getMetadataAndWorkflowFromImageBlob(response),
|
||||||
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
|
}),
|
||||||
clearIntermediates: build.mutation<number, void>({
|
clearIntermediates: build.mutation<number, void>({
|
||||||
query: () => ({ url: `images/clear-intermediates`, method: 'POST' }),
|
query: () => ({ url: `images/clear-intermediates`, method: 'POST' }),
|
||||||
invalidatesTags: ['IntermediatesCount'],
|
invalidatesTags: ['IntermediatesCount'],
|
||||||
@ -357,7 +373,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
],
|
],
|
||||||
async onQueryStarted(
|
async onQueryStarted(
|
||||||
{ imageDTO, session_id },
|
{ imageDTO, session_id },
|
||||||
{ dispatch, queryFulfilled, getState }
|
{ dispatch, queryFulfilled }
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for `changeImageSessionId`:
|
* Cache changes for `changeImageSessionId`:
|
||||||
@ -432,7 +448,9 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
data.updated_image_names.includes(i.image_name)
|
data.updated_image_names.includes(i.image_name)
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!updatedImages[0]) return;
|
if (!updatedImages[0]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// assume all images are on the same board/category
|
// assume all images are on the same board/category
|
||||||
const categories = getCategories(updatedImages[0]);
|
const categories = getCategories(updatedImages[0]);
|
||||||
@ -544,7 +562,9 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
data.updated_image_names.includes(i.image_name)
|
data.updated_image_names.includes(i.image_name)
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!updatedImages[0]) return;
|
if (!updatedImages[0]) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// assume all images are on the same board/category
|
// assume all images are on the same board/category
|
||||||
const categories = getCategories(updatedImages[0]);
|
const categories = getCategories(updatedImages[0]);
|
||||||
const boardId = updatedImages[0].board_id;
|
const boardId = updatedImages[0].board_id;
|
||||||
@ -645,17 +665,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async onQueryStarted(
|
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||||
{
|
|
||||||
file,
|
|
||||||
image_category,
|
|
||||||
is_intermediate,
|
|
||||||
postUploadAction,
|
|
||||||
session_id,
|
|
||||||
board_id,
|
|
||||||
},
|
|
||||||
{ dispatch, queryFulfilled }
|
|
||||||
) {
|
|
||||||
try {
|
try {
|
||||||
/**
|
/**
|
||||||
* NOTE: PESSIMISTIC UPDATE
|
* NOTE: PESSIMISTIC UPDATE
|
||||||
@ -712,7 +722,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
|
|
||||||
deleteBoard: build.mutation<DeleteBoardResult, string>({
|
deleteBoard: build.mutation<DeleteBoardResult, string>({
|
||||||
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
|
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
|
||||||
invalidatesTags: (result, error, board_id) => [
|
invalidatesTags: () => [
|
||||||
{ type: 'Board', id: LIST_TAG },
|
{ type: 'Board', id: LIST_TAG },
|
||||||
// invalidate the 'No Board' cache
|
// invalidate the 'No Board' cache
|
||||||
{
|
{
|
||||||
@ -732,7 +742,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
{ type: 'BoardImagesTotal', id: 'none' },
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
{ type: 'BoardAssetsTotal', id: 'none' },
|
||||||
],
|
],
|
||||||
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
|
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for deleteBoard:
|
* Cache changes for deleteBoard:
|
||||||
* - Update every image in the 'getImageDTO' cache that has the board_id
|
* - Update every image in the 'getImageDTO' cache that has the board_id
|
||||||
@ -802,7 +812,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
params: { include_images: true },
|
params: { include_images: true },
|
||||||
}),
|
}),
|
||||||
invalidatesTags: (result, error, board_id) => [
|
invalidatesTags: () => [
|
||||||
{ type: 'Board', id: LIST_TAG },
|
{ type: 'Board', id: LIST_TAG },
|
||||||
{
|
{
|
||||||
type: 'ImageList',
|
type: 'ImageList',
|
||||||
@ -821,7 +831,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
{ type: 'BoardImagesTotal', id: 'none' },
|
{ type: 'BoardImagesTotal', id: 'none' },
|
||||||
{ type: 'BoardAssetsTotal', id: 'none' },
|
{ type: 'BoardAssetsTotal', id: 'none' },
|
||||||
],
|
],
|
||||||
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
|
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
|
||||||
/**
|
/**
|
||||||
* Cache changes for deleteBoardAndImages:
|
* Cache changes for deleteBoardAndImages:
|
||||||
* - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~
|
* - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~
|
||||||
@ -1253,9 +1263,8 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
];
|
];
|
||||||
|
|
||||||
result?.removed_image_names.forEach((image_name) => {
|
result?.removed_image_names.forEach((image_name) => {
|
||||||
const board_id = imageDTOs.find(
|
const board_id = imageDTOs.find((i) => i.image_name === image_name)
|
||||||
(i) => i.image_name === image_name
|
?.board_id;
|
||||||
)?.board_id;
|
|
||||||
|
|
||||||
if (!board_id || touchedBoardIds.includes(board_id)) {
|
if (!board_id || touchedBoardIds.includes(board_id)) {
|
||||||
return;
|
return;
|
||||||
@ -1385,4 +1394,5 @@ export const {
|
|||||||
useDeleteBoardMutation,
|
useDeleteBoardMutation,
|
||||||
useStarImagesMutation,
|
useStarImagesMutation,
|
||||||
useUnstarImagesMutation,
|
useUnstarImagesMutation,
|
||||||
|
useGetImageMetadataFromFileQuery,
|
||||||
} = imagesApi;
|
} = imagesApi;
|
||||||
|
@ -178,7 +178,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
return `models/?${query}`;
|
return `models/?${query}`;
|
||||||
},
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'OnnxModel', id: LIST_TAG },
|
{ type: 'OnnxModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -194,11 +194,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: { models: OnnxModelConfig[] }) => {
|
||||||
response: { models: OnnxModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<OnnxModelConfigEntity>(
|
const entities = createModelEntities<OnnxModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -221,7 +217,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
return `models/?${query}`;
|
return `models/?${query}`;
|
||||||
},
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -237,11 +233,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: { models: MainModelConfig[] }) => {
|
||||||
response: { models: MainModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<MainModelConfigEntity>(
|
const entities = createModelEntities<MainModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -361,7 +353,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'LoRAModel', id: LIST_TAG },
|
{ type: 'LoRAModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -377,11 +369,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: { models: LoRAModelConfig[] }) => {
|
||||||
response: { models: LoRAModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<LoRAModelConfigEntity>(
|
const entities = createModelEntities<LoRAModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -421,7 +409,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
void
|
void
|
||||||
>({
|
>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'ControlNetModel', id: LIST_TAG },
|
{ type: 'ControlNetModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -437,11 +425,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
|
||||||
response: { models: ControlNetModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<ControlNetModelConfigEntity>(
|
const entities = createModelEntities<ControlNetModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -453,7 +437,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
}),
|
}),
|
||||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'VaeModel', id: LIST_TAG },
|
{ type: 'VaeModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -469,11 +453,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: { models: VaeModelConfig[] }) => {
|
||||||
response: { models: VaeModelConfig[] },
|
|
||||||
meta,
|
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<VaeModelConfigEntity>(
|
const entities = createModelEntities<VaeModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -488,7 +468,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
void
|
void
|
||||||
>({
|
>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'TextualInversionModel', id: LIST_TAG },
|
{ type: 'TextualInversionModel', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
@ -504,11 +484,9 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
|
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
transformResponse: (
|
transformResponse: (response: {
|
||||||
response: { models: TextualInversionModelConfig[] },
|
models: TextualInversionModelConfig[];
|
||||||
meta,
|
}) => {
|
||||||
arg
|
|
||||||
) => {
|
|
||||||
const entities = createModelEntities<TextualInversionModelConfigEntity>(
|
const entities = createModelEntities<TextualInversionModelConfigEntity>(
|
||||||
response.models
|
response.models
|
||||||
);
|
);
|
||||||
@ -525,7 +503,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
url: `/models/search?${folderQueryStr}`,
|
url: `/models/search?${folderQueryStr}`,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'ScannedModels', id: LIST_TAG },
|
{ type: 'ScannedModels', id: LIST_TAG },
|
||||||
];
|
];
|
||||||
|
@ -16,6 +16,7 @@ export const tagTypes = [
|
|||||||
'ImageNameList',
|
'ImageNameList',
|
||||||
'ImageList',
|
'ImageList',
|
||||||
'ImageMetadata',
|
'ImageMetadata',
|
||||||
|
'ImageMetadataFromFile',
|
||||||
'Model',
|
'Model',
|
||||||
];
|
];
|
||||||
export type ApiFullTagDescription = FullTagDescription<
|
export type ApiFullTagDescription = FullTagDescription<
|
||||||
@ -39,7 +40,7 @@ const dynamicBaseQuery: BaseQueryFn<
|
|||||||
headers.set('Authorization', `Bearer ${authToken}`);
|
headers.set('Authorization', `Bearer ${authToken}`);
|
||||||
}
|
}
|
||||||
if (projectId) {
|
if (projectId) {
|
||||||
headers.set("project-id", projectId)
|
headers.set('project-id', projectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
return headers;
|
return headers;
|
||||||
|
2154
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
2154
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -1,14 +1,16 @@
|
|||||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
function getCircularReplacer() {
|
function getCircularReplacer() {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
const ancestors: Record<string, any>[] = [];
|
const ancestors: Record<string, any>[] = [];
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
return function (key: string, value: any) {
|
return function (key: string, value: any) {
|
||||||
if (typeof value !== 'object' || value === null) {
|
if (typeof value !== 'object' || value === null) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
// `this` is the object that value is contained in,
|
// `this` is the object that value is contained in, i.e., its direct parent.
|
||||||
// i.e., its direct parent.
|
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||||
// @ts-ignore
|
// @ts-ignore don't think it's possible to not have TS complain about this...
|
||||||
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
|
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
|
||||||
ancestors.pop();
|
ancestors.pop();
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ export const sessionInvoked = createAsyncThunk<
|
|||||||
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
|
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
|
||||||
const { session_id } = arg;
|
const { session_id } = arg;
|
||||||
const { PUT } = $client.get();
|
const { PUT } = $client.get();
|
||||||
const { data, error, response } = await PUT(
|
const { error, response } = await PUT(
|
||||||
'/api/v1/sessions/{session_id}/invoke',
|
'/api/v1/sessions/{session_id}/invoke',
|
||||||
{
|
{
|
||||||
params: { query: { all: true }, path: { session_id } },
|
params: { query: { all: true }, path: { session_id } },
|
||||||
@ -85,6 +85,7 @@ export const sessionInvoked = createAsyncThunk<
|
|||||||
return rejectWithValue({
|
return rejectWithValue({
|
||||||
arg,
|
arg,
|
||||||
status: response.status,
|
status: response.status,
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
error: (error as any).body.detail,
|
error: (error as any).body.detail,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -124,14 +125,11 @@ export const sessionCanceled = createAsyncThunk<
|
|||||||
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
|
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
|
||||||
const { session_id } = arg;
|
const { session_id } = arg;
|
||||||
const { DELETE } = $client.get();
|
const { DELETE } = $client.get();
|
||||||
const { data, error, response } = await DELETE(
|
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
|
||||||
'/api/v1/sessions/{session_id}/invoke',
|
params: {
|
||||||
{
|
path: { session_id },
|
||||||
params: {
|
},
|
||||||
path: { session_id },
|
});
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
return rejectWithValue({ arg, error });
|
return rejectWithValue({ arg, error });
|
||||||
@ -164,7 +162,7 @@ export const listedSessions = createAsyncThunk<
|
|||||||
>('api/listSessions', async (arg, { rejectWithValue }) => {
|
>('api/listSessions', async (arg, { rejectWithValue }) => {
|
||||||
const { params } = arg;
|
const { params } = arg;
|
||||||
const { GET } = $client.get();
|
const { GET } = $client.get();
|
||||||
const { data, error, response } = await GET('/api/v1/sessions/', {
|
const { data, error } = await GET('/api/v1/sessions/', {
|
||||||
params,
|
params,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -26,15 +26,21 @@ export const getIsImageInDateRange = (
|
|||||||
|
|
||||||
for (let index = 0; index < totalCachedImageDtos.length; index++) {
|
for (let index = 0; index < totalCachedImageDtos.length; index++) {
|
||||||
const image = totalCachedImageDtos[index];
|
const image = totalCachedImageDtos[index];
|
||||||
if (image?.starred) cachedStarredImages.push(image);
|
if (image?.starred) {
|
||||||
if (!image?.starred) cachedUnstarredImages.push(image);
|
cachedStarredImages.push(image);
|
||||||
|
}
|
||||||
|
if (!image?.starred) {
|
||||||
|
cachedUnstarredImages.push(image);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageDTO.starred) {
|
if (imageDTO.starred) {
|
||||||
const lastStarredImage =
|
const lastStarredImage =
|
||||||
cachedStarredImages[cachedStarredImages.length - 1];
|
cachedStarredImages[cachedStarredImages.length - 1];
|
||||||
// if starring or already starred, want to look in list of starred images
|
// if starring or already starred, want to look in list of starred images
|
||||||
if (!lastStarredImage) return true; // no starred images showing, so always show this one
|
if (!lastStarredImage) {
|
||||||
|
return true;
|
||||||
|
} // no starred images showing, so always show this one
|
||||||
const createdDate = new Date(imageDTO.created_at);
|
const createdDate = new Date(imageDTO.created_at);
|
||||||
const oldestDate = new Date(lastStarredImage.created_at);
|
const oldestDate = new Date(lastStarredImage.created_at);
|
||||||
return createdDate >= oldestDate;
|
return createdDate >= oldestDate;
|
||||||
@ -42,7 +48,9 @@ export const getIsImageInDateRange = (
|
|||||||
const lastUnstarredImage =
|
const lastUnstarredImage =
|
||||||
cachedUnstarredImages[cachedUnstarredImages.length - 1];
|
cachedUnstarredImages[cachedUnstarredImages.length - 1];
|
||||||
// if unstarring or already unstarred, want to look in list of unstarred images
|
// if unstarring or already unstarred, want to look in list of unstarred images
|
||||||
if (!lastUnstarredImage) return false; // no unstarred images showing, so don't show this one
|
if (!lastUnstarredImage) {
|
||||||
|
return false;
|
||||||
|
} // no unstarred images showing, so don't show this one
|
||||||
const createdDate = new Date(imageDTO.created_at);
|
const createdDate = new Date(imageDTO.created_at);
|
||||||
const oldestDate = new Date(lastUnstarredImage.created_at);
|
const oldestDate = new Date(lastUnstarredImage.created_at);
|
||||||
return createdDate >= oldestDate;
|
return createdDate >= oldestDate;
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user