Merge branch 'main' into feat/taesd

This commit is contained in:
Kevin Turner 2023-08-31 20:12:00 -07:00
commit bc1bce18b0
194 changed files with 6664 additions and 4050 deletions

4
.github/CODEOWNERS vendored
View File

@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
/mkdocs.yml @lstein @blessedcoolant
# nodes
@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick
# front ends
/invokeai/frontend/CLI @lstein

View File

@ -29,12 +29,13 @@ The first set of things we need to do when creating a new Invocation are -
- Create a new class that derives from a predefined parent class called
`BaseInvocation`.
- The name of every Invocation must end with the word `Invocation` in order for
it to be recognized as an Invocation.
- Every Invocation must have a `docstring` that describes what this Invocation
does.
- Every Invocation must have a unique `type` field defined which becomes its
indentifier.
- While not strictly required, we suggest every invocation class name ends in
"Invocation", eg "CropImageInvocation".
- Every Invocation must use the `@invocation` decorator to provide its unique
invocation type. You may also provide its title, tags and category using the
decorator.
- Invocations are strictly typed. We make use of the native
[typing](https://docs.python.org/3/library/typing.html) library and the
installed [pydantic](https://pydantic-docs.helpmanual.io/) library for
@ -43,12 +44,11 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from typing import Literal
from .baseinvocation import BaseInvocation
from .baseinvocation import BaseInvocation, invocation
@invocation('resize')
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
```
That's great.
@ -62,8 +62,10 @@ our Invocation takes.
### **Inputs**
Every Invocation input is a pydantic `Field` and like everything else should be
strictly typed and defined.
Every Invocation input must be defined using the `InputField` function. This is
a wrapper around the pydantic `Field` function, which handles a few extra things
and provides type hints. Like everything else, this should be strictly typed and
defined.
So let us create these inputs for our Invocation. First up, the `image` input we
need. Generally, we can use standard variable types in Python but InvokeAI
@ -76,55 +78,51 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation
from ..models.image import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
image: ImageField = InputField(description="The input image")
```
Let us break down our input code.
```python
image: Union[ImageField, None] = Field(description="The input image", default=None)
image: ImageField = InputField(description="The input image")
```
| Part | Value | Description |
| --------- | ---------------------------------------------------- | -------------------------------------------------------------------------------------------------- |
| Name | `image` | The variable that will hold our image |
| Type Hint | `Union[ImageField, None]` | The types for our field. Indicates that the image can either be an `ImageField` type or `None` |
| Field | `Field(description="The input image", default=None)` | The image variable is a field which needs a description and a default value that we set to `None`. |
| Part | Value | Description |
| --------- | ------------------------------------------- | ------------------------------------------------------------------------------- |
| Name | `image` | The variable that will hold our image |
| Type Hint | `ImageField` | The types for our field. Indicates that the image must be an `ImageField` type. |
| Field | `InputField(description="The input image")` | The image variable is an `InputField` which needs a description. |
Great. Now let us create our other inputs for `width` and `height`
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation
from ..models.image import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
image: ImageField = InputField(description="The input image")
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
```
As you might have noticed, we added two new parameters to the field type for
`width` and `height` called `gt` and `le`. These basically stand for _greater
than or equal to_ and _less than or equal to_. There are various other param
types for field that you can find on the **pydantic** documentation.
As you might have noticed, we added two new arguments to the `InputField`
definition for `width` and `height`, called `gt` and `le`. They stand for
_greater than or equal to_ and _less than or equal to_.
These impose contraints on those fields, and will raise an exception if the
values do not meet the constraints. Field constraints are provided by
**pydantic**, so anything you see in the **pydantic docs** will work.
**Note:** _Any time it is possible to define constraints for our field, we
should do it so the frontend has more information on how to parse this field._
@ -141,20 +139,17 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
image: ImageField = InputField(description="The input image")
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext):
pass
@ -173,21 +168,18 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
from .image import ImageOutput
@invocation('resize')
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
image: ImageField = InputField(description="The input image")
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext) -> ImageOutput:
pass
@ -195,39 +187,34 @@ class ResizeInvocation(BaseInvocation):
Perfect. Now that we have our Invocation setup, let us do what we want to do.
- We will first load the image. Generally we do this using the `PIL` library but
we can use one of the services provided by InvokeAI to load the image.
- We will first load the image using one of the services provided by InvokeAI to
load the image.
- We will resize the image using `PIL` to our input data.
- We will output this image in the format we set above.
So let's do that.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext
from ..models.image import ImageField, ResourceOrigin, ImageCategory
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
from .image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
"""Resizes an image"""
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
image: ImageField = InputField(description="The input image")
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the image using InvokeAI's predefined Image Service.
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
# Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
image = context.services.images.get_pil_image(self.image.image_name)
# Resizing the image
# Because we used the above service, we already have a PIL image. So we can simply resize.
resized_image = image.resize((self.width, self.height))
# Preparing the image for output using InvokeAI's predefined Image Service.
# Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
output_image = context.services.images.create(
image=resized_image,
image_origin=ResourceOrigin.INTERNAL,
@ -241,7 +228,6 @@ class ResizeInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=output_image.image_name,
image_origin=output_image.image_origin,
),
width=output_image.width,
height=output_image.height,
@ -253,6 +239,20 @@ certain way that the images need to be dispatched in order to be stored and read
correctly. In 99% of the cases when dealing with an image output, you can simply
copy-paste the template above.
### Customization
We can use the `@invocation` decorator to provide some additional info to the
UI, like a custom title, tags and category.
```python
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
class ResizeInvocation(BaseInvocation):
"""Resizes an image"""
image: ImageField = InputField(description="The input image")
...
```
That's it. You made your own **Resize Invocation**.
## Result
@ -271,10 +271,57 @@ new Invocation ready to be used.
![resize node editor](../assets/contributing/resize_node_editor.png)
## Contributing Nodes
Once you've created a Node, the next step is to share it with the community! The best way to do this is to submit a Pull Request to add the Node to the [Community Nodes](nodes/communityNodes) list. If you're not sure how to do that, take a look a at our [contributing nodes overview](contributingNodes).
Once you've created a Node, the next step is to share it with the community! The
best way to do this is to submit a Pull Request to add the Node to the
[Community Nodes](nodes/communityNodes) list. If you're not sure how to do that,
take a look a at our [contributing nodes overview](contributingNodes).
## Advanced
-->
### Custom Output Types
Like with custom inputs, sometimes you might find yourself needing custom
outputs that InvokeAI does not provide. We can easily set one up.
Now that you are familiar with Invocations and Inputs, let us use that knowledge
to create an output that has an `image` field, a `color` field and a `string`
field.
- An invocation output is a class that derives from the parent class of
`BaseInvocationOutput`.
- All invocation outputs must use the `@invocation_output` decorator to provide
their unique output type.
- Output fields must use the provided `OutputField` function. This is very
similar to the `InputField` function described earlier - it's a wrapper around
`pydantic`'s `Field()`.
- It is not mandatory but we recommend using names ending with `Output` for
output types.
- It is not mandatory but we highly recommend adding a `docstring` to describe
what your output type is for.
Now that we know the basic rules for creating a new output type, let us go ahead
and make it.
```python
from .baseinvocation import BaseInvocationOutput, OutputField, invocation_output
from .primitives import ImageField, ColorField
@invocation_output('image_color_string_output')
class ImageColorStringOutput(BaseInvocationOutput):
'''Base class for nodes that output a single image'''
image: ImageField = OutputField(description="The image")
color: ColorField = OutputField(description="The color")
text: str = OutputField(description="The string")
```
That's all there is to it.
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
### Custom Input Fields
Now that you know how to create your own Invocations, let us dive into slightly
@ -329,172 +376,6 @@ like this.
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
```
**Extra Config**
All input fields also take an additional `Config` class that you can use to do
various advanced things like setting required parameters and etc.
Let us do that for our _ColorField_ and enforce all the values because we did
not define any defaults for our fields.
```python
class ColorField(BaseModel):
'''A field that holds the rgba values of a color'''
r: int = Field(ge=0, le=255, description="The red channel")
g: int = Field(ge=0, le=255, description="The green channel")
b: int = Field(ge=0, le=255, description="The blue channel")
a: int = Field(ge=0, le=255, description="The alpha channel")
class Config:
schema_extra = {"required": ["r", "g", "b", "a"]}
```
Now it becomes mandatory for the user to supply all the values required by our
input field.
We will discuss the `Config` class in extra detail later in this guide and how
you can use it to make your Invocations more robust.
### Custom Output Types
Like with custom inputs, sometimes you might find yourself needing custom
outputs that InvokeAI does not provide. We can easily set one up.
Now that you are familiar with Invocations and Inputs, let us use that knowledge
to put together a custom output type for an Invocation that returns _width_,
_height_ and _background_color_ that we need to create a blank image.
- A custom output type is a class that derives from the parent class of
`BaseInvocationOutput`.
- It is not mandatory but we recommend using names ending with `Output` for
output types. So we'll call our class `BlankImageOutput`
- It is not mandatory but we highly recommend adding a `docstring` to describe
what your output type is for.
- Like Invocations, each output type should have a `type` variable that is
**unique**
Now that we know the basic rules for creating a new output type, let us go ahead
and make it.
```python
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocationOutput
class BlankImageOutput(BaseInvocationOutput):
'''Base output type for creating a blank image'''
type: Literal['blank_image_output'] = 'blank_image_output'
# Inputs
width: int = Field(description='Width of blank image')
height: int = Field(description='Height of blank image')
bg_color: ColorField = Field(description='Background color of blank image')
class Config:
schema_extra = {"required": ["type", "width", "height", "bg_color"]}
```
All set. We now have an output type that requires what we need to create a
blank_image. And if you noticed it, we even used the `Config` class to ensure
the fields are required.
### Custom Configuration
As you might have noticed when making inputs and outputs, we used a class called
`Config` from _pydantic_ to further customize them. Because our inputs and
outputs essentially inherit from _pydantic_'s `BaseModel` class, all
[configuration options](https://docs.pydantic.dev/latest/usage/schema/#schema-customization)
that are valid for _pydantic_ classes are also valid for our inputs and outputs.
You can do the same for your Invocations too but InvokeAI makes our life a
little bit easier on that end.
InvokeAI provides a custom configuration class called `InvocationConfig`
particularly for configuring Invocations. This is exactly the same as the raw
`Config` class from _pydantic_ with some extra stuff on top to help faciliate
parsing of the scheme in the frontend UI.
At the current moment, tihs `InvocationConfig` class is further improved with
the following features related the `ui`.
| Config Option | Field Type | Example |
| ------------- | ------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- |
| type_hints | `Dict[str, Literal["integer", "float", "boolean", "string", "enum", "image", "latents", "model", "control"]]` | `type_hint: "model"` provides type hints related to the model like displaying a list of available models |
| tags | `List[str]` | `tags: ['resize', 'image']` will classify your invocation under the tags of resize and image. |
| title | `str` | `title: 'Resize Image` will rename your to this custom title rather than infer from the name of the Invocation class. |
So let us update your `ResizeInvocation` with some extra configuration and see
how that works.
```python
from typing import Literal, Union
from pydantic import Field
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from ..models.image import ImageField, ResourceOrigin, ImageCategory
from .image import ImageOutput
class ResizeInvocation(BaseInvocation):
'''Resizes an image'''
type: Literal['resize'] = 'resize'
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
width: int = Field(default=512, ge=64, le=2048, description="Width of the new image")
height: int = Field(default=512, ge=64, le=2048, description="Height of the new image")
class Config(InvocationConfig):
schema_extra: {
ui: {
tags: ['resize', 'image'],
title: ['My Custom Resize']
}
}
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the image using InvokeAI's predefined Image Service.
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
# Resizing the image
# Because we used the above service, we already have a PIL image. So we can simply resize.
resized_image = image.resize((self.width, self.height))
# Preparing the image for output using InvokeAI's predefined Image Service.
output_image = context.services.images.create(
image=resized_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
# Returning the Image
return ImageOutput(
image=ImageField(
image_name=output_image.image_name,
image_origin=output_image.image_origin,
),
width=output_image.width,
height=output_image.height,
)
```
We now customized our code to let the frontend know that our Invocation falls
under `resize` and `image` categories. So when the user searches for these
particular words, our Invocation will show up too.
We also set a custom title for our Invocation. So instead of being called
`Resize`, it will be called `My Custom Resize`.
As simple as that.
As time goes by, InvokeAI will further improve and add more customizability for
Invocation configuration. We will have more documentation regarding this at a
later time.
# **[TODO]**
### Custom Components For Frontend
Every backend input type should have a corresponding frontend component so the
@ -513,282 +394,4 @@ Let us create a new component for our custom color field we created above. When
we use a color field, let us say we want the UI to display a color picker for
the user to pick from rather than entering values. That is what we will build
now.
---
<!-- # OLD -- TO BE DELETED OR MOVED LATER
---
## Creating a new invocation
To create a new invocation, either find the appropriate module file in
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
that folder. All invocations in that folder will be discovered and made
available to the CLI and API automatically. Invocations make use of
[typing](https://docs.python.org/3/library/typing.html) and
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
into the CLI and API.
An invocation looks like this:
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring
that declares what they do in a single, short line. They should also have a
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
is what the user will type on the CLI or use in the API to create this
invocation. The `command_name` must be unique. The `type` must be assigned to
the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
description, and validation information. For example:
| Part | Value | Description |
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
field to be parsed with `None` as a value, which enables linking to previous
invocations. All fields should either provide a default value or allow `None` as
a value, so that they can be overwritten with a linked output from another
invocation.
The special type `ImageField` is also used here. All images are passed as
`ImageField`, which protects them from pydantic validation errors (since images
only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
### Config
```py
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
```
This is an optional configuration for the invocation. It inherits from
pydantic's model `Config` class, and it used primarily to customize the
autogenerated OpenAPI schema.
The UI relies on the OpenAPI schema in two ways:
- An API client & Typescript types are generated from it. This happens at build
time.
- The node editor parses the schema into a template used by the UI to create the
node editor UI. This parsing happens at runtime.
In this example, a `ui` key has been added to the `schema_extra` dict to provide
some tags for the UI, to facilitate filtering nodes.
See the Schema Generation section below for more information.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
```
The `invoke` function is the last portion of an invocation. It is provided an
`InvocationContext` which contains services to perform work as well as a
`session_id` for use as needed. It should return a class with output values that
derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from
defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other
invocations, may be running on another machine, or in other interesting
scenarios. If you need functionality, please provide it as a service in the
`InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
```
Output classes look like an invocation class without the invoke method. Prefer
to use an existing output class if available, and prefer to name inputs the same
as outputs when possible, to promote automatic invocation linking.
## Schema Generation
Invocation, output and related classes are used to generate an OpenAPI schema.
### Required Properties
The schema generation treat all properties with default values as optional. This
makes sense internally, but when when using these classes via the generated
schema, we end up with e.g. the `ImageOutput` class having its `image` property
marked as optional.
We know that this property will always be present, so the additional logic
needed to always check if the property exists adds a lot of extraneous cruft.
To fix this, we can leverage `pydantic`'s
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
to mark properties that we know will always be present as required.
Here's that `ImageOutput` class, without the needed schema customisation:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
```
The OpenAPI schema that results from this `ImageOutput` will have the `type`,
`image`, `width` and `height` properties marked as optional, even though we know
they will always have a value.
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
# Add schema customization
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
```
With the customization in place, the schema will now show these properties as
required, obviating the need for extensive null checks in client code.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577> -->
-->

View File

@ -22,16 +22,14 @@ To use a community node graph, download the the `.json` node graph file and load
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
<hr>
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
<hr>
--------------------------------
### Retroize
**Description:** Retroize is a collection of nodes for InvokeAI to "Retroize" images. Any image can be given a fresh coat of retro paint with these nodes, either from your gallery or from within the graph itself. It includes nodes to pixelize, quantize, palettize, and ditherize images; as well as to retrieve palettes from existing images.
@ -55,9 +53,50 @@ Generated Prompt: An enchanted weapon will be usable by any character regardless
![9acf5aef-7254-40dd-95b3-8eac431dfab0 (1)](https://github.com/mickr777/InvokeAI/assets/115216705/8496ba09-bcdd-4ff7-8076-ff213b6a1e4c)
--------------------------------
### Load Video Frame
**Description:** This is a video frame image provider + indexer/video creation nodes for hooking up to iterators and ranges and ControlNets and such for invokeAI node experimentation. Think animation + ControlNet outputs.
**Node Link:** https://github.com/helix4u/load_video_frame
**Example Node Graph:** https://github.com/helix4u/load_video_frame/blob/main/Example_Workflow.json
**Output Example:**
=======
![Example animation](https://github.com/helix4u/load_video_frame/blob/main/testmp4_embed_converted.gif)
[Full mp4 of Example Output test.mp4](https://github.com/helix4u/load_video_frame/blob/main/test.mp4)
--------------------------------
### Oobabooga
**Description:** asks a local LLM running in Oobabooga's Text-Generation-Webui to write a prompt based on the user input.
**Link:** https://github.com/sammyf/oobabooga-node
**Example:**
"describe a new mystical creature in its natural environment"
*can return*
"The mystical creature I am describing to you is called the "Glimmerwing". It is a majestic, iridescent being that inhabits the depths of the most enchanted forests and glimmering lakes. Its body is covered in shimmering scales that reflect every color of the rainbow, and it has delicate, translucent wings that sparkle like diamonds in the sunlight. The Glimmerwing's home is a crystal-clear lake, surrounded by towering trees with leaves that shimmer like jewels. In this serene environment, the Glimmerwing spends its days swimming gracefully through the water, chasing schools of glittering fish and playing with the gentle ripples of the lake's surface.
As the sun sets, the Glimmerwing perches on a branch of one of the trees, spreading its wings to catch the last rays of light. The creature's scales glow softly, casting a rainbow of colors across the forest floor. The Glimmerwing sings a haunting melody, its voice echoing through the stillness of the night air. Its song is said to have the power to heal the sick and bring peace to troubled souls. Those who are lucky enough to hear the Glimmerwing's song are forever changed by its beauty and grace."
![glimmerwing_small](https://github.com/sammyf/oobabooga-node/assets/42468608/cecdd820-93dd-4c35-abbf-607e001fb2ed)
**Requirement**
a Text-Generation-Webui instance (might work remotely too, but I never tried it) and obviously InvokeAI 3.x
**Note**
This node works best with SDXL models, especially as the style can be described independantly of the LLM's output.
--------------------------------
### Example Node Template
**Description:** This node allows you to do super cool things with InvokeAI.

View File

@ -46,6 +46,7 @@ if [[ $(python -c 'from importlib.util import find_spec; print(find_spec("build"
pip install --user build
fi
rm -r ../build
python -m build --wheel --outdir dist/ ../.
# ----------------------

View File

@ -2,15 +2,18 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
import re
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
ClassVar,
Literal,
Mapping,
Optional,
Type,
@ -20,8 +23,8 @@ from typing import (
get_type_hints,
)
from pydantic import BaseModel, Field
from pydantic.fields import Undefined
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING:
@ -141,9 +144,11 @@ class UIType(str, Enum):
# endregion
# region Misc
FilePath = "FilePath"
Enum = "enum"
Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
MetadataField = "MetadataField"
# endregion
@ -365,12 +370,12 @@ def OutputField(
class UIConfigBase(BaseModel):
"""
Provides additional node configuration to the UI.
This is used internally by the @tags and @title decorator logic. You probably want to use those
decorators, though you may add this class to a node definition to specify the title and tags.
This is used internally by the @invocation decorator logic. Do not use this directly.
"""
tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI")
title: Optional[str] = Field(default=None, description="The display name of the node")
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
class InvocationContext:
@ -383,10 +388,11 @@ class InvocationContext:
class BaseInvocationOutput(BaseModel):
"""Base class for all invocation outputs"""
"""
Base class for all invocation outputs.
# All outputs must include a type name like this:
# type: Literal['your_output_name'] # noqa f821
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
@classmethod
def get_all_subclasses_tuple(cls):
@ -422,12 +428,12 @@ class MissingInputException(Exception):
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
"""
A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
# All invocations must include a type name like this:
# type: Literal['your_output_name'] # noqa f821
All invocations must use the `@invocation` decorator to provide their unique type.
"""
@classmethod
def get_all_subclasses(cls):
@ -466,6 +472,8 @@ class BaseInvocation(ABC, BaseModel):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"].extend(["type", "id"])
@ -505,37 +513,110 @@ class BaseInvocation(ABC, BaseModel):
raise MissingInputException(self.__fields__["type"].default, field_name)
return self.invoke(context)
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
id: str = Field(
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
)
is_intermediate: bool = InputField(
default=False, description="Whether or not this is an intermediate invocation.", ui_type=UIType.IsIntermediate
)
workflow: Optional[str] = InputField(
default=None,
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
if v is None:
return None
try:
json.loads(v)
except json.decoder.JSONDecodeError:
raise ValueError("Workflow must be valid JSON")
return v
UIConfig: ClassVar[Type[UIConfigBase]]
T = TypeVar("T", bound=BaseInvocation)
GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
def title(title: str) -> Callable[[Type[T]], Type[T]]:
"""Adds a title to the invocation. Use this to override the default title generation, which is based on the class name."""
def invocation(
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
def wrapper(cls: Type[T]) -> Type[T]:
:param str invocation_type: The type of the invocation. Must be unique among all invocations.
:param Optional[str] title: Adds a title to the invocation. Use if the auto-generated title isn't quite right. Defaults to None.
:param Optional[list[str]] tags: Adds tags to the invocation. Invocations may be searched for by their tags. Defaults to None.
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
"""
def wrapper(cls: Type[GenericBaseInvocation]) -> Type[GenericBaseInvocation]:
# Validate invocation types on creation of invocation classes
# TODO: ensure unique?
if re.compile(r"^\S+$").match(invocation_type) is None:
raise ValueError(f'"invocation_type" must consist of non-whitespace characters, got "{invocation_type}"')
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.title = title
if title is not None:
cls.UIConfig.title = title
if tags is not None:
cls.UIConfig.tags = tags
if category is not None:
cls.UIConfig.category = category
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = ModelField.infer(
name="type",
value=invocation_type,
annotation=invocation_type_annotation,
class_validators=None,
config=cls.__config__,
)
cls.__fields__.update({"type": invocation_type_field})
cls.__annotations__.update({"type": invocation_type_annotation})
return cls
return wrapper
def tags(*tags: str) -> Callable[[Type[T]], Type[T]]:
"""Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI."""
GenericBaseInvocationOutput = TypeVar("GenericBaseInvocationOutput", bound=BaseInvocationOutput)
def invocation_output(
output_type: str,
) -> Callable[[Type[GenericBaseInvocationOutput]], Type[GenericBaseInvocationOutput]]:
"""
Adds metadata to an invocation output.
:param str output_type: The type of the invocation output. Must be unique among all invocation outputs.
"""
def wrapper(cls: Type[GenericBaseInvocationOutput]) -> Type[GenericBaseInvocationOutput]:
# Validate output types on creation of invocation output classes
# TODO: ensure unique?
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
# Add the output type to the pydantic model of the invocation output
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = ModelField.infer(
name="type",
value=output_type,
annotation=output_type_annotation,
class_validators=None,
config=cls.__config__,
)
cls.__fields__.update({"type": output_type_field})
cls.__annotations__.update({"type": output_type_annotation})
def wrapper(cls: Type[T]) -> Type[T]:
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig.tags = list(tags)
return cls
return wrapper

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal
import numpy as np
from pydantic import validator
@ -8,17 +7,13 @@ from pydantic import validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("Integer Range")
@tags("collection", "integer", "range")
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
# Inputs
start: int = InputField(default=0, description="The start of the range")
stop: int = InputField(default=10, description="The stop of the range")
step: int = InputField(default=1, description="The step of the range")
@ -33,14 +28,15 @@ class RangeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@title("Integer Range of Size")
@tags("range", "integer", "size", "collection")
@invocation(
"range_of_size",
title="Integer Range of Size",
tags=["collection", "integer", "size", "range"],
category="collections",
)
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = InputField(default=0, description="The start of the range")
size: int = InputField(default=1, description="The number of values")
step: int = InputField(default=1, description="The step of the range")
@ -49,14 +45,15 @@ class RangeOfSizeInvocation(BaseInvocation):
return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
@title("Random Range")
@tags("range", "integer", "random", "collection")
@invocation(
"random_range",
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate")

View File

@ -1,6 +1,6 @@
import re
from dataclasses import dataclass
from typing import List, Literal, Union
from typing import List, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
@ -26,8 +26,8 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIComponent,
tags,
title,
invocation,
invocation_output,
)
from .model import ClipField
@ -44,13 +44,10 @@ class ConditioningFieldData:
# PerpNeg = "perp_neg"
@title("Compel Prompt")
@tags("prompt", "compel")
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = InputField(
default="",
description=FieldDescriptions.compel_prompt,
@ -116,16 +113,15 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True,
truncate_long_prompts=False,
)
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
@ -231,7 +227,7 @@ class SDXLPromptInvocationBase:
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
)
@ -240,8 +236,7 @@ class SDXLPromptInvocationBase:
if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax
for prompt_obj in conjunction.prompts:
log_tokenization_for_prompt_object(prompt_obj, tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -267,13 +262,15 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec
@title("SDXL Compel Prompt")
@tags("sdxl", "compel", "prompt")
@invocation(
"sdxl_compel_prompt",
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt"
prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea)
original_width: int = InputField(default=1024, description="")
@ -305,6 +302,29 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
# [1, 77, 768], [1, 154, 1280]
if c1.shape[1] < c2.shape[1]:
c1 = torch.cat(
[
c1,
torch.zeros(
(c1.shape[0], c2.shape[1] - c1.shape[1], c1.shape[2]), device=c1.device, dtype=c1.dtype
),
],
dim=1,
)
elif c1.shape[1] > c2.shape[1]:
c2 = torch.cat(
[
c2,
torch.zeros(
(c2.shape[0], c1.shape[1] - c2.shape[1], c2.shape[2]), device=c2.device, dtype=c2.dtype
),
],
dim=1,
)
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -326,13 +346,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
@title("SDXL Refiner Compel Prompt")
@tags("sdxl", "compel", "prompt")
@invocation(
"sdxl_refiner_compel_prompt",
title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = InputField(
default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea
) # TODO: ?
@ -374,20 +396,17 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
)
@invocation_output("clip_skip_output")
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@title("CLIP Skip")
@tags("clipskip", "clip", "skip")
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)

View File

@ -40,8 +40,8 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIType,
tags,
title,
invocation,
invocation_output,
)
@ -87,23 +87,18 @@ class ControlField(BaseModel):
return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
type: Literal["control_output"] = "control_output"
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@title("ControlNet")
@tags("controlnet")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
@ -134,12 +129,10 @@ class ControlNetInvocation(BaseInvocation):
)
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
class ImageProcessorInvocation(BaseInvocation):
"""Base class for invocations that preprocess images for ControlNet"""
type: Literal["image_processor"] = "image_processor"
# Inputs
image: ImageField = InputField(description="The image to process")
def run_processor(self, image):
@ -151,11 +144,6 @@ class ImageProcessorInvocation(BaseInvocation):
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
@ -165,6 +153,7 @@ class ImageProcessorInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
"""Builds an ImageOutput and its ImageField"""
@ -179,14 +168,15 @@ class ImageProcessorInvocation(BaseInvocation):
)
@title("Canny Processor")
@tags("controlnet", "canny")
@invocation(
"canny_image_processor",
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
)
@ -200,14 +190,15 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("HED (softedge) Processor")
@tags("controlnet", "hed", "softedge")
@invocation(
"hed_image_processor",
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# safe not supported in controlnet_aux v0.0.3
@ -227,14 +218,15 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Lineart Processor")
@tags("controlnet", "lineart")
@invocation(
"lineart_image_processor",
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
@ -247,14 +239,15 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Lineart Anime Processor")
@tags("controlnet", "lineart", "anime")
@invocation(
"lineart_anime_image_processor",
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -268,14 +261,15 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Openpose Processor")
@tags("controlnet", "openpose", "pose")
@invocation(
"openpose_image_processor",
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -291,14 +285,15 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Midas (Depth) Processor")
@tags("controlnet", "midas", "depth")
@invocation(
"midas_depth_image_processor",
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3
@ -316,14 +311,15 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Normal BAE Processor")
@tags("controlnet", "normal", "bae")
@invocation(
"normalbae_image_processor",
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
@ -335,14 +331,10 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("MLSD Processor")
@tags("controlnet", "mlsd")
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
@ -360,14 +352,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("PIDI Processor")
@tags("controlnet", "pidi")
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
@ -385,14 +373,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Content Shuffle Processor")
@tags("controlnet", "contentshuffle")
@invocation(
"content_shuffle_image_processor",
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
@ -413,27 +402,30 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
@title("Zoe (Depth) Processor")
@tags("controlnet", "zoe", "depth")
@invocation(
"zoe_depth_image_processor",
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@title("Mediapipe Face Processor")
@tags("controlnet", "mediapipe", "face")
@invocation(
"mediapipe_face_processor",
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
@ -447,14 +439,15 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Leres (Depth) Processor")
@tags("controlnet", "leres", "depth")
@invocation(
"leres_image_processor",
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
boost: bool = InputField(default=False, description="Whether to use boost mode")
@ -474,14 +467,15 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Tile Resample Processor")
@tags("controlnet", "tile")
@invocation(
"tile_image_processor",
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
@ -512,13 +506,15 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
return processed_image
@title("Segment Anything Processor")
@tags("controlnet", "segmentanything")
@invocation(
"segment_anything_processor",
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
type: Literal["segment_anything_processor"] = "segment_anything_processor"
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv
import numpy
@ -8,17 +7,18 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("OpenCV Inpaint")
@tags("opencv", "inpaint")
@invocation(
"cv_inpaint",
title="OpenCV Inpaint",
tags=["opencv", "inpaint"],
category="inpaint",
)
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = InputField(description="The image to inpaint")
mask: ImageField = InputField(description="The mask to use when inpainting")
@ -45,6 +45,7 @@ class CvInpaintInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -13,18 +13,13 @@ from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@title("Show Image")
@tags("image")
@invocation("show_image", title="Show Image", tags=["image"], category="image")
class ShowImageInvocation(BaseInvocation):
"""Displays a provided image, and passes it forward in the pipeline."""
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
# Metadata
type: Literal["show_image"] = "show_image"
# Inputs
image: ImageField = InputField(description="The image to show")
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -41,15 +36,10 @@ class ShowImageInvocation(BaseInvocation):
)
@title("Blank Image")
@tags("image")
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
class BlankImageInvocation(BaseInvocation):
"""Creates a blank image and forwards it to the pipeline"""
# Metadata
type: Literal["blank_image"] = "blank_image"
# Inputs
width: int = InputField(default=512, description="The width of the image")
height: int = InputField(default=512, description="The height of the image")
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
@ -65,6 +55,7 @@ class BlankImageInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -74,15 +65,10 @@ class BlankImageInvocation(BaseInvocation):
)
@title("Crop Image")
@tags("image", "crop")
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
class ImageCropInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
# Metadata
type: Literal["img_crop"] = "img_crop"
# Inputs
image: ImageField = InputField(description="The image to crop")
x: int = InputField(default=0, description="The left x coordinate of the crop rectangle")
y: int = InputField(default=0, description="The top y coordinate of the crop rectangle")
@ -102,6 +88,7 @@ class ImageCropInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -111,15 +98,10 @@ class ImageCropInvocation(BaseInvocation):
)
@title("Paste Image")
@tags("image", "paste")
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
class ImagePasteInvocation(BaseInvocation):
"""Pastes an image into another image."""
# Metadata
type: Literal["img_paste"] = "img_paste"
# Inputs
base_image: ImageField = InputField(description="The base image")
image: ImageField = InputField(description="The image to paste")
mask: Optional[ImageField] = InputField(
@ -154,6 +136,7 @@ class ImagePasteInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -163,15 +146,10 @@ class ImagePasteInvocation(BaseInvocation):
)
@title("Mask from Alpha")
@tags("image", "mask")
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
# Metadata
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
@ -189,6 +167,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -198,15 +177,10 @@ class MaskFromAlphaInvocation(BaseInvocation):
)
@title("Multiply Images")
@tags("image", "multiply")
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
class ImageMultiplyInvocation(BaseInvocation):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# Metadata
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: ImageField = InputField(description="The first image to multiply")
image2: ImageField = InputField(description="The second image to multiply")
@ -223,6 +197,7 @@ class ImageMultiplyInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -235,15 +210,10 @@ class ImageMultiplyInvocation(BaseInvocation):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@title("Extract Image Channel")
@tags("image", "channel")
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
class ImageChannelInvocation(BaseInvocation):
"""Gets a channel from an image."""
# Metadata
type: Literal["img_chan"] = "img_chan"
# Inputs
image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
@ -259,6 +229,7 @@ class ImageChannelInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -271,15 +242,10 @@ class ImageChannelInvocation(BaseInvocation):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@title("Convert Image Mode")
@tags("image", "convert")
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
class ImageConvertInvocation(BaseInvocation):
"""Converts an image to a different mode."""
# Metadata
type: Literal["img_conv"] = "img_conv"
# Inputs
image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
@ -295,6 +261,7 @@ class ImageConvertInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -304,15 +271,10 @@ class ImageConvertInvocation(BaseInvocation):
)
@title("Blur Image")
@tags("image", "blur")
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
class ImageBlurInvocation(BaseInvocation):
"""Blurs an image"""
# Metadata
type: Literal["img_blur"] = "img_blur"
# Inputs
image: ImageField = InputField(description="The image to blur")
radius: float = InputField(default=8.0, ge=0, description="The blur radius")
# Metadata
@ -333,6 +295,7 @@ class ImageBlurInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -362,19 +325,17 @@ PIL_RESAMPLING_MAP = {
}
@title("Resize Image")
@tags("image", "resize")
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
class ImageResizeInvocation(BaseInvocation):
"""Resizes an image to specific dimensions"""
# Metadata
type: Literal["img_resize"] = "img_resize"
# Inputs
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -393,6 +354,8 @@ class ImageResizeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -402,15 +365,10 @@ class ImageResizeInvocation(BaseInvocation):
)
@title("Scale Image")
@tags("image", "scale")
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
class ImageScaleInvocation(BaseInvocation):
"""Scales an image by a factor"""
# Metadata
type: Literal["img_scale"] = "img_scale"
# Inputs
image: ImageField = InputField(description="The image to scale")
scale_factor: float = InputField(
default=2.0,
@ -438,6 +396,7 @@ class ImageScaleInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -447,15 +406,10 @@ class ImageScaleInvocation(BaseInvocation):
)
@title("Lerp Image")
@tags("image", "lerp")
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
class ImageLerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
# Metadata
type: Literal["img_lerp"] = "img_lerp"
# Inputs
image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
@ -475,6 +429,7 @@ class ImageLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -484,15 +439,10 @@ class ImageLerpInvocation(BaseInvocation):
)
@title("Inverse Lerp Image")
@tags("image", "ilerp")
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
class ImageInverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
# Metadata
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
image: ImageField = InputField(description="The image to lerp")
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
@ -512,6 +462,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -521,15 +472,10 @@ class ImageInverseLerpInvocation(BaseInvocation):
)
@title("Blur NSFW Image")
@tags("image", "nsfw")
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
class ImageNSFWBlurInvocation(BaseInvocation):
"""Add blur to NSFW-flagged images"""
# Metadata
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: ImageField = InputField(description="The image to check")
metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
@ -555,6 +501,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -570,15 +517,10 @@ class ImageNSFWBlurInvocation(BaseInvocation):
return caution.resize((caution.width // 2, caution.height // 2))
@title("Add Invisible Watermark")
@tags("image", "watermark")
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
class ImageWatermarkInvocation(BaseInvocation):
"""Add an invisible watermark to an image"""
# Metadata
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: ImageField = InputField(description="The image to check")
text: str = InputField(default="InvokeAI", description="Watermark text")
metadata: Optional[CoreMetadata] = InputField(
@ -596,6 +538,7 @@ class ImageWatermarkInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -605,14 +548,10 @@ class ImageWatermarkInvocation(BaseInvocation):
)
@title("Mask Edge")
@tags("image", "mask", "inpaint")
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
class MaskEdgeInvocation(BaseInvocation):
"""Applies an edge mask to an image"""
type: Literal["mask_edge"] = "mask_edge"
# Inputs
image: ImageField = InputField(description="The image to apply the mask to")
edge_size: int = InputField(description="The size of the edge")
edge_blur: int = InputField(description="The amount of blur on the edge")
@ -644,6 +583,7 @@ class MaskEdgeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -653,14 +593,10 @@ class MaskEdgeInvocation(BaseInvocation):
)
@title("Combine Mask")
@tags("image", "mask", "multiply")
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
class MaskCombineInvocation(BaseInvocation):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
type: Literal["mask_combine"] = "mask_combine"
# Inputs
mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = InputField(description="The second image to combine")
@ -677,6 +613,7 @@ class MaskCombineInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -686,17 +623,13 @@ class MaskCombineInvocation(BaseInvocation):
)
@title("Color Correct")
@tags("image", "color")
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
class ColorCorrectInvocation(BaseInvocation):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
"""
type: Literal["color_correct"] = "color_correct"
# Inputs
image: ImageField = InputField(description="The image to color-correct")
reference: ImageField = InputField(description="Reference image for color-correction")
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
@ -785,6 +718,7 @@ class ColorCorrectInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -794,14 +728,10 @@ class ColorCorrectInvocation(BaseInvocation):
)
@title("Image Hue Adjustment")
@tags("image", "hue", "hsl")
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust")
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
@ -827,6 +757,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -838,14 +769,15 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
)
@title("Image Luminosity Adjustment")
@tags("image", "luminosity", "hsl")
@invocation(
"img_luminosity_adjust",
title="Adjust Image Luminosity",
tags=["image", "luminosity", "hsl"],
category="image",
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust")
luminosity: float = InputField(
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
@ -877,6 +809,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -888,14 +821,15 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
)
@title("Image Saturation Adjustment")
@tags("image", "saturation", "hsl")
@invocation(
"img_saturation_adjust",
title="Adjust Image Saturation",
tags=["image", "saturation", "hsl"],
category="image",
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = InputField(description="The image to adjust")
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
@ -925,6 +859,7 @@ class ImageSaturationAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -12,7 +12,7 @@ from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
def infill_methods() -> list[str]:
@ -116,14 +116,10 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@title("Solid Color Infill")
@tags("image", "inpaint")
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
# Inputs
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255),
@ -145,6 +141,7 @@ class InfillColorInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -154,14 +151,10 @@ class InfillColorInvocation(BaseInvocation):
)
@title("Tile Infill")
@tags("image", "inpaint")
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
# Input
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
@ -184,6 +177,7 @@ class InfillTileInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -193,14 +187,10 @@ class InfillTileInvocation(BaseInvocation):
)
@title("PatchMatch Infill")
@tags("image", "inpaint")
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
# Inputs
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -218,6 +208,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -227,14 +218,10 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
@title("LaMa Infill")
@tags("image", "inpaint")
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
class LaMaInfillInvocation(BaseInvocation):
"""Infills transparent areas of an image using the LaMa model"""
type: Literal["infill_lama"] = "infill_lama"
# Inputs
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -23,6 +23,8 @@ from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
DenoiseMaskOutput,
ImageField,
ImageOutput,
LatentsField,
@ -34,13 +36,15 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
invocation,
invocation_output,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
@ -48,6 +52,7 @@ from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ResourceOrigin
from ...backend.model_management import BaseModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
@ -66,6 +71,84 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SAMPLER_NAME_VALUES = InputField(
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)
def prep_mask_tensor(self, mask_image):
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3:
image = image.unsqueeze(0)
else:
image = None
mask = self.prep_mask_tensor(
context.services.images.get_pil_image(self.mask.image_name),
)
if image is not None:
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
context.services.latents.save(masked_latents_name, masked_latents)
else:
masked_latents_name = None
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)
return DenoiseMaskOutput(
denoise_mask=DenoiseMaskField(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
),
)
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
@ -100,14 +183,15 @@ def get_scheduler(
return scheduler
@title("Denoise Latents")
@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l")
@invocation(
"denoise_latents",
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
type: Literal["denoise_latents"] = "denoise_latents"
# Inputs
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
@ -128,10 +212,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
)
latents: Optional[LatentsField] = InputField(
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
)
mask: Optional[ImageField] = InputField(
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
)
@ -311,52 +393,46 @@ class DenoiseLatentsInvocation(BaseInvocation):
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
num_inference_steps = steps
if scheduler.config.get("cpu_only", False):
scheduler.set_timesteps(num_inference_steps, device="cpu")
scheduler.set_timesteps(steps, device="cpu")
timesteps = scheduler.timesteps.to(device=device)
else:
scheduler.set_timesteps(num_inference_steps, device=device)
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps
# apply denoising_start
# skip greater order timesteps
_timesteps = timesteps[:: scheduler.order]
# get start timestep index
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, timesteps)))
timesteps = timesteps[t_start_idx:]
if scheduler.order == 2 and t_start_idx > 0:
timesteps = timesteps[1:]
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
# save start timestep to apply noise
init_timestep = timesteps[:1]
# apply denoising_end
# get end timestep index
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, timesteps)))
if scheduler.order == 2 and t_end_idx > 0:
t_end_idx += 1
timesteps = timesteps[:t_end_idx]
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
# calculate step count based on scheduler order
num_inference_steps = len(timesteps)
if scheduler.order == 2:
num_inference_steps += num_inference_steps % 2
num_inference_steps = num_inference_steps // 2
# apply order to indexes
t_start_idx *= scheduler.order
t_end_idx *= scheduler.order
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
return num_inference_steps, timesteps, init_timestep
def prep_mask_tensor(self, mask, context, lantents):
if mask is None:
return None
def prep_inpaint_mask(self, context, latents):
if self.denoise_mask is None:
return None, None
mask_image = context.services.images.get_pil_image(mask.image_name)
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
return 1 - mask_tensor
mask = context.services.latents.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None
return 1 - mask, masked_latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -371,13 +447,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
else:
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
elif noise is not None:
latents = torch.zeros_like(noise)
else:
raise Exception("'latents' or 'noise' must be provided!")
if seed is None:
seed = 0
mask = self.prep_mask_tensor(self.mask, context, latents)
mask, masked_latents = self.prep_inpaint_mask(context, latents)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -402,12 +484,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
@ -444,6 +528,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
@ -459,14 +544,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
@title("Latents to Image")
@tags("latents", "image", "vae", "l2i")
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
@ -492,7 +573,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context,
)
with vae_info as vae:
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
@ -556,6 +637,7 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -568,14 +650,10 @@ class LatentsToImageInvocation(BaseInvocation):
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@title("Resize Latents")
@tags("latents", "resize")
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
@ -616,14 +694,10 @@ class ResizeLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Scale Latents")
@tags("latents", "resize")
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
@ -656,14 +730,10 @@ class ScaleLatentsInvocation(BaseInvocation):
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@title("Image to Latents")
@tags("latents", "image", "vae", "i2l")
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: ImageField = InputField(
description="The image to encode",
)
@ -674,26 +744,11 @@ class ImageToLatentsInvocation(BaseInvocation):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name)
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
@staticmethod
def vae_encode(vae_info, upcast, tiled, image_tensor):
with vae_info as vae:
orig_dtype = vae.dtype
if self.fp32:
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
@ -719,14 +774,14 @@ class ImageToLatentsInvocation(BaseInvocation):
# latents = latents.half()
try:
if self.tiled:
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
except AttributeError as err:
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
if err.name.endswith("_tiling"):
InvokeAILogger.getLogger(self.__class__.__name__).debug(
InvokeAILogger.getLogger(ImageToLatentsInvocation.__name__).debug(
"ignoring tiling error for %s", vae.__class__, exc_info=err
)
else:
@ -735,35 +790,50 @@ class ImageToLatentsInvocation(BaseInvocation):
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
latents = self._encode_to_tensor(vae, image_tensor)
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get_pil_image(self.image.image_name)
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
name = f"{context.graph_execution_state_id}__{self.id}"
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
def _(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
@title("Blend Latents")
@tags("latents", "blend")
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
type: Literal["lblend"] = "lblend"
# Inputs
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,

View File

@ -1,22 +1,16 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import numpy as np
from invokeai.app.invocations.primitives import IntegerOutput
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@title("Add Integers")
@tags("math")
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
type: Literal["add"] = "add"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -24,14 +18,10 @@ class AddInvocation(BaseInvocation):
return IntegerOutput(value=self.a + self.b)
@title("Subtract Integers")
@tags("math")
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
type: Literal["sub"] = "sub"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -39,14 +29,10 @@ class SubtractInvocation(BaseInvocation):
return IntegerOutput(value=self.a - self.b)
@title("Multiply Integers")
@tags("math")
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
type: Literal["mul"] = "mul"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -54,14 +40,10 @@ class MultiplyInvocation(BaseInvocation):
return IntegerOutput(value=self.a * self.b)
@title("Divide Integers")
@tags("math")
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
type: Literal["div"] = "div"
# Inputs
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
@ -69,14 +51,10 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@title("Random Integer")
@tags("math")
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
type: Literal["rand_int"] = "rand_int"
# Inputs
low: int = InputField(default=0, description="The inclusive low value")
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Optional
from pydantic import Field
@ -8,8 +8,8 @@ from invokeai.app.invocations.baseinvocation import (
InputField,
InvocationContext,
OutputField,
tags,
title,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
@ -91,21 +91,17 @@ class ImageMetadata(BaseModelExcludeNull):
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
@invocation_output("metadata_accumulator_output")
class MetadataAccumulatorOutput(BaseInvocationOutput):
"""The output of the MetadataAccumulator node"""
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
@title("Metadata Accumulator")
@tags("metadata")
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = InputField(
description="The generation mode that output this image",
)

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@ -8,13 +8,13 @@ from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
invocation,
invocation_output,
)
@ -33,6 +33,7 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class ClipField(BaseModel):
@ -45,13 +46,13 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("model_loader_output")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
type: Literal["model_loader_output"] = "model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@ -72,14 +73,10 @@ class LoRAModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
@title("Main Model")
@tags("model")
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["main_model_loader"] = "main_model_loader"
# Inputs
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
@ -168,25 +165,18 @@ class MainModelLoaderInvocation(BaseInvocation):
)
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["lora_loader_output"] = "lora_loader_output"
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
# fmt: on
@title("LoRA")
@tags("lora", "model")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["lora_loader"] = "lora_loader"
# Inputs
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@ -245,25 +235,19 @@ class LoraLoaderInvocation(BaseInvocation):
return output
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
# fmt: on
@title("SDXL LoRA")
@tags("sdxl", "lora", "model")
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = Field(
@ -347,23 +331,17 @@ class VAEModelField(BaseModel):
base_model: BaseModelType = Field(description="Base model")
@invocation_output("vae_loader_output")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
"""VAE output"""
type: Literal["vae_loader_output"] = "vae_loader_output"
# Outputs
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("VAE")
@tags("vae", "model")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
# Inputs
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
)
@ -388,3 +366,44 @@ class VaeLoaderInvocation(BaseInvocation):
)
)
)
@invocation_output("seamless_output")
class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
class SeamlessModeInvocation(BaseInvocation):
"""Applies the seamless transformation to the Model UNet and VAE."""
unet: Optional[UNetField] = InputField(
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
)
vae: Optional[VaeField] = InputField(
default=None, description=FieldDescriptions.vae_model, input=Input.Connection, title="VAE"
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
seamless_axes_list = []
if self.seamless_x:
seamless_axes_list.append("x")
if self.seamless_y:
seamless_axes_list.append("y")
if unet is not None:
unet.seamless_axes = seamless_axes_list
if vae is not None:
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from typing import Literal
import torch
from pydantic import validator
@ -16,8 +15,8 @@ from .baseinvocation import (
InputField,
InvocationContext,
OutputField,
tags,
title,
invocation,
invocation_output,
)
"""
@ -62,12 +61,10 @@ Nodes
"""
@invocation_output("noise_output")
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise)
width: int = OutputField(description=FieldDescriptions.width)
height: int = OutputField(description=FieldDescriptions.height)
@ -81,14 +78,10 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
)
@title("Noise")
@tags("latents", "noise")
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = InputField(
ge=0,
le=SEED_MAX,

View File

@ -31,8 +31,8 @@ from .baseinvocation import (
OutputField,
UIComponent,
UIType,
tags,
title,
invocation,
invocation_output,
)
from .controlnet_image_processors import ControlField
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
@ -56,11 +56,8 @@ ORT_TO_NP_TYPE = {
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@title("ONNX Prompt (Raw)")
@tags("onnx", "prompt")
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@ -141,14 +138,15 @@ class ONNXPromptInvocation(BaseInvocation):
# Text to image
@title("ONNX Text to Latents")
@tags("latents", "inference", "txt2img", "onnx")
@invocation(
"t2l_onnx",
title="ONNX Text to Latents",
tags=["latents", "inference", "txt2img", "onnx"],
category="latents",
)
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond,
input=Input.Connection,
@ -316,14 +314,15 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# Latent to image
@title("ONNX Latents to Image")
@tags("latents", "image", "vae", "onnx")
@invocation(
"l2i_onnx",
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
)
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: LatentsField = InputField(
description=FieldDescriptions.denoised_latents,
input=Input.Connection,
@ -376,6 +375,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -385,17 +385,14 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
)
@invocation_output("model_loader_output_onnx")
class ONNXModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
# fmt: on
class OnnxModelField(BaseModel):
@ -406,14 +403,10 @@ class OnnxModelField(BaseModel):
model_type: ModelType = Field(description="Model Type")
@title("ONNX Main Model")
@tags("onnx", "model")
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
# Inputs
model: OnnxModelField = InputField(
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
)

View File

@ -42,17 +42,13 @@ from matplotlib.ticker import MaxNLocator
from invokeai.app.invocations.primitives import FloatCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@title("Float Range")
@tags("math", "range")
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = InputField(default=5, description="The first value of the range")
stop: float = InputField(default=10, description="The last value of the range")
steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)")
@ -100,14 +96,10 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any]
@title("Step Param Easing")
@tags("step", "easing")
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use")
num_steps: int = InputField(default=20, description="number of denoising steps")
start_value: float = InputField(default=0.0, description="easing starting value")

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Tuple
from typing import Optional, Tuple
import torch
from pydantic import BaseModel, Field
@ -15,8 +15,8 @@ from .baseinvocation import (
OutputField,
UIComponent,
UIType,
tags,
title,
invocation,
invocation_output,
)
"""
@ -29,44 +29,39 @@ Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
# region Boolean
@invocation_output("boolean_output")
class BooleanOutput(BaseInvocationOutput):
"""Base class for nodes that output a single boolean"""
type: Literal["boolean_output"] = "boolean_output"
value: bool = OutputField(description="The output boolean")
@invocation_output("boolean_collection_output")
class BooleanCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of booleans"""
type: Literal["boolean_collection_output"] = "boolean_collection_output"
# Outputs
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
@title("Boolean Primitive")
@tags("primitives", "boolean")
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
class BooleanInvocation(BaseInvocation):
"""A boolean primitive value"""
type: Literal["boolean"] = "boolean"
# Inputs
value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(value=self.value)
@title("Boolean Primitive Collection")
@tags("primitives", "boolean", "collection")
@invocation(
"boolean_collection",
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
)
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
type: Literal["boolean_collection"] = "boolean_collection"
# Inputs
collection: list[bool] = InputField(
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
)
@ -80,46 +75,41 @@ class BooleanCollectionInvocation(BaseInvocation):
# region Integer
@invocation_output("integer_output")
class IntegerOutput(BaseInvocationOutput):
"""Base class for nodes that output a single integer"""
type: Literal["integer_output"] = "integer_output"
value: int = OutputField(description="The output integer")
@invocation_output("integer_collection_output")
class IntegerCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of integers"""
type: Literal["integer_collection_output"] = "integer_collection_output"
# Outputs
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
@title("Integer Primitive")
@tags("primitives", "integer")
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
class IntegerInvocation(BaseInvocation):
"""An integer primitive value"""
type: Literal["integer"] = "integer"
# Inputs
value: int = InputField(default=0, description="The integer value")
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.value)
@title("Integer Primitive Collection")
@tags("primitives", "integer", "collection")
@invocation(
"integer_collection",
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
)
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
type: Literal["integer_collection"] = "integer_collection"
# Inputs
collection: list[int] = InputField(
default=0, description="The collection of integer values", ui_type=UIType.IntegerCollection
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
@ -131,44 +121,39 @@ class IntegerCollectionInvocation(BaseInvocation):
# region Float
@invocation_output("float_output")
class FloatOutput(BaseInvocationOutput):
"""Base class for nodes that output a single float"""
type: Literal["float_output"] = "float_output"
value: float = OutputField(description="The output float")
@invocation_output("float_collection_output")
class FloatCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of floats"""
type: Literal["float_collection_output"] = "float_collection_output"
# Outputs
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
@title("Float Primitive")
@tags("primitives", "float")
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
class FloatInvocation(BaseInvocation):
"""A float primitive value"""
type: Literal["float"] = "float"
# Inputs
value: float = InputField(default=0.0, description="The float value")
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=self.value)
@title("Float Primitive Collection")
@tags("primitives", "float", "collection")
@invocation(
"float_collection",
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
)
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
type: Literal["float_collection"] = "float_collection"
# Inputs
collection: list[float] = InputField(
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
)
@ -182,44 +167,39 @@ class FloatCollectionInvocation(BaseInvocation):
# region String
@invocation_output("string_output")
class StringOutput(BaseInvocationOutput):
"""Base class for nodes that output a single string"""
type: Literal["string_output"] = "string_output"
value: str = OutputField(description="The output string")
@invocation_output("string_collection_output")
class StringCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of strings"""
type: Literal["string_collection_output"] = "string_collection_output"
# Outputs
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
@title("String Primitive")
@tags("primitives", "string")
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
class StringInvocation(BaseInvocation):
"""A string primitive value"""
type: Literal["string"] = "string"
# Inputs
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=self.value)
@title("String Primitive Collection")
@tags("primitives", "string", "collection")
@invocation(
"string_collection",
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
)
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
type: Literal["string_collection"] = "string_collection"
# Inputs
collection: list[str] = InputField(
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
)
@ -239,33 +219,26 @@ class ImageField(BaseModel):
image_name: str = Field(description="The name of the image")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["image_output"] = "image_output"
image: ImageField = OutputField(description="The output image")
width: int = OutputField(description="The width of the image in pixels")
height: int = OutputField(description="The height of the image in pixels")
@invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images"""
type: Literal["image_collection_output"] = "image_collection_output"
# Outputs
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
@title("Image Primitive")
@tags("primitives", "image")
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
class ImageInvocation(BaseInvocation):
"""An image primitive value"""
# Metadata
type: Literal["image"] = "image"
# Inputs
image: ImageField = InputField(description="The image to load")
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -278,22 +251,42 @@ class ImageInvocation(BaseInvocation):
)
@title("Image Primitive Collection")
@tags("primitives", "image", "collection")
@invocation(
"image_collection",
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
)
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""
type: Literal["image_collection"] = "image_collection"
# Inputs
collection: list[ImageField] = InputField(
default=0, description="The collection of image values", ui_type=UIType.ImageCollection
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
)
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
# endregion
# region DenoiseMask
class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
@invocation_output("denoise_mask_output")
class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
# endregion
# region Latents
@ -306,11 +299,10 @@ class LatentsField(BaseModel):
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
@invocation_output("latents_output")
class LatentsOutput(BaseInvocationOutput):
"""Base class for nodes that output a single latents tensor"""
type: Literal["latents_output"] = "latents_output"
latents: LatentsField = OutputField(
description=FieldDescriptions.latents,
)
@ -318,25 +310,20 @@ class LatentsOutput(BaseInvocationOutput):
height: int = OutputField(description=FieldDescriptions.height)
@invocation_output("latents_collection_output")
class LatentsCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of latents tensors"""
type: Literal["latents_collection_output"] = "latents_collection_output"
collection: list[LatentsField] = OutputField(
description=FieldDescriptions.latents,
ui_type=UIType.LatentsCollection,
)
@title("Latents Primitive")
@tags("primitives", "latents")
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
class LatentsInvocation(BaseInvocation):
"""A latents tensor primitive value"""
type: Literal["latents"] = "latents"
# Inputs
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -345,14 +332,15 @@ class LatentsInvocation(BaseInvocation):
return build_latents_output(self.latents.latents_name, latents)
@title("Latents Primitive Collection")
@tags("primitives", "latents", "collection")
@invocation(
"latents_collection",
title="Latents Collection Primitive",
tags=["primitives", "latents", "collection"],
category="primitives",
)
class LatentsCollectionInvocation(BaseInvocation):
"""A collection of latents tensor primitive values"""
type: Literal["latents_collection"] = "latents_collection"
# Inputs
collection: list[LatentsField] = InputField(
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
)
@ -386,30 +374,24 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
@invocation_output("color_output")
class ColorOutput(BaseInvocationOutput):
"""Base class for nodes that output a single color"""
type: Literal["color_output"] = "color_output"
color: ColorField = OutputField(description="The output color")
@invocation_output("color_collection_output")
class ColorCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of colors"""
type: Literal["color_collection_output"] = "color_collection_output"
# Outputs
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
@title("Color Primitive")
@tags("primitives", "color")
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
class ColorInvocation(BaseInvocation):
"""A color primitive value"""
type: Literal["color"] = "color"
# Inputs
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context: InvocationContext) -> ColorOutput:
@ -427,49 +409,51 @@ class ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
type: Literal["conditioning_output"] = "conditioning_output"
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
type: Literal["conditioning_collection_output"] = "conditioning_collection_output"
# Outputs
collection: list[ConditioningField] = OutputField(
description="The output conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
@title("Conditioning Primitive")
@tags("primitives", "conditioning")
@invocation(
"conditioning",
title="Conditioning Primitive",
tags=["primitives", "conditioning"],
category="primitives",
)
class ConditioningInvocation(BaseInvocation):
"""A conditioning tensor primitive value"""
type: Literal["conditioning"] = "conditioning"
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
@title("Conditioning Primitive Collection")
@tags("primitives", "conditioning", "collection")
@invocation(
"conditioning_collection",
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
)
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
type: Literal["conditioning_collection"] = "conditioning_collection"
# Inputs
collection: list[ConditioningField] = InputField(
default=0, description="The collection of conditioning tensors", ui_type=UIType.ConditioningCollection
default_factory=list,
description="The collection of conditioning tensors",
ui_type=UIType.ConditioningCollection,
)
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:

View File

@ -1,5 +1,5 @@
from os.path import exists
from typing import Literal, Optional, Union
from typing import Optional, Union
import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
@ -7,17 +7,13 @@ from pydantic import validator
from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@title("Dynamic Prompt")
@tags("prompt", "collection")
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
# Inputs
prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea)
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
@ -33,15 +29,11 @@ class DynamicPromptInvocation(BaseInvocation):
return StringCollectionOutput(collection=prompts)
@title("Prompts from File")
@tags("prompt", "file")
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""
type: Literal["prompt_from_file"] = "prompt_from_file"
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
file_path: str = InputField(description="Path to prompt text file")
pre_prompt: Optional[str] = InputField(
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)

View File

@ -1,5 +1,3 @@
from typing import Literal
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
@ -10,41 +8,35 @@ from .baseinvocation import (
InvocationContext,
OutputField,
UIType,
tags,
title,
invocation,
invocation_output,
)
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@title("SDXL Main Model")
@tags("model", "sdxl")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
type: Literal["sdxl_model_loader"] = "sdxl_model_loader"
# Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
@ -122,14 +114,15 @@ class SDXLModelLoaderInvocation(BaseInvocation):
)
@title("SDXL Refiner Model")
@tags("model", "sdxl", "refiner")
@invocation(
"sdxl_refiner_model_loader",
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
# Inputs
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
# TODO: Populate this from disk?
# TODO: Use model manager to load?
@ -23,14 +23,10 @@ ESRGAN_MODELS = Literal[
]
@title("Upscale (RealESRGAN)")
@tags("esrgan", "upscale")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
type: Literal["esrgan"] = "esrgan"
# Inputs
image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
@ -110,6 +106,7 @@ class ESRGANInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -6,3 +6,4 @@ from .invokeai_config import ( # noqa F401
InvokeAIAppConfig,
get_invokeai_config,
)
from .base import PagingArgumentParser # noqa F401

View File

@ -3,7 +3,7 @@
import copy
import itertools
import uuid
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
@ -14,11 +14,13 @@ from ..invocations import * # noqa: F401 F403
from ..invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation_output,
)
# in 3.10 this would be "from types import NoneType"
@ -148,24 +150,16 @@ class NodeAlreadyExecutedError(Exception):
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
"required": [
"type",
"image",
]
}
pass
# TODO: Fill this out and move to invocations
@invocation("graph")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
graph: "Graph" = Field(description="The graph to run", default=None)
@ -174,22 +168,20 @@ class GraphInvocation(BaseInvocation):
return GraphInvocationOutput()
@invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
type: Literal["iterate_output"] = "iterate_output"
item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
)
# TODO: Fill this out and move to invocations
@invocation("iterate")
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = InputField(
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
)
@ -200,19 +192,17 @@ class IterateInvocation(BaseInvocation):
return IterateInvocationOutput(item=self.collection[self.index])
@invocation_output("collect_output")
class CollectInvocationOutput(BaseInvocationOutput):
type: Literal["collect_output"] = "collect_output"
collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type=UIType.Collection
)
@invocation("collect")
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
type: Literal["collect"] = "collect"
item: Any = InputField(
description="The item to collect (all inputs must be of the same type)",
ui_type=UIType.CollectionItem,

View File

@ -60,7 +60,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,7 +110,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -119,12 +119,23 @@ class DiskImageFileStorage(ImageFileStorageBase):
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None:
pnginfo.add_text("invokeai_graph", json.dumps(graph))
if metadata is not None or workflow is not None:
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
else:
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
# TODO: retain non-invokeai metadata on save...
original_metadata = image.info.get("invokeai_metadata", None)
if original_metadata is not None:
pnginfo.add_text("invokeai_metadata", original_metadata)
original_workflow = image.info.get("invokeai_workflow", None)
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)

View File

@ -54,6 +54,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -177,6 +178,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -186,16 +188,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name()
graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
# TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
# if session_id is not None:
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
# if session_raw is not None:
# try:
# graph = get_metadata_graph_from_raw_session(session_raw)
# except Exception as e:
# self._services.logger.warn(f"Failed to parse session graph: {e}")
# graph = None
(width, height) = image.size
@ -217,7 +219,7 @@ class ImageService(ImageServiceABC):
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
return image_dto

View File

@ -53,7 +53,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
- `starred`: change whether the image is starred
"""
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
image_category: Optional[ImageCategory] = Field(default=None, description="The image's new category.")
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,

View File

@ -20,7 +20,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on the `seamless` and `seamless_axes` options.
Modifies the 2D convolution layers to use a circular padding mode based on
the `seamless` and `seamless_axes` options.
"""
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules():

View File

@ -492,10 +492,10 @@ def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
loras = paths.get("lora_dir", "loras")
controlnets = paths.get("controlnet_dir", "controlnets")
return ModelPaths(
models=root / models,
embeddings=root / embeddings,
loras=root / loras,
controlnets=root / controlnets,
models=root / models if models else None,
embeddings=root / embeddings if embeddings else None,
loras=root / loras if loras else None,
controlnets=root / controlnets if controlnets else None,
)

View File

@ -0,0 +1,102 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import List, Union
import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding

View File

@ -144,7 +144,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = Tr
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
T.ToTensor(),
]
)
@ -358,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0:
@ -376,28 +377,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
if mask is not None:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
# masked_latents = latents * torch.where(mask < 0.5, 1, 0) TODO: inpaint/outpaint/infill
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))
try:
@ -557,12 +558,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
# TODO: issue to diffusers?
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
# this needed to be able call scheduler.add_noise with current timestep
if self.scheduler.order == 2:
self.scheduler._index_counter[timestep.item()] -= 1
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
# But the way things are now, scheduler runs _after_ that, so there was
# no way to use it to apply an operation that happens after the last scheduler.step.
for guidance in additional_guidance:
step_output = guidance(step_output, timestep, conditioning_data)
# restore internal counter
if self.scheduler.order == 2:
self.scheduler._index_counter[timestep.item()] += 1
return step_output
def _unet_forward(

View File

@ -265,7 +265,7 @@ class InvokeAICrossAttentionMixin:
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):

View File

@ -215,10 +215,7 @@ class InvokeAIDiffuserComponent:
dim=0,
),
}
(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
)
@ -280,10 +277,7 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
sample,
timestep,
conditioning_data,
@ -291,10 +285,7 @@ class InvokeAIDiffuserComponent:
**kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
sample,
timestep,
conditioning_data,
@ -302,10 +293,7 @@ class InvokeAIDiffuserComponent:
)
else:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
sample,
timestep,
conditioning_data,

View File

@ -395,7 +395,7 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -413,7 +413,7 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img

View File

@ -399,7 +399,7 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
@ -417,7 +417,7 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img

View File

@ -562,14 +562,18 @@ def rgb2ycbcr(img, only_y=True):
if only_y:
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
) / 255.0 + [16, 128, 128]
rlt = (
np.matmul(
img,
[
[65.481, -37.797, 112.0],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214],
],
)
/ 255.0
+ [16, 128, 128]
)
if in_img_type == np.uint8:
rlt = rlt.round()
else:
@ -588,14 +592,18 @@ def ycbcr2rgb(img):
if in_img_type != np.uint8:
img *= 255.0
# convert
rlt = np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
) * 255.0 + [-222.921, 135.576, -276.836]
rlt = (
np.matmul(
img,
[
[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0],
],
)
* 255.0
+ [-222.921, 135.576, -276.836]
)
if in_img_type == np.uint8:
rlt = rlt.round()
else:
@ -618,14 +626,18 @@ def bgr2ycbcr(img, only_y=True):
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
) / 255.0 + [16, 128, 128]
rlt = (
np.matmul(
img,
[
[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0],
],
)
/ 255.0
+ [16, 128, 128]
)
if in_img_type == np.uint8:
rlt = rlt.round()
else:
@ -716,11 +728,11 @@ def ssim(img1, img2):
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
@ -737,8 +749,8 @@ def ssim(img1, img2):
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
absx2 = absx ** 2
absx3 = absx ** 3
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * (((absx > 1) * (absx <= 2)).type_as(absx))

View File

@ -475,10 +475,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
(
h,
w,
) = (
(h, w,) = (
img.shape[0],
img.shape[1],
)

View File

@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlnetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
@ -14,16 +14,9 @@ from diffusers.models.embeddings import (
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from diffusers.models.unet_2d_condition import UNet2DConditionModel
import diffusers
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from torch import nn
from invokeai.backend.util.logging import InvokeAILogger
@ -45,7 +38,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", \
"CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
@ -147,7 +141,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# when this library was created...
# The incorrect naming was only discovered much ...
# later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
@ -155,17 +151,20 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `block_out_channels` as `down_block_types`. \
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `only_cross_attention` as `down_block_types`. \
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `num_attention_heads` as `down_block_types`. \
`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
@ -202,7 +201,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# they are set to `cross_attention_dim` here as this is exactly the required dimension ...
# for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
@ -250,8 +250,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
# To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension...
# for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
@ -673,12 +675,14 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which \
requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
@ -761,3 +765,64 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
diffusers.ControlNetModel = ControlNetModel
diffusers.models.controlnet.ControlNetModel = ControlNetModel
# patch LoRACompatibleConv to use original Conv2D forward function
# this needed to make work seamless patch
# NOTE: with this patch, torch.compile crashes on 2.0 torch(already fixed in nightly)
# https://github.com/huggingface/diffusers/pull/4315
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/lora.py#L96C18-L96C18
def new_LoRACompatibleConv_forward(self, x):
if self.lora_layer is None:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x)
else:
return super(diffusers.models.lora.LoRACompatibleConv, self).forward(x) + self.lora_layer(x)
diffusers.models.lora.LoRACompatibleConv.forward = new_LoRACompatibleConv_forward
try:
import xformers
xformers_available = True
except Exception:
xformers_available = False
if xformers_available:
# TODO: remove when fixed in diffusers
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
def new_memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias=None,
p: float = 0.0,
scale: Optional[float] = None,
*,
op=None,
):
# diffusers not align shape to 8, which is required by xformers
if attn_bias is not None and type(attn_bias) is torch.Tensor:
orig_size = attn_bias.shape[-1]
new_size = ((orig_size + 7) // 8) * 8
aligned_attn_bias = torch.zeros(
(attn_bias.shape[0], attn_bias.shape[1], new_size),
device=attn_bias.device,
dtype=attn_bias.dtype,
)
aligned_attn_bias[:, :, :orig_size] = attn_bias
attn_bias = aligned_attn_bias[:, :, :orig_size]
return _xformers_memory_efficient_attention(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
op=op,
)
xformers.ops.memory_efficient_attention = new_memory_efficient_attention

View File

@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2**29
chunk_size = 2 ** 29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))

View File

@ -207,7 +207,7 @@ def parallel_data_prefetch(
return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])

View File

@ -4,14 +4,14 @@ sd-1/main/stable-diffusion-v1-5:
repo_id: runwayml/stable-diffusion-v1-5
recommended: True
default: True
sd-1/main/stable-diffusion-inpainting:
sd-1/main/stable-diffusion-v1-5-inpainting:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting
recommended: True
sd-2/main/stable-diffusion-2-1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-1
recommended: True
recommended: False
sd-2/main/stable-diffusion-2-inpainting:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting
@ -19,19 +19,19 @@ sd-2/main/stable-diffusion-2-inpainting:
sdxl/main/stable-diffusion-xl-base-1-0:
description: Stable Diffusion XL base model (12 GB)
repo_id: stabilityai/stable-diffusion-xl-base-1.0
recommended: False
recommended: True
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
description: Stable Diffusion XL refiner model (12 GB)
repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
recommended: false
recommended: False
sdxl/vae/sdxl-1-0-vae-fix:
description: Fine tuned version of the SDXL-1.0 VAE
repo_id: madebyollin/sdxl-vae-fp16-fix
recommended: true
recommended: True
sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion
recommended: false
recommended: False
sd-1/main/Deliberate:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
repo_id: XpucT/Deliberate

View File

@ -60,7 +60,7 @@ class Config:
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
"""Find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
@ -70,7 +70,7 @@ class Config:
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
"""Validate a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
@ -337,33 +337,24 @@ class InvokeAIMetadataParser:
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
# this was more elegant as a case statement, but that's not available in python 3.9
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
scheduler_map = dict(
ddim="ddim",
plms="pnmd",
k_lms="lms",
k_dpm_2="kdpm_2",
k_dpm_2_a="kdpm_2_a",
dpmpp_2="dpmpp_2s",
k_dpmpp_2="dpmpp_2m",
k_dpmpp_2_a=None, # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
k_euler="euler",
k_euler_a="euler_a",
k_heun="heun",
)
return scheduler_map.get(old_scheduler)
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
@ -524,27 +515,27 @@ class MediaImportProcessor:
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
# This was more elegant as a case statement, but not supported in python 3.9
if input_option == "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
elif input_option == "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
elif input_option == "3":
return "IMPORT"
elif input_option == "4":
return f"IMPORT_{timestamp_string}"
elif input_option == "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""

View File

@ -7,5 +7,4 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts

View File

@ -7,8 +7,7 @@ index.html
.yarn/
.yalc/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts
docs/
static/
src/theme/css/overlayscrollbars.css

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1E00-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-ext-wght-normal-848492d3.woff2) format("woff2-variations");unicode-range:U+0460-052F,U+1C80-1C88,U+20B4,U+2DE0-2DFF,U+A640-A69F,U+FE2E-FE2F}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-cyrillic-wght-normal-262a1054.woff2) format("woff2-variations");unicode-range:U+0301,U+0400-045F,U+0490-0491,U+04B0-04B1,U+2116}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-ext-wght-normal-fe977ddb.woff2) format("woff2-variations");unicode-range:U+1F00-1FFF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-greek-wght-normal-89b4a3fe.woff2) format("woff2-variations");unicode-range:U+0370-03FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-vietnamese-wght-normal-ac4e131c.woff2) format("woff2-variations");unicode-range:U+0102-0103,U+0110-0111,U+0128-0129,U+0168-0169,U+01A0-01A1,U+01AF-01B0,U+0300-0301,U+0303-0304,U+0308-0309,U+0323,U+0329,U+1EA0-1EF9,U+20AB}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-ext-wght-normal-45606f83.woff2) format("woff2-variations");unicode-range:U+0100-02AF,U+0304,U+0308,U+0329,U+1E00-1E9F,U+1EF2-1EFF,U+2020,U+20A0-20AB,U+20AD-20CF,U+2113,U+2C60-2C7F,U+A720-A7FF}@font-face{font-family:Inter Variable;font-style:normal;font-display:swap;font-weight:100 900;src:url(./inter-latin-wght-normal-450f3ba4.woff2) format("woff2-variations");unicode-range:U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD}/*!
* OverlayScrollbars
* Version: 2.2.1
*

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-2c171c8f.js"></script>
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
</head>
<body dir="ltr">

View File

@ -19,7 +19,7 @@
"toggleAutoscroll": "Toggle autoscroll",
"toggleLogViewer": "Toggle Log Viewer",
"showGallery": "Show Gallery",
"showOptionsPanel": "Show Options Panel",
"showOptionsPanel": "Show Side Panel",
"menu": "Menu"
},
"common": {
@ -52,7 +52,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Node Editor",
"nodes": "Workflow Editor",
"batch": "Batch Manager",
"modelManager": "Model Manager",
"postprocessing": "Post Processing",
@ -95,7 +95,6 @@
"statusModelConverted": "Model Converted",
"statusMergingModels": "Merging Models",
"statusMergedModels": "Models Merged",
"pinOptionsPanel": "Pin Options Panel",
"loading": "Loading",
"loadingInvokeAI": "Loading Invoke AI",
"random": "Random",
@ -116,7 +115,6 @@
"maintainAspectRatio": "Maintain Aspect Ratio",
"autoSwitchNewImages": "Auto-Switch to New Images",
"singleColumnLayout": "Single Column Layout",
"pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded",
"loadMore": "Load More",
"noImagesInGallery": "No Images to Display",
@ -133,6 +131,7 @@
"generalHotkeys": "General Hotkeys",
"galleryHotkeys": "Gallery Hotkeys",
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
"nodesHotkeys": "Nodes Hotkeys",
"invoke": {
"title": "Invoke",
"desc": "Generate an image"
@ -332,6 +331,10 @@
"acceptStagingImage": {
"title": "Accept Staging Image",
"desc": "Accept Current Staging Area Image"
},
"addNodes": {
"title": "Add Nodes",
"desc": "Opens the add node menu"
}
},
"modelManager": {
@ -503,13 +506,15 @@
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
"codeformerFidelity": "Fidelity",
"compositingSettingsHeader": "Compositing Settings",
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"seamSize": "Seam Size",
"seamBlur": "Seam Blur",
"seamStrength": "Seam Strength",
"seamSteps": "Seam Steps",
"maskBlur": "Blur",
"maskBlurMethod": "Blur Method",
"coherencePassHeader": "Coherence Pass",
"coherenceSteps": "Steps",
"coherenceStrength": "Strength",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
"scaledWidth": "Scaled W",
"scaledHeight": "Scaled H",
@ -565,10 +570,11 @@
"useSlidersForAll": "Use Sliders For All Options",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
"resetComplete": "Web UI has been reset.",
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
@ -708,14 +714,16 @@
"ui": {
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
"swapSizes": "Swap Sizes",
"lockRatio": "Lock Ratio"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveGraph": "Save Graph",
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"reloadNodeTemplates": "Reload Node Templates",
"downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",

View File

@ -74,6 +74,7 @@
"@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"@stevebel/png": "^1.5.1",
"dateformat": "^5.0.3",
"formik": "^2.4.3",
"framer-motion": "^10.16.1",
@ -110,6 +111,7 @@
"roarr": "^7.15.1",
"serialize-error": "^11.0.1",
"socket.io-client": "^4.7.2",
"type-fest": "^4.2.0",
"use-debounce": "^9.0.4",
"use-image": "^1.1.1",
"uuid": "^9.0.0",

View File

@ -506,12 +506,13 @@
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
"codeformerFidelity": "Fidelity",
"compositingSettingsHeader": "Compositing Settings",
"maskAdjustmentsHeader": "Mask Adjustments",
"maskBlur": "Mask Blur",
"maskBlurMethod": "Mask Blur Method",
"maskBlur": "Blur",
"maskBlurMethod": "Blur Method",
"coherencePassHeader": "Coherence Pass",
"coherenceSteps": "Coherence Pass Steps",
"coherenceStrength": "Coherence Pass Strength",
"coherenceSteps": "Steps",
"coherenceStrength": "Strength",
"seamLowThreshold": "Low",
"seamHighThreshold": "High",
"scaleBeforeProcessing": "Scale Before Processing",
@ -569,6 +570,7 @@
"useSlidersForAll": "Use Sliders For All Options",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"autoChangeDimensions": "Update W/H To Model Defaults On Change",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@ -712,11 +714,12 @@
"ui": {
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
"swapSizes": "Swap Sizes",
"lockRatio": "Lock Ratio"
},
"nodes": {
"reloadNodeTemplates": "Reload Node Templates",
"saveWorkflow": "Save Workflow",
"downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",

View File

@ -14,6 +14,7 @@ import i18n from 'i18n';
import { size } from 'lodash-es';
import { ReactNode, memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
@ -23,13 +24,22 @@ const DEFAULT_CONFIG = {};
interface Props {
config?: PartialAppConfig;
headerComponent?: ReactNode;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
}
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
const App = ({
config = DEFAULT_CONFIG,
headerComponent,
selectedImage,
}: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const { handlePreselectedImage } = usePreselectedImage();
const handleReset = useCallback(() => {
localStorage.clear();
location.reload();
@ -51,6 +61,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
handlePreselectedImage(selectedImage);
}, [handlePreselectedImage, selectedImage]);
return (
<ErrorBoundary
onReset={handleReset}

View File

@ -26,6 +26,10 @@ interface Props extends PropsWithChildren {
headerComponent?: ReactNode;
middleware?: Middleware[];
projectId?: string;
selectedImage?: {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
}
const InvokeAIUI = ({
@ -35,6 +39,7 @@ const InvokeAIUI = ({
headerComponent,
middleware,
projectId,
selectedImage,
}: Props) => {
useEffect(() => {
// configure API client token
@ -81,7 +86,11 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} headerComponent={headerComponent} />
<App
config={config}
headerComponent={headerComponent}
selectedImage={selectedImage}
/>
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@ -15,7 +15,9 @@ import { addDeleteBoardAndImagesFulfilledListener } from './listeners/boardAndIm
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from './listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from './listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from './listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from './listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
@ -41,6 +43,8 @@ import {
addImageUploadedFulfilledListener,
addImageUploadedRejectedListener,
} from './listeners/imageUploaded';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
@ -80,8 +84,6 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addImagesStarredListener } from './listeners/imagesStarred';
import { addImagesUnstarredListener } from './listeners/imagesUnstarred';
export const listenerMiddleware = createListenerMiddleware();
@ -137,6 +139,8 @@ addSessionReadyToInvokeListener();
// Canvas actions
addCanvasSavedToGalleryListener();
addCanvasMaskSavedToGalleryListener();
addCanvasImageToControlNetListener();
addCanvasMaskToControlNetListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();

View File

@ -0,0 +1,58 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
log.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Canvas Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@ -0,0 +1,70 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
dispatch(
addToast({
title: 'Problem Importing Mask',
description: 'Unable to export mask',
status: 'error',
})
);
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
toastOptions: { title: 'Mask Sent to ControlNet & Assets' },
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);
},
});
};

View File

@ -1,9 +1,12 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
modelChanged,
setHeight,
setWidth,
vaeSelected,
} from 'features/parameters/store/generationSlice';
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
@ -74,6 +77,22 @@ export const addModelSelectedListener = () => {
}
}
// Update Width / Height / Bounding Box Dimensions on Model Change
if (
state.generation.model?.base_model !== newModel.base_model &&
state.ui.shouldAutoChangeDimensions
) {
if (['sdxl', 'sdxl-refiner'].includes(newModel.base_model)) {
dispatch(setWidth(1024));
dispatch(setHeight(1024));
dispatch(setBoundingBoxDimensions({ width: 1024, height: 1024 }));
} else {
dispatch(setWidth(512));
dispatch(setHeight(512));
dispatch(setBoundingBoxDimensions({ width: 512, height: 512 }));
}
}
dispatch(modelChanged(newModel));
},
});

View File

@ -6,11 +6,11 @@ import {
configureStore,
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';

View File

@ -86,8 +86,8 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box
sx={{
p: 2,
pt: 3,
p: 4,
pb: 4,
borderBottomRadius: 'base',
bg: 'base.150',
_dark: {

View File

@ -1,10 +1,12 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { AnimatePresence, motion } from 'framer-motion';
import {
KeyboardEvent,
ReactNode,
@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import ImageUploadOverlay from './ImageUploadOverlay';
import { AnimatePresence, motion } from 'framer-motion';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
[stateSelector, activeTabNameSelector],

View File

@ -0,0 +1,56 @@
import { Box } from '@chakra-ui/react';
import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.light';
}
if (isSelected) {
return 'nodeSelected.light';
}
if (isHovered) {
return 'nodeHovered.light';
}
return undefined;
}, [isHovered, isSelected]);
const shadowDark = useMemo(() => {
if (isSelected && isHovered) {
return 'nodeHoveredSelected.dark';
}
if (isSelected) {
return 'nodeSelected.dark';
}
if (isHovered) {
return 'nodeHovered.dark';
}
return undefined;
}, [isHovered, isSelected]);
return (
<Box
className="selection-box"
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
opacity: isSelected || isHovered ? 1 : 0.5,
transitionProperty: 'common',
transitionDuration: '0.1s',
pointerEvents: 'none',
shadow,
_dark: {
shadow: shadowDark,
},
}}
/>
);
};
export default memo(SelectionOverlay);

View File

@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@ -20,3 +21,11 @@ export const canvasMerged = createAction('canvas/canvasMerged');
export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');

View File

@ -17,11 +17,13 @@ import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
@ -36,6 +38,8 @@ const ControlNet = (props: ControlNetProps) => {
const { controlNetId } = controlNet;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
const selector = createSelector(
stateSelector,
({ controlNet }) => {
@ -108,6 +112,9 @@ const ControlNet = (props: ControlNetProps) => {
>
<ParamControlNetModel controlNet={controlNet} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports controlNet={controlNet} />
)}
<IAIIconButton
size="sm"
tooltip="Duplicate"
@ -167,6 +174,7 @@ const ControlNet = (props: ControlNetProps) => {
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex

View File

@ -5,13 +5,21 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
import {
useAddImageToBoardMutation,
useChangeImageIsIntermediateMutation,
useGetImageDTOQuery,
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import {
@ -26,11 +34,13 @@ type Props = {
const selector = createSelector(
stateSelector,
({ controlNet }) => {
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
const { autoAddBoardId } = gallery;
return {
pendingControlImages,
autoAddBoardId,
};
},
defaultSelectorOptions
@ -47,7 +57,8 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
const activeTabName = useAppSelector(activeTabNameSelector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -59,9 +70,57 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
processedControlImageName ?? skipToken
);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
return;
}
await changeIsIntermediate({
imageDTO: processedControlImage,
is_intermediate: false,
}).unwrap();
if (autoAddBoardId !== 'none') {
addToBoard({
imageDTO: processedControlImage,
board_id: autoAddBoardId,
});
} else {
removeFromBoard({ imageDTO: processedControlImage });
}
}, [
processedControlImage,
changeIsIntermediate,
autoAddBoardId,
addToBoard,
removeFromBoard,
]);
const handleSetControlImageToDimensions = useCallback(() => {
if (!processedControlImage) {
return;
}
if (activeTabName === 'unifiedCanvas') {
dispatch(
setBoundingBoxDimensions({
width: processedControlImage.width,
height: processedControlImage.height,
})
);
} else {
dispatch(setWidth(processedControlImage.width));
dispatch(setHeight(processedControlImage.height));
}
}, [processedControlImage, activeTabName, dispatch]);
const handleMouseEnter = useCallback(() => {
setIsMouseOverImage(true);
}, []);
@ -121,13 +180,7 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
postUploadAction={postUploadAction}
>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
</IAIDndImage>
/>
<Box
sx={{
@ -148,14 +201,29 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
</IAIDndImage>
/>
</Box>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <FaUndo /> : undefined}
tooltip="Reset Control Image"
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <FaSave size={16} /> : undefined}
tooltip="Save Control Image"
styleOverrides={{ marginTop: 6 }}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <FaRulerVertical size={16} /> : undefined}
tooltip="Set Control Image Dimensions To W/H"
styleOverrides={{ marginTop: 12 }}
/>
</>
{pendingControlImages.includes(controlNetId) && (
<Flex
sx={{

View File

@ -0,0 +1,54 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
canvasImageToControlNet,
canvasMaskToControlNet,
} from 'features/canvas/store/actions';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { FaImage, FaMask } from 'react-icons/fa';
type ControlNetCanvasImageImportsProps = {
controlNet: ControlNetConfig;
};
const ControlNetCanvasImageImports = (
props: ControlNetCanvasImageImportsProps
) => {
const { controlNet } = props;
const dispatch = useAppDispatch();
const handleImportImageFromCanvas = useCallback(() => {
dispatch(canvasImageToControlNet({ controlNet }));
}, [controlNet, dispatch]);
const handleImportMaskFromCanvas = useCallback(() => {
dispatch(canvasMaskToControlNet({ controlNet }));
}, [controlNet, dispatch]);
return (
<Flex
sx={{
gap: 2,
}}
>
<IAIIconButton
size="sm"
icon={<FaImage />}
tooltip="Import Image From Canvas"
aria-label="Import Image From Canvas"
onClick={handleImportImageFromCanvas}
/>
<IAIIconButton
size="sm"
icon={<FaMask />}
tooltip="Import Mask From Canvas"
aria-label="Import Mask From Canvas"
onClick={handleImportMaskFromCanvas}
/>
</Flex>
);
};
export default memo(ControlNetCanvasImageImports);

View File

@ -4,11 +4,11 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { memo } from 'react';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
import { memo } from 'react';
const selector = createSelector(
stateSelector,

View File

@ -15,6 +15,7 @@ import { BoardDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type Props = {
board?: BoardDTO;
@ -33,12 +34,16 @@ const BoardContextMenu = ({
const selector = useMemo(
() =>
createSelector(stateSelector, ({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id;
const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
}),
createSelector(
stateSelector,
({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id;
const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
},
defaultSelectorOptions
),
[board_id]
);

View File

@ -9,20 +9,24 @@ import {
MenuButton,
MenuList,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
@ -37,12 +41,12 @@ import {
FaSeedling,
FaShareAlt,
} from 'react-icons/fa';
import { MdDeviceHub } from 'react-icons/md';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
} from 'services/api/endpoints/images';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
@ -101,22 +105,36 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: imageDTO } = useGetImageDTOQuery(
lastSelectedImage?.image_name ?? skipToken
);
const { currentData: metadataData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg?.image_name ?? skipToken
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const metadata = metadataData?.metadata;
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
@ -153,6 +171,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [imageDTO]);
useHotkeys('w', handleLoadWorkflow, [workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@ -259,22 +279,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton
isLoading={isLoading}
icon={<MdDeviceHub />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!workflow}
onClick={handleLoadWorkflow}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!metadata?.seed}
onClick={handleUseSeed}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}

View File

@ -1,5 +1,4 @@
import { Flex, MenuItem, Text } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -8,9 +7,12 @@ import {
isModalOpenChanged,
} from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
@ -26,14 +28,13 @@ import {
FaShare,
FaTrash,
} from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';
import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
import {
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
type SingleSelectionMenuItemsProps = {
@ -50,15 +51,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const [starImages] = useStarImagesMutation();
@ -67,8 +68,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
if (!imageDTO) {
return;
@ -99,6 +98,22 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@ -118,7 +133,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
}, [dispatch, imageDTO, t, toaster]);
const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
@ -169,27 +183,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem
icon={<FaQuoteRight />}
icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
onClickCapture={handleLoadWorkflow}
isDisabled={isLoading || !workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>
<MenuItem
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined
isLoading ||
(metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined)
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<FaSeedling />}
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
isDisabled={isLoading || metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<FaAsterisk />}
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
isDisabled={isLoading || !metadata}
>
{t('parameters.useAll')}
</MenuItem>
@ -228,20 +249,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
>
{t('gallery.deleteImage')}
</MenuItem>
{metadata?.created_by && (
<Flex
sx={{
padding: '5px 10px',
marginTop: '5px',
}}
>
<Text fontSize="xs" fontWeight="bold">
Created by {metadata?.created_by}
</Text>
</Flex>
)}
</>
);
};
export default memo(SingleSelectionMenuItems);
const SpinnerIcon = () => (
<Flex w="14px" alignItems="center" justifyContent="center">
<Spinner size="xs" />
</Flex>
);

View File

@ -39,7 +39,7 @@ const ImageGalleryContent = () => {
const { galleryView } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { isOpen: isBoardListOpen, onToggle: onToggleBoardList } =
useDisclosure();
useDisclosure({ defaultIsOpen: true });
const handleClickImages = useCallback(() => {
dispatch(galleryViewChanged('images'));

View File

@ -8,7 +8,7 @@ import {
ImageDraggableData,
TypesafeDraggableData,
} from 'features/dnd/types';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect.ts';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
import { FaTrash } from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';

View File

@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa';
import { FaCopy, FaDownload } from 'react-icons/fa';
type Props = {
label: string;
@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
navigator.clipboard.writeText(dataString);
}, [dataString]);
const handleSave = useCallback(() => {
const handleDownload = useCallback(() => {
const blob = new Blob([dataString]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
{withDownload && (
<Tooltip label={`Save ${label} JSON`}>
<Tooltip label={`Download ${label} JSON`}>
<IconButton
aria-label={`Save ${label} JSON`}
icon={<FaSave />}
aria-label={`Download ${label} JSON`}
icon={<FaDownload />}
variant="ghost"
opacity={0.7}
onClick={handleSave}
onClick={handleDownload}
/>
</Tooltip>
)}

View File

@ -1,10 +1,10 @@
import { CoreMetadata } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/types';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
metadata?: CoreMetadata;
};
const ImageMetadataActions = (props: Props) => {
@ -94,14 +94,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
{metadata.model !== undefined && metadata.model !== null && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
@ -150,7 +150,7 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
<ImageMetadataItem
label="CFG scale"
value={metadata.cfg_scale}

View File

@ -9,14 +9,12 @@ import {
Tabs,
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import ImageMetadataActions from './ImageMetadataActions';
import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
type ImageMetadataViewerProps = {
image: ImageDTO;
@ -29,19 +27,16 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false));
// });
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
image.image_name,
500
{
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
return (
<Flex
layerStyle="first"
@ -71,17 +66,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
<Tab>Core Metadata</Tab>
<Tab>Metadata</Tab>
<Tab>Image Details</Tab>
<Tab>Graph</Tab>
<Tab>Workflow</Tab>
</TabList>
<TabPanels>
<TabPanel>
{metadata ? (
<DataViewer data={metadata} label="Core Metadata" />
<DataViewer data={metadata} label="Metadata" />
) : (
<IAINoContentFallback label="No core metadata found" />
<IAINoContentFallback label="No metadata found" />
)}
</TabPanel>
<TabPanel>
@ -92,10 +87,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
)}
</TabPanel>
<TabPanel>
{graph ? (
<DataViewer data={graph} label="Graph" />
{workflow ? (
<DataViewer data={workflow} label="Workflow" />
) : (
<IAINoContentFallback label="No graph found" />
<IAINoContentFallback label="No workflow found" />
)}
</TabPanel>
</TabPanels>

View File

@ -1,13 +1,15 @@
import { Flex, Image, Text } from '@chakra-ui/react';
import { useState, PropsWithChildren, memo } from 'react';
import { useSelector } from 'react-redux';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { Flex, Image, Text } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { PropsWithChildren, memo } from 'react';
import { useSelector } from 'react-redux';
import { NodeProps } from 'reactflow';
import NodeWrapper from '../common/NodeWrapper';
import { stateSelector } from 'app/store/store';
const selector = createSelector(stateSelector, ({ system, gallery }) => {
const imageDTO = gallery.selection[gallery.selection.length - 1];
@ -54,44 +56,90 @@ const CurrentImageNode = (props: NodeProps) => {
export default memo(CurrentImageNode);
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => (
<NodeWrapper
nodeId={props.nodeProps.data.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{
flexDirection: 'column',
}}
const Wrapper = (props: PropsWithChildren<{ nodeProps: NodeProps }>) => {
const [isHovering, setIsHovering] = useState(false);
const handleMouseEnter = () => {
setIsHovering(true);
};
const handleMouseLeave = () => {
setIsHovering(false);
};
return (
<NodeWrapper
nodeId={props.nodeProps.id}
selected={props.nodeProps.selected}
width={384}
>
<Flex
layerStyle="nodeHeader"
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
className={DRAG_HANDLE_CLASSNAME}
sx={{
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
h: 8,
position: 'relative',
flexDirection: 'column',
}}
>
<Text
<Flex
layerStyle="nodeHeader"
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
_dark: { color: 'base.200' },
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
h: 8,
}}
>
Current Image
</Text>
<Text
sx={{
fontSize: 'sm',
fontWeight: 600,
color: 'base.700',
_dark: { color: 'base.200' },
}}
>
Current Image
</Text>
</Flex>
<Flex
layerStyle="nodeBody"
sx={{
w: 'full',
h: 'full',
borderBottomRadius: 'base',
p: 2,
}}
>
{props.children}
{isHovering && (
<motion.div
key="nextPrevButtons"
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 40,
left: -2,
right: -2,
bottom: 0,
pointerEvents: 'none',
}}
>
<NextPrevImageButtons />
</motion.div>
)}
</Flex>
</Flex>
<Flex
layerStyle="nodeBody"
sx={{ w: 'full', h: 'full', borderBottomRadius: 'base', p: 2 }}
>
{props.children}
</Flex>
</Flex>
</NodeWrapper>
);
</NodeWrapper>
);
};

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const embedWorkflow = useEmbedWorkflow(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeEmbedWorkflowChanged({
nodeId,
embedWorkflow: e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={embedWorkflow}
/>
</FormControl>
);
};
export default memo(EmbedWorkflowCheckbox);

View File

@ -41,7 +41,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
flexDirection: 'column',
w: 'full',
h: 'full',
py: 1,
py: 2,
gap: 1,
borderBottomRadius: withFooter ? 0 : 'base',
}}

View File

@ -1,16 +1,8 @@
import {
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { Flex } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
type Props = {
nodeId: string;
@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
px: 2,
py: 0,
h: 6,
justifyContent: 'space-between',
}}
>
<Spacer />
<SaveImageCheckbox nodeId={nodeId} />
<EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveToGalleryCheckbox nodeId={nodeId} />
</Flex>
);
};
export default memo(InvocationNodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -1,7 +1,5 @@
import {
Flex,
FormControl,
FormLabel,
Icon,
Modal,
ModalBody,
@ -14,16 +12,14 @@ import {
Tooltip,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo, useMemo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
import NotesTextarea from './NotesTextarea';
interface Props {
nodeId: string;
@ -80,13 +76,29 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
}
if (!data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return 'Unknown Node';
}, [data, nodeTemplate]);
if (!isInvocationNodeData(data)) {
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
}
return (
<Flex sx={{ flexDir: 'column' }}>
<Text sx={{ fontWeight: 600 }}>{nodeTemplate?.title}</Text>
<Text sx={{ fontWeight: 600 }}>{title}</Text>
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
{nodeTemplate?.description}
</Text>
@ -96,29 +108,3 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
});
TooltipContent.displayName = 'TooltipContent';
const NotesTextarea = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
});
NotesTextarea.displayName = 'NodesTextarea';

View File

@ -0,0 +1,33 @@
import { FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAITextarea from 'common/components/IAITextarea';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/types';
import { ChangeEvent, memo, useCallback } from 'react';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
return null;
}
return (
<FormControl>
<FormLabel>Notes</FormLabel>
<IAITextarea
value={data?.notes}
onChange={handleNotesChanged}
rows={10}
/>
</FormControl>
);
};
export default memo(NotesTextarea);

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const isIntermediate = useIsIntermediate(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeIsIntermediateChanged({
nodeId,
isIntermediate: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={!isIntermediate}
/>
</FormControl>
);
};
export default memo(SaveToGalleryCheckbox);

View File

@ -0,0 +1,167 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
Tooltip,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
import FieldTooltipContent from './FieldTooltipContent';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
interface Props {
nodeId: string;
fieldName: string;
kind: 'input' | 'output';
isMissingInput?: boolean;
withTooltip?: boolean;
}
const EditableFieldTitle = forwardRef((props: Props, ref) => {
const {
nodeId,
fieldName,
kind,
isMissingInput = false,
withTooltip = false,
} = props;
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return (
<Tooltip
label={
withTooltip ? (
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
) : undefined
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
>
<Flex
ref={ref}
sx={{
position: 'relative',
overflow: 'hidden',
alignItems: 'center',
justifyContent: 'flex-start',
gap: 1,
h: 'full',
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
w: 'full',
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
</Flex>
</Tooltip>
);
});
export default memo(EditableFieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@ -1,16 +1,7 @@
import {
Editable,
EditableInput,
EditablePreview,
Flex,
forwardRef,
useEditableControls,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { Flex, Text, forwardRef } from '@chakra-ui/react';
import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel';
import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { MouseEvent, memo, useCallback, useEffect, useState } from 'react';
import { memo } from 'react';
interface Props {
nodeId: string;
@ -24,31 +15,6 @@ const FieldTitle = forwardRef((props: Props, ref) => {
const label = useFieldLabel(nodeId, fieldName);
const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, kind);
const dispatch = useAppDispatch();
const [localTitle, setLocalTitle] = useState(
label || fieldTemplateTitle || 'Unknown Field'
);
const handleSubmit = useCallback(
async (newTitle: string) => {
if (newTitle && (newTitle === label || newTitle === fieldTemplateTitle)) {
return;
}
setLocalTitle(newTitle || fieldTemplateTitle || 'Unknown Field');
dispatch(fieldLabelChanged({ nodeId, fieldName, label: newTitle }));
},
[label, fieldTemplateTitle, dispatch, nodeId, fieldName]
);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
useEffect(() => {
// Another component may change the title; sync local title with global state
setLocalTitle(label || fieldTemplateTitle || 'Unknown Field');
}, [label, fieldTemplateTitle]);
return (
<Flex
ref={ref}
@ -62,82 +28,11 @@ const FieldTitle = forwardRef((props: Props, ref) => {
w: 'full',
}}
>
<Editable
value={localTitle}
onChange={handleChange}
onSubmit={handleSubmit}
as={Flex}
sx={{
position: 'relative',
alignItems: 'center',
h: 'full',
w: 'full',
}}
>
<EditablePreview
sx={{
p: 0,
fontWeight: isMissingInput ? 600 : 400,
textAlign: 'left',
_hover: {
fontWeight: '600 !important',
},
}}
noOfLines={1}
/>
<EditableInput
className="nodrag"
sx={{
p: 0,
fontWeight: 600,
color: 'base.900',
_dark: {
color: 'base.100',
},
_focusVisible: {
p: 0,
textAlign: 'left',
boxShadow: 'none',
},
}}
/>
<EditableControls />
</Editable>
<Text sx={{ fontWeight: isMissingInput ? 600 : 400 }}>
{label || fieldTemplateTitle}
</Text>
</Flex>
);
});
export default memo(FieldTitle);
const EditableControls = memo(() => {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
const { onClick } = getEditButtonProps();
if (!onClick) {
return;
}
onClick(e);
e.preventDefault();
},
[getEditButtonProps]
);
if (isEditing) {
return null;
}
return (
<Flex
onClick={handleClick}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
cursor="text"
/>
);
});
EditableControls.displayName = 'EditableControls';

View File

@ -34,6 +34,8 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => {
}
return 'Unknown Field';
} else {
return fieldTemplate?.title || 'Unknown Field';
}
}, [field, fieldTemplate]);

View File

@ -1,16 +1,11 @@
import { Box, Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue';
import { useFieldInputKind } from 'features/nodes/hooks/useFieldInputKind';
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { PropsWithChildren, memo, useMemo } from 'react';
import EditableFieldTitle from './EditableFieldTitle';
import FieldContextMenu from './FieldContextMenu';
import FieldHandle from './FieldHandle';
import FieldTitle from './FieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
interface Props {
@ -21,7 +16,6 @@ interface Props {
const InputField = ({ nodeId, fieldName }: Props) => {
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName);
const input = useFieldInputKind(nodeId, fieldName);
const {
isConnected,
@ -51,11 +45,7 @@ const InputField = ({ nodeId, fieldName }: Props) => {
if (fieldTemplate?.fieldKind !== 'input') {
return (
<InputFieldWrapper
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
sx={{ color: 'error.400', textAlign: 'left', fontSize: 'sm' }}
>
@ -66,19 +56,14 @@ const InputField = ({ nodeId, fieldName }: Props) => {
}
return (
<InputFieldWrapper
nodeId={nodeId}
fieldName={fieldName}
shouldDim={shouldDim}
>
<InputFieldWrapper shouldDim={shouldDim}>
<FormControl
as={Flex}
isInvalid={isMissingInput}
isDisabled={isConnected}
sx={{
alignItems: 'stretch',
justifyContent: 'space-between',
ps: 2,
ps: fieldTemplate.input === 'direct' ? 0 : 2,
gap: 2,
h: 'full',
w: 'full',
@ -86,42 +71,27 @@ const InputField = ({ nodeId, fieldName }: Props) => {
>
<FieldContextMenu nodeId={nodeId} fieldName={fieldName} kind="input">
{(ref) => (
<Tooltip
label={
<FieldTooltipContent
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
}
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
placement="top"
hasArrow
<FormLabel
sx={{
display: 'flex',
alignItems: 'center',
mb: 0,
px: 1,
gap: 2,
}}
>
<FormLabel
sx={{
mb: 0,
width: input === 'connection' ? 'auto' : '25%',
flexShrink: 0,
flexGrow: 0,
}}
>
<FieldTitle
ref={ref}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
isMissingInput={isMissingInput}
/>
</FormLabel>
</Tooltip>
<EditableFieldTitle
ref={ref}
nodeId={nodeId}
fieldName={fieldName}
kind="input"
isMissingInput={isMissingInput}
withTooltip
/>
</FormLabel>
)}
</FieldContextMenu>
<Box
sx={{
width: input === 'connection' ? 'auto' : '75%',
}}
>
<Box>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</Box>
</FormControl>
@ -143,19 +113,12 @@ export default memo(InputField);
type InputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
nodeId: string;
fieldName: string;
}>;
const InputFieldWrapper = memo(
({ shouldDim, nodeId, fieldName, children }: InputFieldWrapperProps) => {
const { isMouseOverField, handleMouseOver, handleMouseOut } =
useIsMouseOverField(nodeId, fieldName);
({ shouldDim, children }: InputFieldWrapperProps) => {
return (
<Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
sx={{
position: 'relative',
minH: 8,
@ -169,7 +132,6 @@ const InputFieldWrapper = memo(
}}
>
{children}
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
</Flex>
);
}

View File

@ -10,6 +10,7 @@ import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField';
@ -105,6 +106,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'DenoiseMaskField'
) {
return (
<DenoiseMaskInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'

View File

@ -1,13 +1,20 @@
import { Flex, FormControl, FormLabel, Icon, Tooltip } from '@chakra-ui/react';
import {
Flex,
FormControl,
FormLabel,
Icon,
Spacer,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { useIsMouseOverField } from 'features/nodes/hooks/useIsMouseOverField';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { memo, useCallback } from 'react';
import { FaInfoCircle, FaTrash } from 'react-icons/fa';
import FieldTitle from './FieldTitle';
import EditableFieldTitle from './EditableFieldTitle';
import FieldTooltipContent from './FieldTooltipContent';
import InputFieldRenderer from './InputFieldRenderer';
@ -18,8 +25,8 @@ type Props = {
const LinearViewField = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { isMouseOverField, handleMouseOut, handleMouseOver } =
useIsMouseOverField(nodeId, fieldName);
const { isMouseOverNode, handleMouseOut, handleMouseOver } =
useMouseOverNode(nodeId);
const handleRemoveField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
@ -27,8 +34,8 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
return (
<Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
onMouseEnter={handleMouseOver}
onMouseLeave={handleMouseOut}
layerStyle="second"
sx={{
position: 'relative',
@ -42,11 +49,15 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
sx={{
display: 'flex',
alignItems: 'center',
justifyContent: 'space-between',
mb: 0,
}}
>
<FieldTitle nodeId={nodeId} fieldName={fieldName} kind="input" />
<EditableFieldTitle
nodeId={nodeId}
fieldName={fieldName}
kind="input"
/>
<Spacer />
<Tooltip
label={
<FieldTooltipContent
@ -74,7 +85,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => {
</FormLabel>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
</FormControl>
<SelectionOverlay isSelected={false} isHovered={isMouseOverField} />
<NodeSelectionOverlay isSelected={false} isHovered={isMouseOverNode} />
</Flex>
);
};

View File

@ -92,6 +92,7 @@ const ControlNetModelInputFieldComponent = (
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};

View File

@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};
export default memo(DenoiseMaskInputFieldComponent);

View File

@ -101,8 +101,10 @@ const LoRAModelInputFieldComponent = (
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())
}
error={!selectedLoRAModel}
onChange={handleChange}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@ -134,6 +134,7 @@ const MainModelInputFieldComponent = (
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

View File

@ -1,12 +1,12 @@
import { Box, Flex } from '@chakra-ui/react';
import { Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
FieldComponentProps,
SDXLRefinerModelInputFieldTemplate,
SDXLRefinerModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
@ -101,20 +101,17 @@ const RefinerModelInputFieldComponent = (
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
error={!selectedModel}
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},
}}
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton className="nodrag" iconMode />
</Box>
)}
{isSyncModelEnabled && <SyncModelsButton className="nodrag" iconMode />}
</Flex>
);
};

View File

@ -128,10 +128,11 @@ const ModelInputFieldComponent = (
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
error={!selectedModel}
disabled={data.length === 0}
onChange={handleChangeModel}
sx={{
width: '100%',
'.mantine-Select-dropdown': {
width: '16rem !important',
},

Some files were not shown because too many files have changed in this diff Show More