mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into feat/workflow-saving
This commit is contained in:
commit
81d2d5abae
@ -8,7 +8,7 @@ To use a node, add the node to the `nodes` folder found in your InvokeAI install
|
|||||||
|
|
||||||
The suggested method is to use `git clone` to clone the repository the node is found in. This allows for easy updates of the node in the future.
|
The suggested method is to use `git clone` to clone the repository the node is found in. This allows for easy updates of the node in the future.
|
||||||
|
|
||||||
If you'd prefer, you can also just download the `.py` file from the linked repository and add it to the `nodes` folder.
|
If you'd prefer, you can also just download the whole node folder from the linked repository and add it to the `nodes` folder.
|
||||||
|
|
||||||
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||||
|
|
||||||
@ -26,6 +26,7 @@ To use a community workflow, download the the `.json` node graph file and load i
|
|||||||
+ [Image Picker](#image-picker)
|
+ [Image Picker](#image-picker)
|
||||||
+ [Load Video Frame](#load-video-frame)
|
+ [Load Video Frame](#load-video-frame)
|
||||||
+ [Make 3D](#make-3d)
|
+ [Make 3D](#make-3d)
|
||||||
|
+ [Match Histogram](#match-histogram)
|
||||||
+ [Oobabooga](#oobabooga)
|
+ [Oobabooga](#oobabooga)
|
||||||
+ [Prompt Tools](#prompt-tools)
|
+ [Prompt Tools](#prompt-tools)
|
||||||
+ [Remote Image](#remote-image)
|
+ [Remote Image](#remote-image)
|
||||||
@ -208,6 +209,23 @@ This includes 15 Nodes:
|
|||||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
|
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
|
||||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
|
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
### Match Histogram
|
||||||
|
|
||||||
|
**Description:** An InvokeAI node to match a histogram from one image to another. This is a bit like the `color correct` node in the main InvokeAI but this works in the YCbCr colourspace and can handle images of different sizes. Also does not require a mask input.
|
||||||
|
- Option to only transfer luminance channel.
|
||||||
|
- Option to save output as grayscale
|
||||||
|
|
||||||
|
A good use case for this node is to normalize the colors of an image that has been through the tiled scaling workflow of my XYGrid Nodes.
|
||||||
|
|
||||||
|
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/skunkworxdark/match_histogram
|
||||||
|
|
||||||
|
**Output Examples**
|
||||||
|
|
||||||
|
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Oobabooga
|
### Oobabooga
|
||||||
|
|
||||||
@ -237,22 +255,30 @@ This node works best with SDXL models, especially as the style can be described
|
|||||||
--------------------------------
|
--------------------------------
|
||||||
### Prompt Tools
|
### Prompt Tools
|
||||||
|
|
||||||
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These were written to accompany the PromptsFromFile node and other prompt generation nodes.
|
**Description:** A set of InvokeAI nodes that add general prompt (string) manipulation tools. Designed to accompany the `Prompts From File` node and other prompt generation nodes.
|
||||||
|
|
||||||
|
1. `Prompt To File` - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
||||||
|
2. `PTFields Collect` - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
||||||
|
3. `PTFields Expand` - Takes Json string and converts it to individual generation parameters. This can be fed from the Prompts to file node.
|
||||||
|
4. `Prompt Strength` - Formats prompt with strength like the weighted format of compel
|
||||||
|
5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend()
|
||||||
|
6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option
|
||||||
|
|
||||||
|
The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.<br>
|
||||||
|
- `Prompt Join` -> `String Join`
|
||||||
|
- `Prompt Join Three` -> `String Join Three`
|
||||||
|
- `Prompt Replace` -> `String Replace`
|
||||||
|
- `Prompt Split Neg` -> `String Split Neg`
|
||||||
|
|
||||||
1. PromptJoin - Joins to prompts into one.
|
|
||||||
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
|
|
||||||
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
|
|
||||||
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
|
||||||
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
|
||||||
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
|
|
||||||
7. PromptJoinThree - Joins 3 prompt together.
|
|
||||||
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
|
|
||||||
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
|
|
||||||
|
|
||||||
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||||
|
|
||||||
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
|
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
|
||||||
|
|
||||||
|
**Workflow Examples**
|
||||||
|
|
||||||
|
<img src="https://github.com/skunkworxdark/prompt-tools/blob/main/images/CSVToIndexStringNode.png" width="300" />
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Remote Image
|
### Remote Image
|
||||||
|
|
||||||
@ -339,15 +365,27 @@ Highlights/Midtones/Shadows (with LUT blur enabled):
|
|||||||
--------------------------------
|
--------------------------------
|
||||||
### XY Image to Grid and Images to Grids nodes
|
### XY Image to Grid and Images to Grids nodes
|
||||||
|
|
||||||
**Description:** Image to grid nodes and supporting tools.
|
**Description:** These nodes add the following to InvokeAI:
|
||||||
|
- Generate grids of images from multiple input images
|
||||||
|
- Create XY grid images with labels from parameters
|
||||||
|
- Split images into overlapping tiles for processing (for super-resolution workflows)
|
||||||
|
- Recombine image tiles into a single output image blending the seams
|
||||||
|
|
||||||
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then multiple grids will be created until it runs out of images.
|
The nodes include:
|
||||||
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporting nodes. See example node setups for more details.
|
1. `Images To Grids` - Combine multiple images into a grid of images
|
||||||
|
2. `XYImage To Grid` - Take X & Y params and creates a labeled image grid.
|
||||||
|
3. `XYImage Tiles` - Super-resolution (embiggen) style tiled resizing
|
||||||
|
4. `Image Tot XYImages` - Takes an image and cuts it up into a number of columns and rows.
|
||||||
|
5. Multiple supporting nodes - Helper nodes for data wrangling and building `XYImage` collections
|
||||||
|
|
||||||
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
|
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
|
||||||
|
|
||||||
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
|
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
|
||||||
|
|
||||||
|
**Output Examples**
|
||||||
|
|
||||||
|
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Example Node Template
|
### Example Node Template
|
||||||
|
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
# List of Default Nodes
|
# List of Default Nodes
|
||||||
|
|
||||||
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
|
The table below contains a list of the default nodes shipped with InvokeAI and
|
||||||
|
their descriptions.
|
||||||
|
|
||||||
| Node <img width=160 align="right"> | Function |
|
| Node <img width=160 align="right"> | Function |
|
||||||
|: ---------------------------------- | :--------------------------------------------------------------------------------------|
|
| :------------------------------------------------------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| Add Integers | Adds two numbers |
|
| Add Integers | Adds two numbers |
|
||||||
| Boolean Primitive Collection | A collection of boolean primitive values |
|
| Boolean Primitive Collection | A collection of boolean primitive values |
|
||||||
| Boolean Primitive | A boolean primitive value |
|
| Boolean Primitive | A boolean primitive value |
|
||||||
| Canny Processor | Canny edge detection for ControlNet |
|
| Canny Processor | Canny edge detection for ControlNet |
|
||||||
|
| CenterPadCrop | Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image. |
|
||||||
| CLIP Skip | Skip layers in clip text_encoder model. |
|
| CLIP Skip | Skip layers in clip text_encoder model. |
|
||||||
| Collect | Collects values into a collection |
|
| Collect | Collects values into a collection |
|
||||||
| Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. |
|
| Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. |
|
||||||
@ -74,7 +76,7 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
|||||||
| Noise | Generates latent noise. |
|
| Noise | Generates latent noise. |
|
||||||
| Normal BAE Processor | Applies NormalBae processing to image |
|
| Normal BAE Processor | Applies NormalBae processing to image |
|
||||||
| ONNX Latents to Image | Generates an image from latents. |
|
| ONNX Latents to Image | Generates an image from latents. |
|
||||||
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.|
|
| ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in **init** to receive providers. |
|
||||||
| ONNX Text to Latents | Generates latents from conditionings. |
|
| ONNX Text to Latents | Generates latents from conditionings. |
|
||||||
| ONNX Model Loader | Loads a main model, outputting its submodels. |
|
| ONNX Model Loader | Loads a main model, outputting its submodels. |
|
||||||
| OpenCV Inpaint | Simple inpaint using opencv. |
|
| OpenCV Inpaint | Simple inpaint using opencv. |
|
||||||
|
@ -119,6 +119,61 @@ class ImageCropInvocation(BaseInvocation, WithMetadata):
|
|||||||
category="image",
|
category="image",
|
||||||
version="1.2.0",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
|
class CenterPadCropInvocation(BaseInvocation):
|
||||||
|
"""Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The image to crop")
|
||||||
|
left: int = InputField(
|
||||||
|
default=0,
|
||||||
|
description="Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards)",
|
||||||
|
)
|
||||||
|
right: int = InputField(
|
||||||
|
default=0,
|
||||||
|
description="Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards)",
|
||||||
|
)
|
||||||
|
top: int = InputField(
|
||||||
|
default=0,
|
||||||
|
description="Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards)",
|
||||||
|
)
|
||||||
|
bottom: int = InputField(
|
||||||
|
default=0,
|
||||||
|
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
# Calculate and create new image dimensions
|
||||||
|
new_width = image.width + self.right + self.left
|
||||||
|
new_height = image.height + self.top + self.bottom
|
||||||
|
image_crop = Image.new(mode="RGBA", size=(new_width, new_height), color=(0, 0, 0, 0))
|
||||||
|
|
||||||
|
# Paste new image onto input
|
||||||
|
image_crop.paste(image, (self.left, self.top))
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=image_crop,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
invocation_type="img_pad_crop",
|
||||||
|
title="Center Pad or Crop Image",
|
||||||
|
category="image",
|
||||||
|
tags=["image", "pad", "crop"],
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
class ImagePasteInvocation(BaseInvocation, WithMetadata):
|
class ImagePasteInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
|
@ -78,6 +78,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||||
|
|
||||||
|
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||||
|
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||||
|
# factor is hard-coded to a literal '8' rather than using this constant.
|
||||||
|
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||||
|
LATENT_SCALE_FACTOR = 8
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
class SchedulerOutput(BaseInvocationOutput):
|
class SchedulerOutput(BaseInvocationOutput):
|
||||||
@ -214,7 +220,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.4.0",
|
version="1.5.0",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -272,6 +278,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=7,
|
ui_order=7,
|
||||||
)
|
)
|
||||||
|
cfg_rescale_multiplier: float = InputField(
|
||||||
|
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
latents: Optional[LatentsField] = InputField(
|
latents: Optional[LatentsField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
@ -331,6 +340,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
extra=extra_conditioning_info,
|
extra=extra_conditioning_info,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
@ -389,9 +399,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||||
if control_input is None:
|
if control_input is None:
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||||
@ -904,12 +914,12 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
width: int = InputField(
|
width: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
height: int = InputField(
|
height: int = InputField(
|
||||||
ge=64,
|
ge=64,
|
||||||
multiple_of=8,
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
description=FieldDescriptions.width,
|
description=FieldDescriptions.width,
|
||||||
)
|
)
|
||||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
@ -923,7 +933,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device),
|
||||||
size=(self.height // 8, self.width // 8),
|
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
@ -1161,3 +1171,60 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, blended_latents)
|
context.services.latents.save(name, blended_latents)
|
||||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||||
|
|
||||||
|
|
||||||
|
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||||
|
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
|
||||||
|
@invocation(
|
||||||
|
"crop_latents",
|
||||||
|
title="Crop Latents",
|
||||||
|
tags=["latents", "crop"],
|
||||||
|
category="latents",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||||
|
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||||
|
class CropLatentsCoreInvocation(BaseInvocation):
|
||||||
|
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
|
||||||
|
divisible by the latent scale factor of 8.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latents: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
x: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
y: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
width: int = InputField(
|
||||||
|
ge=1,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
height: int = InputField(
|
||||||
|
ge=1,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
|
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
x1 = self.x // LATENT_SCALE_FACTOR
|
||||||
|
y1 = self.y // LATENT_SCALE_FACTOR
|
||||||
|
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
|
||||||
|
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
|
||||||
|
|
||||||
|
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.save(name, cropped_latents)
|
||||||
|
|
||||||
|
return build_latents_output(latents_name=name, latents=cropped_latents)
|
||||||
|
@ -112,7 +112,7 @@ GENERATION_MODES = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
|
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.2.0")
|
||||||
class CoreMetadataInvocation(BaseInvocation):
|
class CoreMetadataInvocation(BaseInvocation):
|
||||||
"""Collects core generation metadata into a MetadataField"""
|
"""Collects core generation metadata into a MetadataField"""
|
||||||
|
|
||||||
@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
||||||
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
||||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||||
|
cfg_rescale_multiplier: Optional[float] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
|
)
|
||||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||||
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||||
|
@ -44,7 +44,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
title="Prompts from File",
|
title="Prompts from File",
|
||||||
tags=["prompt", "file"],
|
tags=["prompt", "file"],
|
||||||
category="prompt",
|
category="prompt",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
"""Loads prompts from a text file"""
|
||||||
@ -82,7 +82,7 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
end_line = start_line + max_prompts
|
end_line = start_line + max_prompts
|
||||||
if max_prompts <= 0:
|
if max_prompts <= 0:
|
||||||
end_line = np.iinfo(np.int32).max
|
end_line = np.iinfo(np.int32).max
|
||||||
with open(file_path) as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
if i >= start_line and i < end_line:
|
if i >= start_line and i < end_line:
|
||||||
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
||||||
|
180
invokeai/app/invocations/tiles.py
Normal file
180
invokeai/app/invocations/tiles.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
WithMetadata,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||||
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
|
class TileWithImage(BaseModel):
|
||||||
|
tile: Tile
|
||||||
|
image: ImageField
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("calculate_image_tiles_output")
|
||||||
|
class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||||
|
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
|
||||||
|
class CalculateImageTilesInvocation(BaseInvocation):
|
||||||
|
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||||
|
|
||||||
|
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||||
|
image_height: int = InputField(
|
||||||
|
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||||
|
)
|
||||||
|
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
|
||||||
|
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||||
|
overlap: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
default=128,
|
||||||
|
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||||
|
tiles = calc_tiles_with_overlap(
|
||||||
|
image_height=self.image_height,
|
||||||
|
image_width=self.image_width,
|
||||||
|
tile_height=self.tile_height,
|
||||||
|
tile_width=self.tile_width,
|
||||||
|
overlap=self.overlap,
|
||||||
|
)
|
||||||
|
return CalculateImageTilesOutput(tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("tile_to_properties_output")
|
||||||
|
class TileToPropertiesOutput(BaseInvocationOutput):
|
||||||
|
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
|
||||||
|
coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.")
|
||||||
|
coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.")
|
||||||
|
coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.")
|
||||||
|
|
||||||
|
# HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this
|
||||||
|
# object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design
|
||||||
|
# principles. These fields are included, because 1) they are often useful in tiled workflows, and 2) they are
|
||||||
|
# difficult to calculate in a workflow (even though it's just a couple of subtraction nodes the graph gets
|
||||||
|
# surprisingly complicated).
|
||||||
|
width: int = OutputField(description="The width of the tile. Equal to coords_right - coords_left.")
|
||||||
|
height: int = OutputField(description="The height of the tile. Equal to coords_bottom - coords_top.")
|
||||||
|
|
||||||
|
overlap_top: int = OutputField(description="Overlap between this tile and its top neighbor.")
|
||||||
|
overlap_bottom: int = OutputField(description="Overlap between this tile and its bottom neighbor.")
|
||||||
|
overlap_left: int = OutputField(description="Overlap between this tile and its left neighbor.")
|
||||||
|
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
|
||||||
|
class TileToPropertiesInvocation(BaseInvocation):
|
||||||
|
"""Split a Tile into its individual properties."""
|
||||||
|
|
||||||
|
tile: Tile = InputField(description="The tile to split into properties.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
|
||||||
|
return TileToPropertiesOutput(
|
||||||
|
coords_left=self.tile.coords.left,
|
||||||
|
coords_right=self.tile.coords.right,
|
||||||
|
coords_top=self.tile.coords.top,
|
||||||
|
coords_bottom=self.tile.coords.bottom,
|
||||||
|
width=self.tile.coords.right - self.tile.coords.left,
|
||||||
|
height=self.tile.coords.bottom - self.tile.coords.top,
|
||||||
|
overlap_top=self.tile.overlap.top,
|
||||||
|
overlap_bottom=self.tile.overlap.bottom,
|
||||||
|
overlap_left=self.tile.overlap.left,
|
||||||
|
overlap_right=self.tile.overlap.right,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("pair_tile_image_output")
|
||||||
|
class PairTileImageOutput(BaseInvocationOutput):
|
||||||
|
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
|
||||||
|
class PairTileImageInvocation(BaseInvocation):
|
||||||
|
"""Pair an image with its tile properties."""
|
||||||
|
|
||||||
|
# TODO(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve
|
||||||
|
# order. Can this be fixed?
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The tile image.")
|
||||||
|
tile: Tile = InputField(description="The tile properties.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
|
||||||
|
return PairTileImageOutput(
|
||||||
|
tile_with_image=TileWithImage(
|
||||||
|
tile=self.tile,
|
||||||
|
image=self.image,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
|
||||||
|
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||||
|
"""Merge multiple tile images into a single image."""
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
|
||||||
|
blend_amount: int = InputField(
|
||||||
|
ge=0,
|
||||||
|
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
images = [twi.image for twi in self.tiles_with_images]
|
||||||
|
tiles = [twi.tile for twi in self.tiles_with_images]
|
||||||
|
|
||||||
|
# Infer the output image dimensions from the max/min tile limits.
|
||||||
|
height = 0
|
||||||
|
width = 0
|
||||||
|
for tile in tiles:
|
||||||
|
height = max(height, tile.coords.bottom)
|
||||||
|
width = max(width, tile.coords.right)
|
||||||
|
|
||||||
|
# Get all tile images for processing.
|
||||||
|
# TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly
|
||||||
|
# existed in memory at an earlier point in the graph.
|
||||||
|
tile_np_images: list[np.ndarray] = []
|
||||||
|
for image in images:
|
||||||
|
pil_image = context.services.images.get_pil_image(image.image_name)
|
||||||
|
pil_image = pil_image.convert("RGB")
|
||||||
|
tile_np_images.append(np.array(pil_image))
|
||||||
|
|
||||||
|
# Prepare the output image buffer.
|
||||||
|
# Check the first tile to determine how many image channels are expected in the output.
|
||||||
|
channels = tile_np_images[0].shape[-1]
|
||||||
|
dtype = tile_np_images[0].dtype
|
||||||
|
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
|
||||||
|
|
||||||
|
merge_tiles_with_linear_blending(
|
||||||
|
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||||
|
)
|
||||||
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=pil_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
metadata=self.metadata,
|
||||||
|
workflow=context.workflow,
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
@ -2,6 +2,7 @@ class FieldDescriptions:
|
|||||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||||
cfg_scale = "Classifier-Free Guidance scale"
|
cfg_scale = "Classifier-Free Guidance scale"
|
||||||
|
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||||
scheduler = "Scheduler to use during inference"
|
scheduler = "Scheduler to use during inference"
|
||||||
positive_cond = "Positive conditioning tensor"
|
positive_cond = "Positive conditioning tensor"
|
||||||
negative_cond = "Negative conditioning tensor"
|
negative_cond = "Negative conditioning tensor"
|
||||||
|
@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class MLPProjModel(torch.nn.Module):
|
||||||
|
"""SD model with image prompt"""
|
||||||
|
|
||||||
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.proj = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
||||||
|
torch.nn.LayerNorm(cross_attention_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||||
|
"""Initialize an MLPProjModel from a state_dict.
|
||||||
|
|
||||||
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MLPProjModel
|
||||||
|
"""
|
||||||
|
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||||
|
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||||
|
|
||||||
|
model = cls(cross_attention_dim, clip_embeddings_dim)
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter:
|
||||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||||
|
|
||||||
@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterFull(IPAdapterPlus):
|
||||||
|
"""IP-Adapter Plus with full features."""
|
||||||
|
|
||||||
|
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||||
|
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlusXL(IPAdapterPlus):
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus for SDXL."""
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
@ -149,11 +194,9 @@ def build_ip_adapter(
|
|||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||||
# contains.
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||||
|
|
||||||
if is_plus:
|
|
||||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
if cross_attention_dim == 768:
|
if cross_attention_dim == 768:
|
||||||
# SD1 IP-Adapter Plus
|
# SD1 IP-Adapter Plus
|
||||||
@ -163,5 +206,7 @@ def build_ip_adapter(
|
|||||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||||
|
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||||
|
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||||
|
@ -607,10 +607,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[step_index]
|
guidance_scale = guidance_scale[step_index]
|
||||||
|
|
||||||
noise_pred = self.invokeai_diffuser._combine(
|
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
|
||||||
uc_noise_pred,
|
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
|
||||||
|
if guidance_rescale_multiplier > 0:
|
||||||
|
noise_pred = self._rescale_cfg(
|
||||||
|
noise_pred,
|
||||||
c_noise_pred,
|
c_noise_pred,
|
||||||
guidance_scale,
|
guidance_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
|
||||||
|
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||||
|
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||||
|
|
||||||
|
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||||
|
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||||
|
return x_final
|
||||||
|
|
||||||
def _unet_forward(
|
def _unet_forward(
|
||||||
self,
|
self,
|
||||||
latents,
|
latents,
|
||||||
|
@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
|
|||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: BasicConditioningInfo
|
text_embeddings: BasicConditioningInfo
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
"""
|
"""
|
||||||
|
guidance_scale: Union[float, List[float]]
|
||||||
|
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||||
|
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||||
|
"""
|
||||||
|
guidance_rescale_multiplier: float = 0
|
||||||
extra: Optional[ExtraConditioningInfo] = None
|
extra: Optional[ExtraConditioningInfo] = None
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
"""
|
"""
|
||||||
|
0
invokeai/backend/tiles/__init__.py
Normal file
0
invokeai/backend/tiles/__init__.py
Normal file
201
invokeai/backend/tiles/tiles.py
Normal file
201
invokeai/backend/tiles/tiles.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import math
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
||||||
|
|
||||||
|
|
||||||
|
def calc_tiles_with_overlap(
|
||||||
|
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
|
||||||
|
) -> list[Tile]:
|
||||||
|
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_height (int): The image height in px.
|
||||||
|
image_width (int): The image width in px.
|
||||||
|
tile_height (int): The tile height in px. All tiles will have this height.
|
||||||
|
tile_width (int): The tile width in px. All tiles will have this width.
|
||||||
|
overlap (int, optional): The target overlap between adjacent tiles. If the tiles do not evenly cover the image
|
||||||
|
shape, then the last row/column of tiles will overlap more than this. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||||
|
"""
|
||||||
|
assert image_height >= tile_height
|
||||||
|
assert image_width >= tile_width
|
||||||
|
assert overlap < tile_height
|
||||||
|
assert overlap < tile_width
|
||||||
|
|
||||||
|
non_overlap_per_tile_height = tile_height - overlap
|
||||||
|
non_overlap_per_tile_width = tile_width - overlap
|
||||||
|
|
||||||
|
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
|
||||||
|
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
|
||||||
|
|
||||||
|
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||||
|
tiles: list[Tile] = []
|
||||||
|
|
||||||
|
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||||
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
for tile_idx_x in range(num_tiles_x):
|
||||||
|
tile = Tile(
|
||||||
|
coords=TBLR(
|
||||||
|
top=tile_idx_y * non_overlap_per_tile_height,
|
||||||
|
bottom=tile_idx_y * non_overlap_per_tile_height + tile_height,
|
||||||
|
left=tile_idx_x * non_overlap_per_tile_width,
|
||||||
|
right=tile_idx_x * non_overlap_per_tile_width + tile_width,
|
||||||
|
),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
if tile.coords.bottom > image_height:
|
||||||
|
# If this tile would go off the bottom of the image, shift it so that it is aligned with the bottom
|
||||||
|
# of the image.
|
||||||
|
tile.coords.bottom = image_height
|
||||||
|
tile.coords.top = image_height - tile_height
|
||||||
|
|
||||||
|
if tile.coords.right > image_width:
|
||||||
|
# If this tile would go off the right edge of the image, shift it so that it is aligned with the
|
||||||
|
# right edge of the image.
|
||||||
|
tile.coords.right = image_width
|
||||||
|
tile.coords.left = image_width - tile_width
|
||||||
|
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||||
|
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||||
|
return None
|
||||||
|
return tiles[idx_y * num_tiles_x + idx_x]
|
||||||
|
|
||||||
|
# Iterate over tiles again and calculate overlaps.
|
||||||
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
for tile_idx_x in range(num_tiles_x):
|
||||||
|
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||||
|
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||||
|
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||||
|
|
||||||
|
assert cur_tile is not None
|
||||||
|
|
||||||
|
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||||
|
if top_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||||
|
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||||
|
|
||||||
|
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||||
|
if left_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||||
|
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
def merge_tiles_with_linear_blending(
|
||||||
|
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
|
||||||
|
):
|
||||||
|
"""Merge a set of image tiles into `dst_image` with linear blending between the tiles.
|
||||||
|
|
||||||
|
We expect every tile edge to either:
|
||||||
|
1) have an overlap of 0, because it is aligned with the image edge, or
|
||||||
|
2) have an overlap >= blend_amount.
|
||||||
|
If neither of these conditions are satisfied, we raise an exception.
|
||||||
|
|
||||||
|
The linear blending is centered at the halfway point of the overlap between adjacent tiles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
|
||||||
|
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||||
|
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
|
||||||
|
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
|
||||||
|
"""
|
||||||
|
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
||||||
|
# iterate over tiles left-to-right, top-to-bottom.
|
||||||
|
tiles_and_images = list(zip(tiles, tile_images, strict=True))
|
||||||
|
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
|
||||||
|
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
|
||||||
|
|
||||||
|
# Organize tiles into rows.
|
||||||
|
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
|
||||||
|
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
|
||||||
|
first_tile_in_cur_row, _ = tiles_and_images[0]
|
||||||
|
for tile_and_image in tiles_and_images:
|
||||||
|
tile, _ = tile_and_image
|
||||||
|
if not (
|
||||||
|
tile.coords.top == first_tile_in_cur_row.coords.top
|
||||||
|
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
|
||||||
|
):
|
||||||
|
# Store the previous row, and start a new one.
|
||||||
|
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||||
|
cur_tile_and_image_row = []
|
||||||
|
first_tile_in_cur_row, _ = tile_and_image
|
||||||
|
|
||||||
|
cur_tile_and_image_row.append(tile_and_image)
|
||||||
|
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||||
|
|
||||||
|
# Prepare 1D linear gradients for blending.
|
||||||
|
gradient_left_x = np.linspace(start=0.0, stop=1.0, num=blend_amount)
|
||||||
|
gradient_top_y = np.linspace(start=0.0, stop=1.0, num=blend_amount)
|
||||||
|
# Convert shape: (blend_amount, ) -> (blend_amount, 1). The extra dimension enables the gradient to be applied
|
||||||
|
# to a 2D image via broadcasting. Note that no additional dimension is needed on gradient_left_x for
|
||||||
|
# broadcasting to work correctly.
|
||||||
|
gradient_top_y = np.expand_dims(gradient_top_y, axis=1)
|
||||||
|
|
||||||
|
for tile_and_image_row in tile_and_image_rows:
|
||||||
|
first_tile_in_row, _ = tile_and_image_row[0]
|
||||||
|
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
|
||||||
|
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
|
||||||
|
|
||||||
|
# Blend the tiles in the row horizontally.
|
||||||
|
for tile, tile_image in tile_and_image_row:
|
||||||
|
# We expect the tiles to be ordered left-to-right. For each tile, we construct a mask that applies linear
|
||||||
|
# blending to the left of the current tile. The inverse linear blending is automatically applied to the
|
||||||
|
# right of the tiles that have already been pasted by the paste(...) operation.
|
||||||
|
tile_height, tile_width, _ = tile_image.shape
|
||||||
|
mask = np.ones(shape=(tile_height, tile_width), dtype=np.float64)
|
||||||
|
|
||||||
|
# Left blending:
|
||||||
|
if tile.overlap.left > 0:
|
||||||
|
assert tile.overlap.left >= blend_amount
|
||||||
|
# Center the blending gradient in the middle of the overlap.
|
||||||
|
blend_start_left = tile.overlap.left // 2 - blend_amount // 2
|
||||||
|
# The region left of the blending region is masked completely.
|
||||||
|
mask[:, :blend_start_left] = 0.0
|
||||||
|
# Apply the blend gradient to the mask.
|
||||||
|
mask[:, blend_start_left : blend_start_left + blend_amount] = gradient_left_x
|
||||||
|
# For visual debugging:
|
||||||
|
# tile_image[:, blend_start_left : blend_start_left + blend_amount] = 0
|
||||||
|
|
||||||
|
paste(
|
||||||
|
dst_image=row_image,
|
||||||
|
src_image=tile_image,
|
||||||
|
box=TBLR(
|
||||||
|
top=0, bottom=tile.coords.bottom - tile.coords.top, left=tile.coords.left, right=tile.coords.right
|
||||||
|
),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Blend the row into the dst_image vertically.
|
||||||
|
# We construct a mask that applies linear blending to the top of the current row. The inverse linear blending is
|
||||||
|
# automatically applied to the bottom of the tiles that have already been pasted by the paste(...) operation.
|
||||||
|
mask = np.ones(shape=(row_image.shape[0], row_image.shape[1]), dtype=np.float64)
|
||||||
|
# Top blending:
|
||||||
|
# (See comments under 'Left blending' for an explanation of the logic.)
|
||||||
|
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
|
||||||
|
if first_tile_in_row.overlap.top > 0:
|
||||||
|
assert first_tile_in_row.overlap.top >= blend_amount
|
||||||
|
blend_start_top = first_tile_in_row.overlap.top // 2 - blend_amount // 2
|
||||||
|
mask[:blend_start_top, :] = 0.0
|
||||||
|
mask[blend_start_top : blend_start_top + blend_amount, :] = gradient_top_y
|
||||||
|
# For visual debugging:
|
||||||
|
# row_image[blend_start_top : blend_start_top + blend_amount, :] = 0
|
||||||
|
paste(
|
||||||
|
dst_image=dst_image,
|
||||||
|
src_image=row_image,
|
||||||
|
box=TBLR(
|
||||||
|
top=first_tile_in_row.coords.top,
|
||||||
|
bottom=first_tile_in_row.coords.bottom,
|
||||||
|
left=0,
|
||||||
|
right=row_image.shape[1],
|
||||||
|
),
|
||||||
|
mask=mask,
|
||||||
|
)
|
47
invokeai/backend/tiles/utils.py
Normal file
47
invokeai/backend/tiles/utils.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TBLR(BaseModel):
|
||||||
|
top: int
|
||||||
|
bottom: int
|
||||||
|
left: int
|
||||||
|
right: int
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
self.top == other.top
|
||||||
|
and self.bottom == other.bottom
|
||||||
|
and self.left == other.left
|
||||||
|
and self.right == other.right
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tile(BaseModel):
|
||||||
|
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
|
||||||
|
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.coords == other.coords and self.overlap == other.overlap
|
||||||
|
|
||||||
|
|
||||||
|
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
|
||||||
|
"""Paste a source image into a destination image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
|
||||||
|
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
||||||
|
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
|
||||||
|
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
||||||
|
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
|
||||||
|
`src * mask + dst * (1 - mask)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
dst_image[box.top : box.bottom, box.left : box.right] = src_image
|
||||||
|
else:
|
||||||
|
mask = np.expand_dims(mask, -1)
|
||||||
|
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
|
||||||
|
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
|
@ -75,6 +75,7 @@
|
|||||||
"framer-motion": "^10.16.4",
|
"framer-motion": "^10.16.4",
|
||||||
"i18next": "^23.6.0",
|
"i18next": "^23.6.0",
|
||||||
"i18next-http-backend": "^2.3.1",
|
"i18next-http-backend": "^2.3.1",
|
||||||
|
"idb-keyval": "^6.2.1",
|
||||||
"konva": "^9.2.3",
|
"konva": "^9.2.3",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"nanostores": "^0.9.4",
|
"nanostores": "^0.9.4",
|
||||||
|
@ -803,8 +803,7 @@
|
|||||||
"canny": "Canny",
|
"canny": "Canny",
|
||||||
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
|
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
|
||||||
"scribble": "Scribble",
|
"scribble": "Scribble",
|
||||||
"maxFaces": "Maximal Anzahl Gesichter",
|
"maxFaces": "Maximal Anzahl Gesichter"
|
||||||
"unstarImage": "Markierung aufheben"
|
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"status": "Status",
|
"status": "Status",
|
||||||
|
@ -252,7 +252,6 @@
|
|||||||
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
||||||
"showAdvanced": "Show Advanced",
|
"showAdvanced": "Show Advanced",
|
||||||
"toggleControlNet": "Toggle this ControlNet",
|
"toggleControlNet": "Toggle this ControlNet",
|
||||||
"unstarImage": "Unstar Image",
|
|
||||||
"w": "W",
|
"w": "W",
|
||||||
"weight": "Weight",
|
"weight": "Weight",
|
||||||
"enableIPAdapter": "Enable IP Adapter",
|
"enableIPAdapter": "Enable IP Adapter",
|
||||||
@ -387,6 +386,8 @@
|
|||||||
"showGenerations": "Show Generations",
|
"showGenerations": "Show Generations",
|
||||||
"showUploads": "Show Uploads",
|
"showUploads": "Show Uploads",
|
||||||
"singleColumnLayout": "Single Column Layout",
|
"singleColumnLayout": "Single Column Layout",
|
||||||
|
"starImage": "Star Image",
|
||||||
|
"unstarImage": "Unstar Image",
|
||||||
"unableToLoad": "Unable to load Gallery",
|
"unableToLoad": "Unable to load Gallery",
|
||||||
"uploads": "Uploads",
|
"uploads": "Uploads",
|
||||||
"deleteSelection": "Delete Selection",
|
"deleteSelection": "Delete Selection",
|
||||||
@ -608,6 +609,7 @@
|
|||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"cfgScale": "CFG scale",
|
"cfgScale": "CFG scale",
|
||||||
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||||
"createdBy": "Created By",
|
"createdBy": "Created By",
|
||||||
"fit": "Image to image fit",
|
"fit": "Image to image fit",
|
||||||
"generationMode": "Generation Mode",
|
"generationMode": "Generation Mode",
|
||||||
@ -986,6 +988,7 @@
|
|||||||
"unsupportedAnyOfLength": "too many union members ({{count}})",
|
"unsupportedAnyOfLength": "too many union members ({{count}})",
|
||||||
"unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}",
|
"unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}",
|
||||||
"unableToParseFieldType": "unable to parse field type",
|
"unableToParseFieldType": "unable to parse field type",
|
||||||
|
"unableToExtractEnumOptions": "unable to extract enum options",
|
||||||
"uNetField": "UNet",
|
"uNetField": "UNet",
|
||||||
"uNetFieldDescription": "UNet submodel.",
|
"uNetFieldDescription": "UNet submodel.",
|
||||||
"unhandledInputProperty": "Unhandled input property",
|
"unhandledInputProperty": "Unhandled input property",
|
||||||
@ -1041,6 +1044,8 @@
|
|||||||
"setType": "Set cancel type"
|
"setType": "Set cancel type"
|
||||||
},
|
},
|
||||||
"cfgScale": "CFG Scale",
|
"cfgScale": "CFG Scale",
|
||||||
|
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
|
||||||
|
"cfgRescale": "CFG Rescale",
|
||||||
"clipSkip": "CLIP Skip",
|
"clipSkip": "CLIP Skip",
|
||||||
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
|
||||||
"closeViewer": "Close Viewer",
|
"closeViewer": "Close Viewer",
|
||||||
@ -1482,6 +1487,12 @@
|
|||||||
"Controls how much your prompt influences the generation process."
|
"Controls how much your prompt influences the generation process."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"paramCFGRescaleMultiplier": {
|
||||||
|
"heading": "CFG Rescale Multiplier",
|
||||||
|
"paragraphs": [
|
||||||
|
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
|
||||||
|
]
|
||||||
|
},
|
||||||
"paramDenoisingStrength": {
|
"paramDenoisingStrength": {
|
||||||
"heading": "Denoising Strength",
|
"heading": "Denoising Strength",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
|
@ -1137,8 +1137,7 @@
|
|||||||
"openPose": "Openpose",
|
"openPose": "Openpose",
|
||||||
"controlAdapter_other": "Control Adapters",
|
"controlAdapter_other": "Control Adapters",
|
||||||
"lineartAnime": "Lineart Anime",
|
"lineartAnime": "Lineart Anime",
|
||||||
"canny": "Canny",
|
"canny": "Canny"
|
||||||
"unstarImage": "取消收藏图像"
|
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"status": "状态",
|
"status": "状态",
|
||||||
|
@ -21,6 +21,7 @@ import GlobalHotkeys from './GlobalHotkeys';
|
|||||||
import PreselectedImage from './PreselectedImage';
|
import PreselectedImage from './PreselectedImage';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
import { useSocketIO } from 'app/hooks/useSocketIO';
|
import { useSocketIO } from 'app/hooks/useSocketIO';
|
||||||
|
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -36,15 +37,16 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
|||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(languageSelector);
|
||||||
const logger = useLogger('system');
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const clearStorage = useClearStorage();
|
||||||
|
|
||||||
// singleton!
|
// singleton!
|
||||||
useSocketIO();
|
useSocketIO();
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
localStorage.clear();
|
clearStorage();
|
||||||
location.reload();
|
location.reload();
|
||||||
return false;
|
return false;
|
||||||
}, []);
|
}, [clearStorage]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
i18n.changeLanguage(language);
|
i18n.changeLanguage(language);
|
||||||
|
@ -7,21 +7,23 @@ import { $headerComponent } from 'app/store/nanostores/headerComponent';
|
|||||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||||
import { $projectId } from 'app/store/nanostores/projectId';
|
import { $projectId } from 'app/store/nanostores/projectId';
|
||||||
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
|
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
|
||||||
import { store } from 'app/store/store';
|
import { $store } from 'app/store/nanostores/store';
|
||||||
|
import { createStore } from 'app/store/store';
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
|
import Loading from 'common/components/Loading/Loading';
|
||||||
|
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||||
|
import 'i18n';
|
||||||
import React, {
|
import React, {
|
||||||
PropsWithChildren,
|
PropsWithChildren,
|
||||||
ReactNode,
|
ReactNode,
|
||||||
lazy,
|
lazy,
|
||||||
memo,
|
memo,
|
||||||
useEffect,
|
useEffect,
|
||||||
|
useMemo,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
import { ManagerOptions, SocketOptions } from 'socket.io-client';
|
import { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||||
import Loading from 'common/components/Loading/Loading';
|
|
||||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
|
||||||
import 'i18n';
|
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -137,6 +139,14 @@ const InvokeAIUI = ({
|
|||||||
};
|
};
|
||||||
}, [isDebugging]);
|
}, [isDebugging]);
|
||||||
|
|
||||||
|
const store = useMemo(() => {
|
||||||
|
return createStore(projectId);
|
||||||
|
}, [projectId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
$store.set(store);
|
||||||
|
}, [store]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
|
@ -9,9 +9,9 @@ import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
|
|||||||
|
|
||||||
import '@fontsource-variable/inter';
|
import '@fontsource-variable/inter';
|
||||||
import { MantineProvider } from '@mantine/core';
|
import { MantineProvider } from '@mantine/core';
|
||||||
|
import { useMantineTheme } from 'mantine-theme/theme';
|
||||||
import 'overlayscrollbars/overlayscrollbars.css';
|
import 'overlayscrollbars/overlayscrollbars.css';
|
||||||
import 'theme/css/overlayscrollbars.css';
|
import 'theme/css/overlayscrollbars.css';
|
||||||
import { useMantineTheme } from 'mantine-theme/theme';
|
|
||||||
|
|
||||||
type ThemeLocaleProviderProps = {
|
type ThemeLocaleProviderProps = {
|
||||||
children: ReactNode;
|
children: ReactNode;
|
||||||
|
@ -1,8 +1 @@
|
|||||||
export const LOCALSTORAGE_KEYS = [
|
export const STORAGE_PREFIX = '@@invokeai-';
|
||||||
'chakra-ui-color-mode',
|
|
||||||
'i18nextLng',
|
|
||||||
'ROARR_FILTER',
|
|
||||||
'ROARR_LOG',
|
|
||||||
];
|
|
||||||
|
|
||||||
export const LOCALSTORAGE_PREFIX = '@@invokeai-';
|
|
||||||
|
@ -23,16 +23,16 @@ import systemReducer from 'features/system/store/systemSlice';
|
|||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
import { Driver, rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
import { STORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
import { $store } from './nanostores/store';
|
import { createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
@ -74,16 +74,28 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'modelmanager',
|
'modelmanager',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const store = configureStore({
|
// Create a custom idb-keyval store (just needed to customize the name)
|
||||||
|
export const idbKeyValStore = createIDBKeyValStore('invoke', 'invoke-store');
|
||||||
|
|
||||||
|
// Create redux-remember driver, wrapping idb-keyval
|
||||||
|
const idbKeyValDriver: Driver = {
|
||||||
|
getItem: (key) => get(key, idbKeyValStore),
|
||||||
|
setItem: (key, value) => set(key, value, idbKeyValStore),
|
||||||
|
};
|
||||||
|
|
||||||
|
export const createStore = (uniqueStoreKey?: string) =>
|
||||||
|
configureStore({
|
||||||
reducer: rememberedRootReducer,
|
reducer: rememberedRootReducer,
|
||||||
enhancers: (existingEnhancers) => {
|
enhancers: (existingEnhancers) => {
|
||||||
return existingEnhancers
|
return existingEnhancers
|
||||||
.concat(
|
.concat(
|
||||||
rememberEnhancer(window.localStorage, rememberedKeys, {
|
rememberEnhancer(idbKeyValDriver, rememberedKeys, {
|
||||||
persistDebounce: 300,
|
persistDebounce: 300,
|
||||||
serialize,
|
serialize,
|
||||||
unserialize,
|
unserialize,
|
||||||
prefix: LOCALSTORAGE_PREFIX,
|
prefix: uniqueStoreKey
|
||||||
|
? `${STORAGE_PREFIX}${uniqueStoreKey}-`
|
||||||
|
: STORAGE_PREFIX,
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
.concat(autoBatchEnhancer());
|
.concat(autoBatchEnhancer());
|
||||||
@ -121,10 +133,11 @@ export const store = configureStore({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export type AppGetState = typeof store.getState;
|
export type AppGetState = ReturnType<
|
||||||
export type RootState = ReturnType<typeof store.getState>;
|
ReturnType<typeof createStore>['getState']
|
||||||
|
>;
|
||||||
|
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>;
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||||
export type AppDispatch = typeof store.dispatch;
|
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
|
||||||
export const stateSelector = (state: RootState) => state;
|
export const stateSelector = (state: RootState) => state;
|
||||||
$store.set(store);
|
|
||||||
|
@ -25,6 +25,7 @@ export type Feature =
|
|||||||
| 'lora'
|
| 'lora'
|
||||||
| 'noiseUseCPU'
|
| 'noiseUseCPU'
|
||||||
| 'paramCFGScale'
|
| 'paramCFGScale'
|
||||||
|
| 'paramCFGRescaleMultiplier'
|
||||||
| 'paramDenoisingStrength'
|
| 'paramDenoisingStrength'
|
||||||
| 'paramIterations'
|
| 'paramIterations'
|
||||||
| 'paramModel'
|
| 'paramModel'
|
||||||
|
12
invokeai/frontend/web/src/common/hooks/useClearStorage.ts
Normal file
12
invokeai/frontend/web/src/common/hooks/useClearStorage.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import { idbKeyValStore } from 'app/store/store';
|
||||||
|
import { clear } from 'idb-keyval';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
export const useClearStorage = () => {
|
||||||
|
const clearStorage = useCallback(() => {
|
||||||
|
clear(idbKeyValStore);
|
||||||
|
localStorage.clear();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return clearStorage;
|
||||||
|
};
|
@ -5,14 +5,19 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||||
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
|
||||||
|
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
|
||||||
|
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
|
||||||
|
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import {
|
import {
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
|
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
|
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
@ -22,11 +27,6 @@ import {
|
|||||||
useRemoveImageFromBoardMutation,
|
useRemoveImageFromBoardMutation,
|
||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
|
||||||
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
|
|
||||||
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
|
|
||||||
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
id: string;
|
id: string;
|
||||||
@ -35,13 +35,15 @@ type Props = {
|
|||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlAdapters, gallery }) => {
|
({ controlAdapters, gallery, system }) => {
|
||||||
const { pendingControlImages } = controlAdapters;
|
const { pendingControlImages } = controlAdapters;
|
||||||
const { autoAddBoardId } = gallery;
|
const { autoAddBoardId } = gallery;
|
||||||
|
const { isConnected } = system;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
pendingControlImages,
|
pendingControlImages,
|
||||||
autoAddBoardId,
|
autoAddBoardId,
|
||||||
|
isConnected,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -55,18 +57,19 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
|
const { pendingControlImages, autoAddBoardId, isConnected } =
|
||||||
|
useAppSelector(selector);
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
const { currentData: controlImage } = useGetImageDTOQuery(
|
const { currentData: controlImage, isError: isErrorControlImage } =
|
||||||
controlImageName ?? skipToken
|
useGetImageDTOQuery(controlImageName ?? skipToken);
|
||||||
);
|
|
||||||
|
|
||||||
const { currentData: processedControlImage } = useGetImageDTOQuery(
|
const {
|
||||||
processedControlImageName ?? skipToken
|
currentData: processedControlImage,
|
||||||
);
|
isError: isErrorProcessedControlImage,
|
||||||
|
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
|
||||||
|
|
||||||
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
|
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
|
||||||
const [addToBoard] = useAddImageToBoardMutation();
|
const [addToBoard] = useAddImageToBoardMutation();
|
||||||
@ -158,6 +161,17 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
|||||||
!pendingControlImages.includes(id) &&
|
!pendingControlImages.includes(id) &&
|
||||||
processorType !== 'none';
|
processorType !== 'none';
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
|
||||||
|
handleResetControlImage();
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
handleResetControlImage,
|
||||||
|
isConnected,
|
||||||
|
isErrorControlImage,
|
||||||
|
isErrorProcessedControlImage,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
onMouseEnter={handleMouseEnter}
|
onMouseEnter={handleMouseEnter}
|
||||||
|
@ -231,14 +231,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
icon={customStarUi ? customStarUi.off.icon : <MdStar />}
|
icon={customStarUi ? customStarUi.off.icon : <MdStar />}
|
||||||
onClickCapture={handleUnstarImage}
|
onClickCapture={handleUnstarImage}
|
||||||
>
|
>
|
||||||
{customStarUi ? customStarUi.off.text : t('controlnet.unstarImage')}
|
{customStarUi ? customStarUi.off.text : t('gallery.unstarImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
) : (
|
) : (
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={customStarUi ? customStarUi.on.icon : <MdStarBorder />}
|
icon={customStarUi ? customStarUi.on.icon : <MdStarBorder />}
|
||||||
onClickCapture={handleStarImage}
|
onClickCapture={handleStarImage}
|
||||||
>
|
>
|
||||||
{customStarUi ? customStarUi.on.text : `Star Image`}
|
{customStarUi ? customStarUi.on.text : t('gallery.starImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
)}
|
)}
|
||||||
<MenuItem
|
<MenuItem
|
||||||
|
@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallNegativePrompt,
|
recallNegativePrompt,
|
||||||
recallSeed,
|
recallSeed,
|
||||||
recallCfgScale,
|
recallCfgScale,
|
||||||
|
recallCfgRescaleMultiplier,
|
||||||
recallModel,
|
recallModel,
|
||||||
recallScheduler,
|
recallScheduler,
|
||||||
recallVaeModel,
|
recallVaeModel,
|
||||||
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallCfgScale(metadata?.cfg_scale);
|
recallCfgScale(metadata?.cfg_scale);
|
||||||
}, [metadata?.cfg_scale, recallCfgScale]);
|
}, [metadata?.cfg_scale, recallCfgScale]);
|
||||||
|
|
||||||
|
const handleRecallCfgRescaleMultiplier = useCallback(() => {
|
||||||
|
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
|
||||||
|
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
|
||||||
|
|
||||||
const handleRecallStrength = useCallback(() => {
|
const handleRecallStrength = useCallback(() => {
|
||||||
recallStrength(metadata?.strength);
|
recallStrength(metadata?.strength);
|
||||||
}, [metadata?.strength, recallStrength]);
|
}, [metadata?.strength, recallStrength]);
|
||||||
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={handleRecallCfgScale}
|
onClick={handleRecallCfgScale}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{metadata.cfg_rescale_multiplier !== undefined &&
|
||||||
|
metadata.cfg_rescale_multiplier !== null && (
|
||||||
|
<ImageMetadataItem
|
||||||
|
label={t('metadata.cfgRescaleMultiplier')}
|
||||||
|
value={metadata.cfg_rescale_multiplier}
|
||||||
|
onClick={handleRecallCfgRescaleMultiplier}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{metadata.strength && (
|
{metadata.strength && (
|
||||||
<ImageMetadataItem
|
<ImageMetadataItem
|
||||||
label={t('metadata.strength')}
|
label={t('metadata.strength')}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { Flex, Text } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||||
import {
|
import {
|
||||||
@ -13,7 +13,7 @@ import {
|
|||||||
ImageFieldInputTemplate,
|
ImageFieldInputTemplate,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
@ -24,8 +24,8 @@ const ImageFieldInputComponent = (
|
|||||||
) => {
|
) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||||
field.value?.image_name ?? skipToken
|
field.value?.image_name ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -67,6 +67,12 @@ const ImageFieldInputComponent = (
|
|||||||
[nodeId, field.name]
|
[nodeId, field.name]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isConnected && isError) {
|
||||||
|
handleReset();
|
||||||
|
}
|
||||||
|
}, [handleReset, isConnected, isError]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
className="nodrag"
|
className="nodrag"
|
||||||
|
@ -43,10 +43,10 @@ export class NodeUpdateError extends Error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FieldTypeParseError
|
* FieldParseError
|
||||||
* Raised when a field cannot be parsed from a field schema.
|
* Raised when a field cannot be parsed from a field schema.
|
||||||
*/
|
*/
|
||||||
export class FieldTypeParseError extends Error {
|
export class FieldParseError extends Error {
|
||||||
/**
|
/**
|
||||||
* Create FieldTypeParseError
|
* Create FieldTypeParseError
|
||||||
* @param {String} message
|
* @param {String} message
|
||||||
@ -56,18 +56,3 @@ export class FieldTypeParseError extends Error {
|
|||||||
this.name = this.constructor.name;
|
this.name = this.constructor.name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* UnsupportedFieldTypeError
|
|
||||||
* Raised when an unsupported field type is parsed.
|
|
||||||
*/
|
|
||||||
export class UnsupportedFieldTypeError extends Error {
|
|
||||||
/**
|
|
||||||
* Create UnsupportedFieldTypeError
|
|
||||||
* @param {String} message
|
|
||||||
*/
|
|
||||||
constructor(message: string) {
|
|
||||||
super(message);
|
|
||||||
this.name = this.constructor.name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -51,6 +51,7 @@ export const zCoreMetadata = z
|
|||||||
seed: z.number().int().nullish().catch(null),
|
seed: z.number().int().nullish().catch(null),
|
||||||
rand_device: z.string().nullish().catch(null),
|
rand_device: z.string().nullish().catch(null),
|
||||||
cfg_scale: z.number().nullish().catch(null),
|
cfg_scale: z.number().nullish().catch(null),
|
||||||
|
cfg_rescale_multiplier: z.number().nullish().catch(null),
|
||||||
steps: z.number().int().nullish().catch(null),
|
steps: z.number().int().nullish().catch(null),
|
||||||
scheduler: z.string().nullish().catch(null),
|
scheduler: z.string().nullish().catch(null),
|
||||||
clip_skip: z.number().int().nullish().catch(null),
|
clip_skip: z.number().int().nullish().catch(null),
|
||||||
|
@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
width: !isUsingScaledDimensions
|
width: !isUsingScaledDimensions
|
||||||
? width
|
? width
|
||||||
: scaledBoundingBoxDimensions.width,
|
: scaledBoundingBoxDimensions.width,
|
||||||
|
@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'img2img',
|
generation_mode: 'img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'sdxl_img2img',
|
generation_mode: 'sdxl_img2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'sdxl_txt2img',
|
generation_mode: 'sdxl_txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
negativePrompt,
|
negativePrompt,
|
||||||
model,
|
model,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
|
cfgRescaleMultiplier: cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
id: DENOISE_LATENTS,
|
id: DENOISE_LATENTS,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
denoising_start: 0,
|
denoising_start: 0,
|
||||||
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
{
|
{
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
|
@ -23,7 +23,12 @@ import {
|
|||||||
VAEModelFieldInputTemplate,
|
VAEModelFieldInputTemplate,
|
||||||
isStatefulFieldType,
|
isStatefulFieldType,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import { InvocationFieldSchema } from 'features/nodes/types/openapi';
|
import {
|
||||||
|
InvocationFieldSchema,
|
||||||
|
isSchemaObject,
|
||||||
|
} from 'features/nodes/types/openapi';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { FieldParseError } from 'features/nodes/types/error';
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
|
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
|
||||||
@ -321,7 +326,28 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder<
|
|||||||
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<
|
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<
|
||||||
EnumFieldInputTemplate
|
EnumFieldInputTemplate
|
||||||
> = ({ schemaObject, baseField, isCollection, isCollectionOrScalar }) => {
|
> = ({ schemaObject, baseField, isCollection, isCollectionOrScalar }) => {
|
||||||
const options = schemaObject.enum ?? [];
|
let options: EnumFieldInputTemplate['options'] = [];
|
||||||
|
if (schemaObject.anyOf) {
|
||||||
|
const filteredAnyOf = schemaObject.anyOf.filter((i) => {
|
||||||
|
if (isSchemaObject(i)) {
|
||||||
|
if (i.type === 'null') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
const firstAnyOf = filteredAnyOf[0];
|
||||||
|
if (filteredAnyOf.length !== 1 || !isSchemaObject(firstAnyOf)) {
|
||||||
|
options = [];
|
||||||
|
} else {
|
||||||
|
options = firstAnyOf.enum ?? [];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
options = schemaObject.enum ?? [];
|
||||||
|
}
|
||||||
|
if (options.length === 0) {
|
||||||
|
throw new FieldParseError(t('nodes.unableToExtractEnumOptions'));
|
||||||
|
}
|
||||||
const template: EnumFieldInputTemplate = {
|
const template: EnumFieldInputTemplate = {
|
||||||
...baseField,
|
...baseField,
|
||||||
type: {
|
type: {
|
||||||
|
@ -1,10 +1,4 @@
|
|||||||
import { t } from 'i18next';
|
import { FieldParseError } from 'features/nodes/types/error';
|
||||||
import { isArray } from 'lodash-es';
|
|
||||||
import { OpenAPIV3_1 } from 'openapi-types';
|
|
||||||
import {
|
|
||||||
FieldTypeParseError,
|
|
||||||
UnsupportedFieldTypeError,
|
|
||||||
} from 'features/nodes/types/error';
|
|
||||||
import { FieldType } from 'features/nodes/types/field';
|
import { FieldType } from 'features/nodes/types/field';
|
||||||
import {
|
import {
|
||||||
OpenAPIV3_1SchemaOrRef,
|
OpenAPIV3_1SchemaOrRef,
|
||||||
@ -14,6 +8,9 @@ import {
|
|||||||
isRefObject,
|
isRefObject,
|
||||||
isSchemaObject,
|
isSchemaObject,
|
||||||
} from 'features/nodes/types/openapi';
|
} from 'features/nodes/types/openapi';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { isArray } from 'lodash-es';
|
||||||
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transforms an invocation output ref object to field type.
|
* Transforms an invocation output ref object to field type.
|
||||||
@ -70,7 +67,7 @@ export const parseFieldType = (
|
|||||||
// This is a single ref type
|
// This is a single ref type
|
||||||
const name = refObjectToSchemaName(allOf[0]);
|
const name = refObjectToSchemaName(allOf[0]);
|
||||||
if (!name) {
|
if (!name) {
|
||||||
throw new FieldTypeParseError(
|
throw new FieldParseError(
|
||||||
t('nodes.unableToExtractSchemaNameFromRef')
|
t('nodes.unableToExtractSchemaNameFromRef')
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -95,7 +92,7 @@ export const parseFieldType = (
|
|||||||
if (isRefObject(filteredAnyOf[0])) {
|
if (isRefObject(filteredAnyOf[0])) {
|
||||||
const name = refObjectToSchemaName(filteredAnyOf[0]);
|
const name = refObjectToSchemaName(filteredAnyOf[0]);
|
||||||
if (!name) {
|
if (!name) {
|
||||||
throw new FieldTypeParseError(
|
throw new FieldParseError(
|
||||||
t('nodes.unableToExtractSchemaNameFromRef')
|
t('nodes.unableToExtractSchemaNameFromRef')
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -120,7 +117,7 @@ export const parseFieldType = (
|
|||||||
|
|
||||||
if (filteredAnyOf.length !== 2) {
|
if (filteredAnyOf.length !== 2) {
|
||||||
// This is a union of more than 2 types, which we don't support
|
// This is a union of more than 2 types, which we don't support
|
||||||
throw new UnsupportedFieldTypeError(
|
throw new FieldParseError(
|
||||||
t('nodes.unsupportedAnyOfLength', {
|
t('nodes.unsupportedAnyOfLength', {
|
||||||
count: filteredAnyOf.length,
|
count: filteredAnyOf.length,
|
||||||
})
|
})
|
||||||
@ -167,7 +164,7 @@ export const parseFieldType = (
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new UnsupportedFieldTypeError(
|
throw new FieldParseError(
|
||||||
t('nodes.unsupportedMismatchedUnion', {
|
t('nodes.unsupportedMismatchedUnion', {
|
||||||
firstType,
|
firstType,
|
||||||
secondType,
|
secondType,
|
||||||
@ -186,7 +183,7 @@ export const parseFieldType = (
|
|||||||
if (isSchemaObject(schemaObject.items)) {
|
if (isSchemaObject(schemaObject.items)) {
|
||||||
const itemType = schemaObject.items.type;
|
const itemType = schemaObject.items.type;
|
||||||
if (!itemType || isArray(itemType)) {
|
if (!itemType || isArray(itemType)) {
|
||||||
throw new UnsupportedFieldTypeError(
|
throw new FieldParseError(
|
||||||
t('nodes.unsupportedArrayItemType', {
|
t('nodes.unsupportedArrayItemType', {
|
||||||
type: itemType,
|
type: itemType,
|
||||||
})
|
})
|
||||||
@ -196,7 +193,7 @@ export const parseFieldType = (
|
|||||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
|
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
|
||||||
if (!name) {
|
if (!name) {
|
||||||
// it's 'null', 'object', or 'array' - skip
|
// it's 'null', 'object', or 'array' - skip
|
||||||
throw new UnsupportedFieldTypeError(
|
throw new FieldParseError(
|
||||||
t('nodes.unsupportedArrayItemType', {
|
t('nodes.unsupportedArrayItemType', {
|
||||||
type: itemType,
|
type: itemType,
|
||||||
})
|
})
|
||||||
@ -212,7 +209,7 @@ export const parseFieldType = (
|
|||||||
// This is a ref object, extract the type name
|
// This is a ref object, extract the type name
|
||||||
const name = refObjectToSchemaName(schemaObject.items);
|
const name = refObjectToSchemaName(schemaObject.items);
|
||||||
if (!name) {
|
if (!name) {
|
||||||
throw new FieldTypeParseError(
|
throw new FieldParseError(
|
||||||
t('nodes.unableToExtractSchemaNameFromRef')
|
t('nodes.unableToExtractSchemaNameFromRef')
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -226,7 +223,7 @@ export const parseFieldType = (
|
|||||||
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
|
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
|
||||||
if (!name) {
|
if (!name) {
|
||||||
// it's 'null', 'object', or 'array' - skip
|
// it's 'null', 'object', or 'array' - skip
|
||||||
throw new UnsupportedFieldTypeError(
|
throw new FieldParseError(
|
||||||
t('nodes.unsupportedArrayItemType', {
|
t('nodes.unsupportedArrayItemType', {
|
||||||
type: schemaObject.type,
|
type: schemaObject.type,
|
||||||
})
|
})
|
||||||
@ -242,9 +239,7 @@ export const parseFieldType = (
|
|||||||
} else if (isRefObject(schemaObject)) {
|
} else if (isRefObject(schemaObject)) {
|
||||||
const name = refObjectToSchemaName(schemaObject);
|
const name = refObjectToSchemaName(schemaObject);
|
||||||
if (!name) {
|
if (!name) {
|
||||||
throw new FieldTypeParseError(
|
throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
|
||||||
t('nodes.unableToExtractSchemaNameFromRef')
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
name,
|
name,
|
||||||
@ -252,5 +247,5 @@ export const parseFieldType = (
|
|||||||
isCollectionOrScalar: false,
|
isCollectionOrScalar: false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
throw new FieldTypeParseError(t('nodes.unableToParseFieldType'));
|
throw new FieldParseError(t('nodes.unableToParseFieldType'));
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { t } from 'i18next';
|
import { FieldParseError } from 'features/nodes/types/error';
|
||||||
import { reduce } from 'lodash-es';
|
|
||||||
import { OpenAPIV3_1 } from 'openapi-types';
|
|
||||||
import {
|
|
||||||
FieldTypeParseError,
|
|
||||||
UnsupportedFieldTypeError,
|
|
||||||
} from 'features/nodes/types/error';
|
|
||||||
import {
|
import {
|
||||||
FieldInputTemplate,
|
FieldInputTemplate,
|
||||||
FieldOutputTemplate,
|
FieldOutputTemplate,
|
||||||
@ -18,6 +12,9 @@ import {
|
|||||||
isInvocationOutputSchemaObject,
|
isInvocationOutputSchemaObject,
|
||||||
isInvocationSchemaObject,
|
isInvocationSchemaObject,
|
||||||
} from 'features/nodes/types/openapi';
|
} from 'features/nodes/types/openapi';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { reduce } from 'lodash-es';
|
||||||
|
import { OpenAPIV3_1 } from 'openapi-types';
|
||||||
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
|
import { buildFieldInputTemplate } from './buildFieldInputTemplate';
|
||||||
import { buildFieldOutputTemplate } from './buildFieldOutputTemplate';
|
import { buildFieldOutputTemplate } from './buildFieldOutputTemplate';
|
||||||
import { parseFieldType } from './parseFieldType';
|
import { parseFieldType } from './parseFieldType';
|
||||||
@ -126,10 +123,7 @@ export const parseSchema = (
|
|||||||
|
|
||||||
inputsAccumulator[propertyName] = fieldInputTemplate;
|
inputsAccumulator[propertyName] = fieldInputTemplate;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (
|
if (e instanceof FieldParseError) {
|
||||||
e instanceof FieldTypeParseError ||
|
|
||||||
e instanceof UnsupportedFieldTypeError
|
|
||||||
) {
|
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
{
|
{
|
||||||
node: type,
|
node: type,
|
||||||
@ -218,10 +212,7 @@ export const parseSchema = (
|
|||||||
|
|
||||||
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
outputsAccumulator[propertyName] = fieldOutputTemplate;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (
|
if (e instanceof FieldParseError) {
|
||||||
e instanceof FieldTypeParseError ||
|
|
||||||
e instanceof UnsupportedFieldTypeError
|
|
||||||
) {
|
|
||||||
logger('nodes').warn(
|
logger('nodes').warn(
|
||||||
{
|
{
|
||||||
node: type,
|
node: type,
|
||||||
|
@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
|
||||||
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
|
||||||
import ParamClipSkip from './ParamClipSkip';
|
import ParamClipSkip from './ParamClipSkip';
|
||||||
|
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
(state: RootState) => {
|
(state: RootState) => {
|
||||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
const {
|
||||||
state.generation;
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise };
|
return {
|
||||||
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
export default function ParamAdvancedCollapse() {
|
export default function ParamAdvancedCollapse() {
|
||||||
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } =
|
const {
|
||||||
useAppSelector(selector);
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
} = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const activeLabel = useMemo(() => {
|
const activeLabel = useMemo(() => {
|
||||||
const activeLabel: string[] = [];
|
const activeLabel: string[] = [];
|
||||||
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
|
|||||||
activeLabel.push(t('parameters.seamlessY'));
|
activeLabel.push(t('parameters.seamlessY'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cfgRescaleMultiplier) {
|
||||||
|
activeLabel.push(t('parameters.cfgRescale'));
|
||||||
|
}
|
||||||
|
|
||||||
return activeLabel.join(', ');
|
return activeLabel.join(', ');
|
||||||
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]);
|
}, [
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
clipSkip,
|
||||||
|
model,
|
||||||
|
seamlessXAxis,
|
||||||
|
seamlessYAxis,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
t,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
|
||||||
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<ParamCpuNoiseToggle />
|
<ParamCpuNoiseToggle />
|
||||||
|
<Divider />
|
||||||
|
<ParamCFGRescaleMultiplier />
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ generation, hotkeys }) => {
|
||||||
|
const { cfgRescaleMultiplier } = generation;
|
||||||
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
|
return {
|
||||||
|
cfgRescaleMultiplier,
|
||||||
|
shift,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamCFGRescaleMultiplier = () => {
|
||||||
|
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(
|
||||||
|
() => dispatch(setCfgRescaleMultiplier(0)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
|
||||||
|
<IAISlider
|
||||||
|
label={t('parameters.cfgRescaleMultiplier')}
|
||||||
|
step={shift ? 0.01 : 0.05}
|
||||||
|
min={0}
|
||||||
|
max={0.99}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={cfgRescaleMultiplier}
|
||||||
|
sliderNumberInputProps={{ max: 0.99 }}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
isInteger={false}
|
||||||
|
/>
|
||||||
|
</IAIInformationalPopover>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamCFGRescaleMultiplier);
|
@ -1,7 +1,7 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
@ -9,25 +9,30 @@ import {
|
|||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
import { memo, useMemo } from 'react';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { memo, useEffect, useMemo } from 'react';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
(state) => {
|
(state) => {
|
||||||
const { initialImage } = state.generation;
|
const { initialImage } = state.generation;
|
||||||
|
const { isConnected } = state.system;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
initialImage,
|
initialImage,
|
||||||
isResetButtonDisabled: !initialImage,
|
isResetButtonDisabled: !initialImage,
|
||||||
|
isConnected,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const InitialImage = () => {
|
const InitialImage = () => {
|
||||||
const { initialImage } = useAppSelector(selector);
|
const dispatch = useAppDispatch();
|
||||||
|
const { initialImage, isConnected } = useAppSelector(selector);
|
||||||
|
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||||
initialImage?.imageName ?? skipToken
|
initialImage?.imageName ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -49,6 +54,13 @@ const InitialImage = () => {
|
|||||||
[]
|
[]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isError && isConnected) {
|
||||||
|
// The image doesn't exist, reset init image
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
}
|
||||||
|
}, [dispatch, isConnected, isError]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
imageDTO={imageDTO}
|
imageDTO={imageDTO}
|
||||||
|
@ -57,6 +57,7 @@ import {
|
|||||||
modelSelected,
|
modelSelected,
|
||||||
} from 'features/parameters/store/actions';
|
} from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
|
setCfgRescaleMultiplier,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setHrfEnabled,
|
setHrfEnabled,
|
||||||
@ -94,6 +95,7 @@ import {
|
|||||||
isParameterStrength,
|
isParameterStrength,
|
||||||
isParameterVAEModel,
|
isParameterVAEModel,
|
||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
|
isParameterCFGRescaleMultiplier,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
|
|||||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall CFG rescale multiplier with toast
|
||||||
|
*/
|
||||||
|
const recallCfgRescaleMultiplier = useCallback(
|
||||||
|
(cfgRescaleMultiplier: unknown) => {
|
||||||
|
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
|
||||||
|
parameterNotSetToast();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||||
|
);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Recall model with toast
|
* Recall model with toast
|
||||||
*/
|
*/
|
||||||
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
|
cfg_rescale_multiplier,
|
||||||
height,
|
height,
|
||||||
model,
|
model,
|
||||||
positive_prompt,
|
positive_prompt,
|
||||||
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
|
|||||||
dispatch(setCfgScale(cfg_scale));
|
dispatch(setCfgScale(cfg_scale));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||||
|
}
|
||||||
|
|
||||||
if (isParameterModel(model)) {
|
if (isParameterModel(model)) {
|
||||||
dispatch(modelSelected(model));
|
dispatch(modelSelected(model));
|
||||||
}
|
}
|
||||||
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
|
|||||||
recallSDXLNegativeStylePrompt,
|
recallSDXLNegativeStylePrompt,
|
||||||
recallSeed,
|
recallSeed,
|
||||||
recallCfgScale,
|
recallCfgScale,
|
||||||
|
recallCfgRescaleMultiplier,
|
||||||
recallModel,
|
recallModel,
|
||||||
recallScheduler,
|
recallScheduler,
|
||||||
recallVaeModel,
|
recallVaeModel,
|
||||||
|
@ -24,6 +24,7 @@ import {
|
|||||||
ParameterVAEModel,
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
zParameterModel,
|
zParameterModel,
|
||||||
|
ParameterCFGRescaleMultiplier,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
@ -31,6 +32,7 @@ export interface GenerationState {
|
|||||||
hrfStrength: ParameterStrength;
|
hrfStrength: ParameterStrength;
|
||||||
hrfMethod: ParameterHRFMethod;
|
hrfMethod: ParameterHRFMethod;
|
||||||
cfgScale: ParameterCFGScale;
|
cfgScale: ParameterCFGScale;
|
||||||
|
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
|
||||||
height: ParameterHeight;
|
height: ParameterHeight;
|
||||||
img2imgStrength: ParameterStrength;
|
img2imgStrength: ParameterStrength;
|
||||||
infillMethod: string;
|
infillMethod: string;
|
||||||
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
hrfEnabled: false,
|
hrfEnabled: false,
|
||||||
hrfMethod: 'ESRGAN',
|
hrfMethod: 'ESRGAN',
|
||||||
cfgScale: 7.5,
|
cfgScale: 7.5,
|
||||||
|
cfgRescaleMultiplier: 0,
|
||||||
height: 512,
|
height: 512,
|
||||||
img2imgStrength: 0.75,
|
img2imgStrength: 0.75,
|
||||||
infillMethod: 'patchmatch',
|
infillMethod: 'patchmatch',
|
||||||
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
|
|||||||
state.steps
|
state.steps
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
setCfgScale: (state, action: PayloadAction<number>) => {
|
setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
|
||||||
state.cfgScale = action.payload;
|
state.cfgScale = action.payload;
|
||||||
},
|
},
|
||||||
|
setCfgRescaleMultiplier: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<ParameterCFGRescaleMultiplier>
|
||||||
|
) => {
|
||||||
|
state.cfgRescaleMultiplier = action.payload;
|
||||||
|
},
|
||||||
setThreshold: (state, action: PayloadAction<number>) => {
|
setThreshold: (state, action: PayloadAction<number>) => {
|
||||||
state.threshold = action.payload;
|
state.threshold = action.payload;
|
||||||
},
|
},
|
||||||
@ -336,6 +345,7 @@ export const {
|
|||||||
resetParametersState,
|
resetParametersState,
|
||||||
resetSeed,
|
resetSeed,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
|
setCfgRescaleMultiplier,
|
||||||
setWidth,
|
setWidth,
|
||||||
setHeight,
|
setHeight,
|
||||||
toggleSize,
|
toggleSize,
|
||||||
|
@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
|
|||||||
zParameterCFGScale.safeParse(val).success;
|
zParameterCFGScale.safeParse(val).success;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|
||||||
|
// #region CFG Rescale Multiplier
|
||||||
|
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
|
||||||
|
export type ParameterCFGRescaleMultiplier = z.infer<
|
||||||
|
typeof zParameterCFGRescaleMultiplier
|
||||||
|
>;
|
||||||
|
export const isParameterCFGRescaleMultiplier = (
|
||||||
|
val: unknown
|
||||||
|
): val is ParameterCFGRescaleMultiplier =>
|
||||||
|
zParameterCFGRescaleMultiplier.safeParse(val).success;
|
||||||
|
// #endregion
|
||||||
|
|
||||||
// #region Scheduler
|
// #region Scheduler
|
||||||
export const zParameterScheduler = zSchedulerField;
|
export const zParameterScheduler = zSchedulerField;
|
||||||
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
export type ParameterScheduler = z.infer<typeof zParameterScheduler>;
|
||||||
|
@ -14,11 +14,11 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { VALID_LOG_LEVELS } from 'app/logging/logger';
|
import { VALID_LOG_LEVELS } from 'app/logging/logger';
|
||||||
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||||
import {
|
import {
|
||||||
consoleLogLevelChanged,
|
consoleLogLevelChanged,
|
||||||
setEnableImageDebugging,
|
setEnableImageDebugging,
|
||||||
@ -164,20 +164,14 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
|
|||||||
shouldEnableInformationalPopovers,
|
shouldEnableInformationalPopovers,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
|
const clearStorage = useClearStorage();
|
||||||
|
|
||||||
const handleClickResetWebUI = useCallback(() => {
|
const handleClickResetWebUI = useCallback(() => {
|
||||||
// Only remove our keys
|
clearStorage();
|
||||||
Object.keys(window.localStorage).forEach((key) => {
|
|
||||||
if (
|
|
||||||
LOCALSTORAGE_KEYS.includes(key) ||
|
|
||||||
key.startsWith(LOCALSTORAGE_PREFIX)
|
|
||||||
) {
|
|
||||||
localStorage.removeItem(key);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
onSettingsModalClose();
|
onSettingsModalClose();
|
||||||
onRefreshModalOpen();
|
onRefreshModalOpen();
|
||||||
setInterval(() => setCountdown((prev) => prev - 1), 1000);
|
setInterval(() => setCountdown((prev) => prev - 1), 1000);
|
||||||
}, [onSettingsModalClose, onRefreshModalOpen]);
|
}, [clearStorage, onSettingsModalClose, onRefreshModalOpen]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (countdown <= 0) {
|
if (countdown <= 0) {
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import WorkflowLibraryList from 'features/workflowLibrary/components/WorkflowLibraryList';
|
import WorkflowLibraryList from 'features/workflowLibrary/components/WorkflowLibraryList';
|
||||||
import WorkflowLibraryListWrapper from 'features/workflowLibrary/components/WorkflowLibraryListWrapper';
|
import WorkflowLibraryListWrapper from 'features/workflowLibrary/components/WorkflowLibraryListWrapper';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const WorkflowLibraryContent = () => {
|
const WorkflowLibraryContent = () => {
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<WorkflowLibraryListWrapper>
|
<WorkflowLibraryListWrapper>
|
||||||
<WorkflowLibraryList />
|
<WorkflowLibraryList />
|
||||||
|
441
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
441
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -4158,6 +4158,11 @@ i18next@^23.6.0:
|
|||||||
dependencies:
|
dependencies:
|
||||||
"@babel/runtime" "^7.22.5"
|
"@babel/runtime" "^7.22.5"
|
||||||
|
|
||||||
|
idb-keyval@^6.2.1:
|
||||||
|
version "6.2.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/idb-keyval/-/idb-keyval-6.2.1.tgz#94516d625346d16f56f3b33855da11bfded2db33"
|
||||||
|
integrity sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==
|
||||||
|
|
||||||
ieee754@^1.1.13:
|
ieee754@^1.1.13:
|
||||||
version "1.2.1"
|
version "1.2.1"
|
||||||
resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352"
|
resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352"
|
||||||
|
@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device):
|
|||||||
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
"unet_model_name": "stable-diffusion-v1-5",
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
},
|
},
|
||||||
|
# SD1.5, IPAdapterFull
|
||||||
|
{
|
||||||
|
"ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
|
||||||
|
"ip_adapter_model_name": "ip-adapter-full-face_sd15",
|
||||||
|
"base_model": BaseModelType.StableDiffusion1,
|
||||||
|
"unet_model_id": "runwayml/stable-diffusion-v1-5",
|
||||||
|
"unet_model_name": "stable-diffusion-v1-5",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
224
tests/backend/tiles/test_tiles.py
Normal file
224
tests/backend/tiles/test_tiles.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Test calc_tiles_with_overlap(...)
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_single_tile():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when a single tile covers the image."""
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0))
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_evenly_divisible():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles."""
|
||||||
|
# Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles.
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)),
|
||||||
|
# Row 1
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)),
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_not_evenly_divisible():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
|
||||||
|
# Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps.
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)),
|
||||||
|
# Row 1
|
||||||
|
Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272)
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"],
|
||||||
|
[
|
||||||
|
(128, 128, 128, 128, 127, False), # OK
|
||||||
|
(128, 128, 128, 128, 0, False), # OK
|
||||||
|
(128, 128, 64, 64, 0, False), # OK
|
||||||
|
(128, 128, 129, 128, 0, True), # tile_height exceeds image_height.
|
||||||
|
(128, 128, 128, 129, 0, True), # tile_width exceeds image_width.
|
||||||
|
(128, 128, 64, 128, 64, True), # overlap equals tile_height.
|
||||||
|
(128, 128, 128, 64, 64, True), # overlap equals tile_width.
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_tiles_with_overlap_input_validation(
|
||||||
|
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool
|
||||||
|
):
|
||||||
|
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
|
||||||
|
if raises:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
||||||
|
else:
|
||||||
|
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
# Test merge_tiles_with_linear_blending(...)
|
||||||
|
#############################################
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("blend_amount", [0, 32])
|
||||||
|
def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int):
|
||||||
|
"""Test merge_tiles_with_linear_blending(...) behavior when merging horizontally."""
|
||||||
|
# Initialize 2 tiles side-by-side.
|
||||||
|
tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Prepare tile_images that match tiles. Pixel values are set based on the tile index.
|
||||||
|
tile_images = [
|
||||||
|
np.zeros((512, 512, 3)) + 64,
|
||||||
|
np.zeros((512, 512, 3)) + 128,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate expected output.
|
||||||
|
expected_output = np.zeros((512, 960, 3), dtype=np.uint8)
|
||||||
|
expected_output[:, : 480 - (blend_amount // 2), :] = 64
|
||||||
|
if blend_amount > 0:
|
||||||
|
gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((1, blend_amount, 1))
|
||||||
|
expected_output[:, 480 - (blend_amount // 2) : 480 + (blend_amount // 2), :] = gradient
|
||||||
|
expected_output[:, 480 + (blend_amount // 2) :, :] = 128
|
||||||
|
|
||||||
|
merge_tiles_with_linear_blending(
|
||||||
|
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("blend_amount", [0, 32])
|
||||||
|
def test_merge_tiles_with_linear_blending_vertical(blend_amount: int):
|
||||||
|
"""Test merge_tiles_with_linear_blending(...) behavior when merging vertically."""
|
||||||
|
# Initialize 2 tiles stacked vertically.
|
||||||
|
tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
|
||||||
|
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Prepare tile_images that match tiles. Pixel values are set based on the tile index.
|
||||||
|
tile_images = [
|
||||||
|
np.zeros((512, 512, 3)) + 64,
|
||||||
|
np.zeros((512, 512, 3)) + 128,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate expected output.
|
||||||
|
expected_output = np.zeros((960, 512, 3), dtype=np.uint8)
|
||||||
|
expected_output[: 480 - (blend_amount // 2), :, :] = 64
|
||||||
|
if blend_amount > 0:
|
||||||
|
gradient = np.linspace(start=64, stop=128, num=blend_amount, dtype=np.uint8).reshape((blend_amount, 1, 1))
|
||||||
|
expected_output[480 - (blend_amount // 2) : 480 + (blend_amount // 2), :, :] = gradient
|
||||||
|
expected_output[480 + (blend_amount // 2) :, :, :] = 128
|
||||||
|
|
||||||
|
merge_tiles_with_linear_blending(
|
||||||
|
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
|
||||||
|
)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap():
|
||||||
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between
|
||||||
|
any vertically adjacent tiles.
|
||||||
|
"""
|
||||||
|
# Initialize 2 tiles stacked vertically.
|
||||||
|
tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
|
||||||
|
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Prepare tile_images that match tiles.
|
||||||
|
tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))]
|
||||||
|
|
||||||
|
# blend_amount=128 exceeds overlap of 64, so should raise exception.
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overlap():
|
||||||
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if 'blend_amount' exceeds the overlap between
|
||||||
|
any horizontally adjacent tiles.
|
||||||
|
"""
|
||||||
|
# Initialize 2 tiles side-by-side.
|
||||||
|
tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Prepare tile_images that match tiles.
|
||||||
|
tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))]
|
||||||
|
|
||||||
|
# blend_amount=128 exceeds overlap of 64, so should raise exception.
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image():
|
||||||
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the
|
||||||
|
dst_image.
|
||||||
|
"""
|
||||||
|
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
|
||||||
|
|
||||||
|
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Prepare tile_images that match tiles.
|
||||||
|
tile_images = [np.zeros((512, 512, 3))]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_tiles_with_linear_blending_mismatched_list_lengths():
|
||||||
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images'
|
||||||
|
do not match.
|
||||||
|
"""
|
||||||
|
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
|
||||||
|
|
||||||
|
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# tile_images is longer than tiles, so should cause an exception.
|
||||||
|
tile_images = [np.zeros((512, 512, 3)), np.zeros((512, 512, 3))]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
merge_tiles_with_linear_blending(dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=0)
|
101
tests/backend/tiles/test_utils.py
Normal file
101
tests/backend/tiles/test_utils.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, paste
|
||||||
|
|
||||||
|
|
||||||
|
def test_paste_no_mask_success():
|
||||||
|
"""Test successful paste with mask=None."""
|
||||||
|
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Create src_image with a pattern that can be used to validate that it was pasted correctly.
|
||||||
|
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
|
||||||
|
src_image[0, :, 0] = 1 # Row of 1s in channel 0.
|
||||||
|
src_image[:, 0, 1] = 2 # Column of 2s in channel 1.
|
||||||
|
|
||||||
|
# Paste in bottom-center of dst_image.
|
||||||
|
box = TBLR(top=2, bottom=5, left=1, right=4)
|
||||||
|
|
||||||
|
# Construct expected output image.
|
||||||
|
expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
expected_output[2, 1:4, 0] = 1
|
||||||
|
expected_output[2:5, 1, 1] = 2
|
||||||
|
|
||||||
|
paste(dst_image=dst_image, src_image=src_image, box=box)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_paste_with_mask_success():
|
||||||
|
"""Test successful paste with a mask."""
|
||||||
|
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Create src_image with a pattern that can be used to validate that it was pasted correctly.
|
||||||
|
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
|
||||||
|
src_image[0, :, 0] = 64 # Row of 64s in channel 0.
|
||||||
|
src_image[:, 0, 1] = 128 # Column of 128s in channel 1.
|
||||||
|
|
||||||
|
# Paste in bottom-center of dst_image.
|
||||||
|
box = TBLR(top=2, bottom=5, left=1, right=4)
|
||||||
|
|
||||||
|
# Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image.
|
||||||
|
mask = np.zeros((3, 3))
|
||||||
|
mask[0, 0] = 0.5
|
||||||
|
|
||||||
|
# Construct expected output image.
|
||||||
|
expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
expected_output[2, 1, 0] = 32
|
||||||
|
expected_output[2, 1, 1] = 64
|
||||||
|
|
||||||
|
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_mask", [True, False])
|
||||||
|
def test_paste_box_overflows_dst_image(use_mask: bool):
|
||||||
|
"""Test that an exception is raised if 'box' overflows the 'dst_image'."""
|
||||||
|
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
|
||||||
|
mask = None
|
||||||
|
if use_mask:
|
||||||
|
mask = np.zeros((3, 3))
|
||||||
|
|
||||||
|
# Construct box that overflows bottom of dst_image.
|
||||||
|
top = 3
|
||||||
|
left = 0
|
||||||
|
box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_mask", [True, False])
|
||||||
|
def test_paste_src_image_does_not_match_box(use_mask: bool):
|
||||||
|
"""Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions."""
|
||||||
|
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
|
||||||
|
mask = None
|
||||||
|
if use_mask:
|
||||||
|
mask = np.zeros((3, 3))
|
||||||
|
|
||||||
|
# Construct box that is smaller than src_image.
|
||||||
|
box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_paste_mask_does_not_match_src_image():
|
||||||
|
"""Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape."""
|
||||||
|
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||||
|
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Construct mask that is smaller than the src_image.
|
||||||
|
mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1]))
|
||||||
|
|
||||||
|
# Construct box that matches src_image shape.
|
||||||
|
box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
|
Loading…
Reference in New Issue
Block a user