Merge branch 'main' into tiled-upscaling-graph

This commit is contained in:
skunkworxdark 2023-12-05 15:32:49 +00:00
commit 5816320645
67 changed files with 1181 additions and 869 deletions

View File

@ -120,7 +120,7 @@ Generate an image with a given prompt, record the seed of the image, and then
use the `prompt2prompt` syntax to substitute words in the original prompt for use the `prompt2prompt` syntax to substitute words in the original prompt for
words in a new prompt. This works for `img2img` as well. words in a new prompt. This works for `img2img` as well.
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions: For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because the words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
- `a cat playing with a ball in the forest` - `a cat playing with a ball in the forest`
- `a dog playing with a ball in the forest` - `a dog playing with a ball in the forest`

View File

@ -1,13 +1,15 @@
# List of Default Nodes # List of Default Nodes
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions. The table below contains a list of the default nodes shipped with InvokeAI and
their descriptions.
| Node <img width=160 align="right"> | Function | | Node <img width=160 align="right"> | Function |
|: ---------------------------------- | :--------------------------------------------------------------------------------------| | :------------------------------------------------------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------- |
| Add Integers | Adds two numbers | | Add Integers | Adds two numbers |
| Boolean Primitive Collection | A collection of boolean primitive values | | Boolean Primitive Collection | A collection of boolean primitive values |
| Boolean Primitive | A boolean primitive value | | Boolean Primitive | A boolean primitive value |
| Canny Processor | Canny edge detection for ControlNet | | Canny Processor | Canny edge detection for ControlNet |
| CenterPadCrop | Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image. |
| CLIP Skip | Skip layers in clip text_encoder model. | | CLIP Skip | Skip layers in clip text_encoder model. |
| Collect | Collects values into a collection | | Collect | Collects values into a collection |
| Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. | | Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. |
@ -74,7 +76,7 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
| Noise | Generates latent noise. | | Noise | Generates latent noise. |
| Normal BAE Processor | Applies NormalBae processing to image | | Normal BAE Processor | Applies NormalBae processing to image |
| ONNX Latents to Image | Generates an image from latents. | | ONNX Latents to Image | Generates an image from latents. |
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.| | ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in **init** to receive providers. |
| ONNX Text to Latents | Generates latents from conditionings. | | ONNX Text to Latents | Generates latents from conditionings. |
| ONNX Model Loader | Loads a main model, outputting its submodels. | | ONNX Model Loader | Loads a main model, outputting its submodels. |
| OpenCV Inpaint | Simple inpaint using opencv. | | OpenCV Inpaint | Simple inpaint using opencv. |

View File

@ -1,7 +1,11 @@
import typing import typing
from enum import Enum from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path from pathlib import Path
from platform import python_version
from typing import Optional
import torch
from fastapi import Body from fastapi import Body
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -40,6 +44,24 @@ class AppVersion(BaseModel):
version: str = Field(description="App version") version: str = Field(description="App version")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""
accelerate: str = Field(description="accelerate version")
compel: str = Field(description="compel version")
cuda: Optional[str] = Field(description="CUDA version")
diffusers: str = Field(description="diffusers version")
numpy: str = Field(description="Numpy version")
opencv: str = Field(description="OpenCV version")
onnx: str = Field(description="ONNX version")
pillow: str = Field(description="Pillow (PIL) version")
python: str = Field(description="Python version")
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")
class AppConfig(BaseModel): class AppConfig(BaseModel):
"""App Config Response""" """App Config Response"""
@ -54,6 +76,29 @@ async def get_version() -> AppVersion:
return AppVersion(version=__version__) return AppVersion(version=__version__)
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
cuda=torch.version.cuda,
diffusers=version("diffusers"),
numpy=version("numpy"),
opencv=version("opencv-python"),
onnx=version("onnx"),
pillow=version("pillow"),
python=python_version(),
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig) @app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig: async def get_config() -> AppConfig:
infill_methods = ["tile", "lama", "cv2"] infill_methods = ["tile", "lama", "cv2"]

View File

@ -141,7 +141,7 @@ async def del_model_record(
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Add a model using the configuration information appropriate for its type. Add a model using the configuration information appropriate for its type.

View File

@ -100,6 +100,61 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
) )
@invocation(
invocation_type="img_pad_crop",
title="Center Pad or Crop Image",
category="image",
tags=["image", "pad", "crop"],
version="1.0.0",
)
class CenterPadCropInvocation(BaseInvocation):
"""Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""
image: ImageField = InputField(description="The image to crop")
left: int = InputField(
default=0,
description="Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards)",
)
right: int = InputField(
default=0,
description="Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards)",
)
top: int = InputField(
default=0,
description="Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards)",
)
bottom: int = InputField(
default=0,
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
# Calculate and create new image dimensions
new_width = image.width + self.right + self.left
new_height = image.height + self.top + self.bottom
image_crop = Image.new(mode="RGBA", size=(new_width, new_height), color=(0, 0, 0, 0))
# Paste new image onto input
image_crop.paste(image, (self.left, self.top))
image_dto = context.services.images.create(
image=image_crop,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0") @invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata): class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Pastes an image into another image.""" """Pastes an image into another image."""

View File

@ -221,7 +221,7 @@ def get_scheduler(
title="Denoise Latents", title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents", category="latents",
version="1.4.0", version="1.5.0",
) )
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
@ -279,6 +279,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
ui_order=7, ui_order=7,
) )
cfg_rescale_multiplier: float = InputField(
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField( latents: Optional[LatentsField] = InputField(
default=None, default=None,
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
@ -338,6 +341,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
text_embeddings=c, text_embeddings=c,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
extra=extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold, threshold=0.0, # threshold,
@ -1190,12 +1194,12 @@ class CropLatentsCoreInvocation(BaseInvocation):
description=FieldDescriptions.latents, description=FieldDescriptions.latents,
input=Input.Connection, input=Input.Connection,
) )
x_offset: int = InputField( x: int = InputField(
ge=0, ge=0,
multiple_of=LATENT_SCALE_FACTOR, multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
) )
y_offset: int = InputField( y: int = InputField(
ge=0, ge=0,
multiple_of=LATENT_SCALE_FACTOR, multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
@ -1214,8 +1218,8 @@ class CropLatentsCoreInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
x1 = self.x_offset // LATENT_SCALE_FACTOR x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y_offset // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR) x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR) y2 = y1 + (self.height // LATENT_SCALE_FACTOR)

View File

@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation") seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation") rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter") cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
cfg_rescale_multiplier: Optional[float] = InputField(
default=None, description=FieldDescriptions.cfg_rescale_multiplier
)
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference") steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference") scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis") seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")

View File

@ -1,5 +1,3 @@
from typing import Literal
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from pydantic import BaseModel from pydantic import BaseModel
@ -7,7 +5,6 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Input,
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
@ -18,13 +15,7 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import ( from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
merge_tiles_with_seam_blending,
)
from invokeai.backend.tiles.utils import Tile from invokeai.backend.tiles.utils import Tile
@ -65,92 +56,12 @@ class CalculateImageTilesInvocation(BaseInvocation):
return CalculateImageTilesOutput(tiles=tiles) return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_Even_Split",
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.0.0",
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
num_tiles_x: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the x axis",
)
num_tiles_y: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap: float = InputField(
default=0.25,
ge=0,
lt=1,
description="Overlap amount of tile size (0-1)",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap=self.overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_min_overlap",
title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"],
category="tiles",
version="1.0.0",
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(
default=128,
ge=0,
description="minimum tile overlap size (must be a multiple of 8)",
)
round_to_8: bool = InputField(
default=False,
description="Round outputs down to the nearest 8 (for pulling from a large noise field)",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.min_overlap,
round_to_8=self.round_to_8,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output") @invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput): class TileToPropertiesOutput(BaseInvocationOutput):
coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.")
coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.")
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.") coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.") coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.")
coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.")
coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.")
# HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this # HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this
# object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design # object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design
@ -174,10 +85,10 @@ class TileToPropertiesInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
return TileToPropertiesOutput( return TileToPropertiesOutput(
coords_top=self.tile.coords.top,
coords_bottom=self.tile.coords.bottom,
coords_left=self.tile.coords.left, coords_left=self.tile.coords.left,
coords_right=self.tile.coords.right, coords_right=self.tile.coords.right,
coords_top=self.tile.coords.top,
coords_bottom=self.tile.coords.bottom,
width=self.tile.coords.right - self.tile.coords.left, width=self.tile.coords.right - self.tile.coords.left,
height=self.tile.coords.bottom - self.tile.coords.top, height=self.tile.coords.bottom - self.tile.coords.top,
overlap_top=self.tile.overlap.top, overlap_top=self.tile.overlap.top,
@ -211,22 +122,13 @@ class PairTileImageInvocation(BaseInvocation):
) )
BLEND_MODES = Literal["Linear", "Seam"] @invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.0.0")
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Merge multiple tile images into a single image.""" """Merge multiple tile images into a single image."""
# Inputs # Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.") tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_mode: BLEND_MODES = InputField(
default="Seam",
description="blending type Linear or Seam",
input=Input.Direct,
)
blend_amount: int = InputField( blend_amount: int = InputField(
default=32,
ge=0, ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
) )
@ -256,16 +158,10 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
channels = tile_np_images[0].shape[-1] channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype) np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
if self.blend_mode == "Linear":
merge_tiles_with_linear_blending( merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
) )
else:
merge_tiles_with_seam_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
# Convert into a PIL image and save
pil_image = Image.fromarray(np_image) pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create( image_dto = context.services.images.create(

View File

@ -5,6 +5,8 @@ from typing import Union
import torch import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase from .latents_storage_base import LatentsStorageBase
@ -17,6 +19,10 @@ class DiskLatentsStorage(LatentsStorageBase):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_latents()
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name) latent_path = self.get_path(name)
return torch.load(latent_path) return torch.load(latent_path)
@ -32,3 +38,21 @@ class DiskLatentsStorage(LatentsStorageBase):
def get_path(self, name: str) -> Path: def get_path(self, name: str) -> Path:
return self.__output_folder / name return self.__output_folder / name
def _delete_all_latents(self) -> None:
"""
Deletes all latents from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self.__output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
)

View File

@ -5,6 +5,8 @@ from typing import Dict, Optional
import torch import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase from .latents_storage_base import LatentsStorageBase
@ -23,6 +25,18 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size self.__max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self.__underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self.__underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name) cache_item = self.__get_cache(name)
if cache_item is not None: if cache_item is not None:

View File

@ -42,6 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None: def __init__(self, db: SqliteDatabase) -> None:

View File

@ -207,10 +207,12 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = OutputField( item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
) )
index: int = OutputField(description="The index of the item", title="Index")
total: int = OutputField(description="The total number of items", title="Total")
# TODO: Fill this out and move to invocations # TODO: Fill this out and move to invocations
@invocation("iterate", version="1.0.0") @invocation("iterate", version="1.1.0")
class IterateInvocation(BaseInvocation): class IterateInvocation(BaseInvocation):
"""Iterates over a list of items""" """Iterates over a list of items"""
@ -221,7 +223,7 @@ class IterateInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IterateInvocationOutput: def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values""" """Produces the outputs as values"""
return IterateInvocationOutput(item=self.collection[self.index]) return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
@invocation_output("collect_output") @invocation_output("collect_output")

View File

@ -1,6 +1,7 @@
import sqlite3 import sqlite3
import threading import threading
from logging import Logger from logging import Logger
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
@ -8,25 +9,20 @@ sqlite_memory = ":memory:"
class SqliteDatabase: class SqliteDatabase:
conn: sqlite3.Connection
lock: threading.RLock
_logger: Logger
_config: InvokeAIAppConfig
def __init__(self, config: InvokeAIAppConfig, logger: Logger): def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger self._logger = logger
self._config = config self._config = config
if self._config.use_memory_db: if self._config.use_memory_db:
location = sqlite_memory self.db_path = sqlite_memory
logger.info("Using in-memory database") logger.info("Using in-memory database")
else: else:
db_path = self._config.db_path db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
location = str(db_path) self.db_path = str(db_path)
self._logger.info(f"Using database at {location}") self._logger.info(f"Using database at {self.db_path}")
self.conn = sqlite3.connect(location, check_same_thread=False) self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.lock = threading.RLock() self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row self.conn.row_factory = sqlite3.Row
@ -37,10 +33,16 @@ class SqliteDatabase:
def clean(self) -> None: def clean(self) -> None:
try: try:
if self.db_path == sqlite_memory:
return
initial_db_size = Path(self.db_path).stat().st_size
self.lock.acquire() self.lock.acquire()
self.conn.execute("VACUUM;") self.conn.execute("VACUUM;")
self.conn.commit() self.conn.commit()
self._logger.info("Cleaned database") final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e: except Exception as e:
self._logger.error(f"Error cleaning database: {e}") self._logger.error(f"Error cleaning database: {e}")
raise e raise e

View File

@ -2,6 +2,7 @@ class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps" denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps" denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale" cfg_scale = "Classifier-Free Guidance scale"
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
scheduler = "Scheduler to use during inference" scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor" positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor" negative_cond = "Negative conditioning tensor"

View File

@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
return clip_extra_context_tokens return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim),
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
Returns:
MLPProjModel
"""
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
model = cls(cross_attention_dim, clip_embeddings_dim)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter: class IPAdapter:
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus): class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL.""" """IP-Adapter Plus for SDXL."""
@ -149,11 +194,9 @@ def build_ip_adapter(
) -> Union[IPAdapter, IPAdapterPlus]: ) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
# contains. return IPAdapter(state_dict, device=device, dtype=dtype)
is_plus = "proj.weight" not in state_dict["image_proj"] elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
if is_plus:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768: if cross_attention_dim == 768:
# SD1 IP-Adapter Plus # SD1 IP-Adapter Plus
@ -163,5 +206,7 @@ def build_ip_adapter(
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else: else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
return IPAdapterFull(state_dict, device=device, dtype=dtype)
else: else:
return IPAdapter(state_dict, device=device, dtype=dtype) raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")

View File

@ -192,20 +192,33 @@ class ModelPatcher:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
def _get_ti_embedding(model_embeddings, ti):
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
ti.embedding_2
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
else ti.embedding
)
else:
return ti.embedding
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]): ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder # modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings() model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list: for ti_name, _ in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti_embedding.shape[0]):
embedding = ti.embedding[i] embedding = ti_embedding[i]
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
@ -273,6 +286,7 @@ class ModelPatcher:
class TextualInversionModel: class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod @classmethod
def from_checkpoint( def from_checkpoint(
@ -296,8 +310,8 @@ class TextualInversionModel:
if "string_to_param" in state_dict: if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1: if len(state_dict["string_to_param"]) > 1:
print( print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first' f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used." " token will be used.",
) )
result.embedding = next(iter(state_dict["string_to_param"].values())) result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -306,6 +320,11 @@ class TextualInversionModel:
elif "emb_params" in state_dict: elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"] result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files) # v4(diffusers bin files)
else: else:
result.embedding = next(iter(state_dict.values())) result.embedding = next(iter(state_dict.values()))
@ -342,6 +361,13 @@ class TextualInversionManager(BaseTextualInversionManager):
if token_id in self.pad_tokens: if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id]) new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids return new_token_ids
@ -490,24 +516,31 @@ class ONNXModelPatcher:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]): if ti.embedding_2 is not None:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else:
ti_embedding = ti.embedding
# modify text_encoder for i in range(ti_embedding.shape[0]):
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
embeddings = np.concatenate( embeddings = np.concatenate(
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
axis=0, axis=0,
) )
for ti_name, ti in ti_list: for ti_name, _ in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti_embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy() embedding = ti_embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)

View File

@ -373,12 +373,16 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint: elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1] token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else: else:
token_dim = list(checkpoint.values())[0].shape[0] token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768: if token_dim == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif token_dim == 1024: elif token_dim == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else: else:
return None return None

View File

@ -607,10 +607,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if isinstance(guidance_scale, list): if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[step_index] guidance_scale = guidance_scale[step_index]
noise_pred = self.invokeai_diffuser._combine( noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
uc_noise_pred, guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
if guidance_rescale_multiplier > 0:
noise_pred = self._rescale_cfg(
noise_pred,
c_noise_pred, c_noise_pred,
guidance_scale, guidance_rescale_multiplier,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output return step_output
@staticmethod
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
def _unet_forward( def _unet_forward(
self, self,
latents, latents,

View File

@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
""" """
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
extra: Optional[ExtraConditioningInfo] = None extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict) scheduler_args: dict[str, Any] = field(default_factory=dict)
""" """

View File

@ -1,8 +1,9 @@
import math import math
from typing import Union
import numpy as np import numpy as np
from invokeai.backend.tiles.utils import TBLR, Tile, calc_overlap, paste, seam_blend from invokeai.backend.tiles.utils import TBLR, Tile, paste
def calc_tiles_with_overlap( def calc_tiles_with_overlap(
@ -62,117 +63,31 @@ def calc_tiles_with_overlap(
tiles.append(tile) tiles.append(tile)
return calc_overlap(tiles, num_tiles_x, num_tiles_y) def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
# Iterate over tiles again and calculate overlaps.
def calc_tiles_even_split(
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: float = 0
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
num_x_tiles (int): The number of tile to split the image into on the X-axis.
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
overlap (int, optional): The target overlap amount of the tiles size. Defaults to 0.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
# Ensure tile size is divisible by 8
if image_width % 8 != 0 or image_height % 8 != 0:
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by 8")
# Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up)
overlap_x = 8 * math.ceil(int((image_width / num_tiles_x) * overlap) / 8)
overlap_y = 8 * math.ceil(int((image_height / num_tiles_y) * overlap) / 8)
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
tile_size_x = 8 * math.floor(((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / 8)
tile_size_y = 8 * math.floor(((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / 8)
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y): for tile_idx_y in range(num_tiles_y):
# Calculate the top and bottom of the row
top = tile_idx_y * (tile_size_y - overlap_y)
bottom = min(top + tile_size_y, image_height)
# For the last row adjust bottom to be the height of the image
if tile_idx_y == num_tiles_y - 1:
bottom = image_height
for tile_idx_x in range(num_tiles_x): for tile_idx_x in range(num_tiles_x):
# Calculate the left & right coordinate of each tile cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
left = tile_idx_x * (tile_size_x - overlap_x) top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
right = min(left + tile_size_x, image_width) left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
# For the last tile in the row adjust right to be the width of the image
if tile_idx_x == num_tiles_x - 1:
right = image_width
tile = Tile( assert cur_tile is not None
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
tiles.append(tile) # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
return calc_overlap(tiles, num_tiles_x, num_tiles_y) # Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles
def calc_tiles_min_overlap(
image_height: int, image_width: int, tile_height: int, tile_width: int, min_overlap: int, round_to_8: bool
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
tile_height (int): The tile height in px. All tiles will have this height.
tile_width (int): The tile width in px. All tiles will have this width.
min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image
shape, then the overlap will be spread between the tiles.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
assert image_height >= tile_height
assert image_width >= tile_width
assert min_overlap < tile_height
assert min_overlap < tile_width
num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap)) if tile_width < image_width else 1
num_tiles_y = (
math.ceil((image_height - min_overlap) / (tile_height - min_overlap)) if tile_height < image_height else 1
)
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0
if round_to_8:
top = 8 * (top // 8)
bottom = top + tile_height
for tile_idx_x in range(num_tiles_x):
left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0
if round_to_8:
left = 8 * (left // 8)
right = left + tile_width
tile = Tile(
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
tiles.append(tile)
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
def merge_tiles_with_linear_blending( def merge_tiles_with_linear_blending(
@ -284,91 +199,3 @@ def merge_tiles_with_linear_blending(
), ),
mask=mask, mask=mask,
) )
def merge_tiles_with_seam_blending(
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
):
"""Merge a set of image tiles into `dst_image` with seam blending between the tiles.
We expect every tile edge to either:
1) have an overlap of 0, because it is aligned with the image edge, or
2) have an overlap >= blend_amount.
If neither of these conditions are satisfied, we raise an exception.
The seam blending is centered on a seam of least energy of the overlap between adjacent tiles.
Args:
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
"""
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
# iterate over tiles left-to-right, top-to-bottom.
tiles_and_images = list(zip(tiles, tile_images, strict=True))
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
# Organize tiles into rows.
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
first_tile_in_cur_row, _ = tiles_and_images[0]
for tile_and_image in tiles_and_images:
tile, _ = tile_and_image
if not (
tile.coords.top == first_tile_in_cur_row.coords.top
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
):
# Store the previous row, and start a new one.
tile_and_image_rows.append(cur_tile_and_image_row)
cur_tile_and_image_row = []
first_tile_in_cur_row, _ = tile_and_image
cur_tile_and_image_row.append(tile_and_image)
tile_and_image_rows.append(cur_tile_and_image_row)
for tile_and_image_row in tile_and_image_rows:
first_tile_in_row, _ = tile_and_image_row[0]
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
# Blend the tiles in the row horizontally.
for tile, tile_image in tile_and_image_row:
# We expect the tiles to be ordered left-to-right.
# For each tile:
# - extract the overlap regions and pass to seam_blend()
# - apply blended region to the row_image
# - apply the un-blended region to the row_image
tile_height, tile_width, _ = tile_image.shape
overlap_size = tile.overlap.left
# Left blending:
if overlap_size > 0:
assert overlap_size >= blend_amount
overlap_coord_right = tile.coords.left + overlap_size
src_overlap = row_image[:, tile.coords.left : overlap_coord_right]
dst_overlap = tile_image[:, :overlap_size]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False)
row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap
row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:]
else:
# no overlap just paste the tile
row_image[:, tile.coords.left : tile.coords.right] = tile_image
# Blend the row into the dst_image
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
# Rows are processed in the same way as tiles (extract overlap, blend, apply)
row_overlap_size = first_tile_in_row.overlap.top
if row_overlap_size > 0:
assert row_overlap_size >= blend_amount
overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size
src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :]
dst_overlap = row_image[:row_overlap_size, :]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True)
dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap
dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :]
else:
# no overlap just paste the row
row_image[first_tile_in_row.coords.top:first_tile_in_row.coords.bottom, :] = row_image

View File

@ -1,9 +1,6 @@
import math from typing import Optional
from typing import Optional, Union
import cv2
import numpy as np import numpy as np
#from PIL import Image
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -48,130 +45,3 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
mask = np.expand_dims(mask, -1) mask = np.expand_dims(mask, -1)
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right] dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask) dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
def calc_overlap(tiles: list[Tile], num_tiles_x, num_tiles_y) -> list[Tile]:
"""Calculate and update the overlap of a list of tiles.
Args:
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
num_tiles_x: the number of tiles on the x axis.
num_tiles_y: the number of tiles on the y axis.
"""
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
assert cur_tile is not None
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles
def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool,) -> np.ndarray:
"""Blend two overlapping tile sections using a seams to find a path.
It is assumed that input images will be RGB np arrays and are the same size.
Args:
ia1 (torch.Tensor): Image array 1 Shape: (H, W, C).
ia2 (torch.Tensor): Image array 2 Shape: (H, W, C).
x_seam (bool): If the images should be blended on the x axis or not.
blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image.
"""
def shift(arr, num, fill_value=255.0):
result = np.full_like(arr, fill_value)
if num > 0:
result[num:] = arr[:-num]
elif num < 0:
result[:num] = arr[-num:]
else:
result[:] = arr
return result
# Assume RGB and convert to grey
iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140])
iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140])
# Calc Difference between the images
ia = iag2 - iag1
# If the seam is on the X-axis rotate the array so we can treat it like a vertical seam
if x_seam:
ia = np.rot90(ia, 1)
# Calc max and min X & Y limits
# gutter is used to avoid the blur hitting the edge of the image
gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0
max_y, max_x = ia.shape
max_x -= gutter
min_x = gutter
# Calc the energy in the difference
energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1))
#Find the starting position of the seam
res = np.copy(energy)
for y in range(1, max_y):
row = res[y, :]
rowl = shift(row, -1)
rowr = shift(row, 1)
res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0)
# create an array max_y long
lowest_energy_line = np.empty([max_y], dtype="uint16")
lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1])
#Calc the path of the seam
for ypos in range(max_y - 2, -1, -1):
lowest_pos = lowest_energy_line[ypos + 1]
lpos = lowest_pos - 1
rpos = lowest_pos + 1
lpos = np.clip(lpos, min_x, max_x - 1)
rpos = np.clip(rpos, min_x, max_x - 1)
lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos
# Draw the mask
mask = np.zeros_like(ia)
for ypos in range(0, max_y):
to_fill = lowest_energy_line[ypos]
mask[ypos, :to_fill] = 1
# If the seam is on the X-axis rotate the array back
if x_seam:
mask = np.rot90(mask, 3)
# blur the seam mask if required
if blend_amount > 0:
mask = cv2.blur(mask, (blend_amount, blend_amount))
# copy ia2 over ia1 while applying the seam mask
mask = np.expand_dims(mask, -1)
blended_image = ia1 * mask + ia2 * (1.0 - mask)
# for debugging to see the final blended overlap image
#image = Image.fromarray((mask * 255.0).astype("uint8"))
#i1 = Image.fromarray(ia1.astype("uint8"))
#i2 = Image.fromarray(ia2.astype("uint8"))
#bimage = Image.fromarray(blended_image.astype("uint8"))
#print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}")
#print(f"{i1.size}, {i2.size}, {image.size}, {bimage.size}")
return blended_image

View File

@ -342,9 +342,8 @@ class InvokeAILogger(object): # noqa D102
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger: # noqa D102 ) -> logging.Logger: # noqa D102
if name in cls.loggers: if name in cls.loggers:
logger = cls.loggers[name] return cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.get_loggers(config): for ch in cls.get_loggers(config):
@ -358,7 +357,7 @@ class InvokeAILogger(object): # noqa D102
handlers = [] handlers = []
for handler in handler_strs: for handler in handler_strs:
handler_name, *args = handler.split("=", 2) handler_name, *args = handler.split("=", 2)
args = args[0] if len(args) > 0 else None arg = args[0] if len(args) > 0 else None
# console and file get the fancy formatter. # console and file get the fancy formatter.
# syslog gets a simple one # syslog gets a simple one
@ -370,16 +369,16 @@ class InvokeAILogger(object): # noqa D102
handlers.append(ch) handlers.append(ch)
elif handler_name == "syslog": elif handler_name == "syslog":
ch = cls._parse_syslog_args(args) ch = cls._parse_syslog_args(arg)
handlers.append(ch) handlers.append(ch)
elif handler_name == "file": elif handler_name == "file":
ch = cls._parse_file_args(args) ch = cls._parse_file_args(arg)
ch.setFormatter(formatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
elif handler_name == "http": elif handler_name == "http":
ch = cls._parse_http_args(args) ch = cls._parse_http_args(arg)
handlers.append(ch) handlers.append(ch)
return handlers return handlers

View File

@ -75,6 +75,7 @@
"framer-motion": "^10.16.4", "framer-motion": "^10.16.4",
"i18next": "^23.6.0", "i18next": "^23.6.0",
"i18next-http-backend": "^2.3.1", "i18next-http-backend": "^2.3.1",
"idb-keyval": "^6.2.1",
"konva": "^9.2.3", "konva": "^9.2.3",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.9.4", "nanostores": "^0.9.4",

View File

@ -803,8 +803,7 @@
"canny": "Canny", "canny": "Canny",
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung", "hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
"scribble": "Scribble", "scribble": "Scribble",
"maxFaces": "Maximal Anzahl Gesichter", "maxFaces": "Maximal Anzahl Gesichter"
"unstarImage": "Markierung aufheben"
}, },
"queue": { "queue": {
"status": "Status", "status": "Status",

View File

@ -243,7 +243,6 @@
"setControlImageDimensions": "Set Control Image Dimensions To W/H", "setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced", "showAdvanced": "Show Advanced",
"toggleControlNet": "Toggle this ControlNet", "toggleControlNet": "Toggle this ControlNet",
"unstarImage": "Unstar Image",
"w": "W", "w": "W",
"weight": "Weight", "weight": "Weight",
"enableIPAdapter": "Enable IP Adapter", "enableIPAdapter": "Enable IP Adapter",
@ -378,6 +377,8 @@
"showGenerations": "Show Generations", "showGenerations": "Show Generations",
"showUploads": "Show Uploads", "showUploads": "Show Uploads",
"singleColumnLayout": "Single Column Layout", "singleColumnLayout": "Single Column Layout",
"starImage": "Star Image",
"unstarImage": "Unstar Image",
"unableToLoad": "Unable to load Gallery", "unableToLoad": "Unable to load Gallery",
"uploads": "Uploads", "uploads": "Uploads",
"deleteSelection": "Delete Selection", "deleteSelection": "Delete Selection",
@ -599,6 +600,7 @@
}, },
"metadata": { "metadata": {
"cfgScale": "CFG scale", "cfgScale": "CFG scale",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"createdBy": "Created By", "createdBy": "Created By",
"fit": "Image to image fit", "fit": "Image to image fit",
"generationMode": "Generation Mode", "generationMode": "Generation Mode",
@ -977,6 +979,7 @@
"unsupportedAnyOfLength": "too many union members ({{count}})", "unsupportedAnyOfLength": "too many union members ({{count}})",
"unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}", "unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}",
"unableToParseFieldType": "unable to parse field type", "unableToParseFieldType": "unable to parse field type",
"unableToExtractEnumOptions": "unable to extract enum options",
"uNetField": "UNet", "uNetField": "UNet",
"uNetFieldDescription": "UNet submodel.", "uNetFieldDescription": "UNet submodel.",
"unhandledInputProperty": "Unhandled input property", "unhandledInputProperty": "Unhandled input property",
@ -1032,6 +1035,8 @@
"setType": "Set cancel type" "setType": "Set cancel type"
}, },
"cfgScale": "CFG Scale", "cfgScale": "CFG Scale",
"cfgRescaleMultiplier": "CFG Rescale Multiplier",
"cfgRescale": "CFG Rescale",
"clipSkip": "CLIP Skip", "clipSkip": "CLIP Skip",
"clipSkipWithLayerCount": "CLIP Skip {{layerCount}}", "clipSkipWithLayerCount": "CLIP Skip {{layerCount}}",
"closeViewer": "Close Viewer", "closeViewer": "Close Viewer",
@ -1470,6 +1475,12 @@
"Controls how much your prompt influences the generation process." "Controls how much your prompt influences the generation process."
] ]
}, },
"paramCFGRescaleMultiplier": {
"heading": "CFG Rescale Multiplier",
"paragraphs": [
"Rescale multiplier for CFG guidance, used for models trained using zero-terminal SNR (ztsnr). Suggested value 0.7."
]
},
"paramDenoisingStrength": { "paramDenoisingStrength": {
"heading": "Denoising Strength", "heading": "Denoising Strength",
"paragraphs": [ "paragraphs": [

View File

@ -91,7 +91,19 @@
"controlNet": "ControlNet", "controlNet": "ControlNet",
"auto": "Automatico", "auto": "Automatico",
"simple": "Semplice", "simple": "Semplice",
"details": "Dettagli" "details": "Dettagli",
"format": "formato",
"unknown": "Sconosciuto",
"folder": "Cartella",
"error": "Errore",
"installed": "Installato",
"template": "Schema",
"outputs": "Uscite",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
"copyError": "$t(gallery.copy) Errore",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)"
}, },
"gallery": { "gallery": {
"generations": "Generazioni", "generations": "Generazioni",
@ -122,7 +134,14 @@
"preparingDownload": "Preparazione del download", "preparingDownload": "Preparazione del download",
"preparingDownloadFailed": "Problema durante la preparazione del download", "preparingDownloadFailed": "Problema durante la preparazione del download",
"downloadSelection": "Scarica gli elementi selezionati", "downloadSelection": "Scarica gli elementi selezionati",
"noImageSelected": "Nessuna immagine selezionata" "noImageSelected": "Nessuna immagine selezionata",
"deleteSelection": "Elimina la selezione",
"image": "immagine",
"drop": "Rilascia",
"unstarImage": "Rimuovi preferenza immagine",
"dropOrUpload": "$t(gallery.drop) o carica",
"starImage": "Immagine preferita",
"dropToUpload": "$t(gallery.drop) per aggiornare"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "Tasti rapidi", "keyboardShortcuts": "Tasti rapidi",
@ -477,7 +496,8 @@
"modelType": "Tipo di modello", "modelType": "Tipo di modello",
"customConfigFileLocation": "Posizione del file di configurazione personalizzato", "customConfigFileLocation": "Posizione del file di configurazione personalizzato",
"vaePrecision": "Precisione VAE", "vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato" "noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata"
}, },
"parameters": { "parameters": {
"images": "Immagini", "images": "Immagini",
@ -838,7 +858,8 @@
"menu": "Menu", "menu": "Menu",
"showGalleryPanel": "Mostra il pannello Galleria", "showGalleryPanel": "Mostra il pannello Galleria",
"loadMore": "Carica altro", "loadMore": "Carica altro",
"mode": "Modalità" "mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente"
}, },
"ui": { "ui": {
"hideProgressImages": "Nascondi avanzamento immagini", "hideProgressImages": "Nascondi avanzamento immagini",
@ -1040,7 +1061,15 @@
"updateAllNodes": "Aggiorna tutti i nodi", "updateAllNodes": "Aggiorna tutti i nodi",
"unableToUpdateNodes_one": "Impossibile aggiornare {{count}} nodo", "unableToUpdateNodes_one": "Impossibile aggiornare {{count}} nodo",
"unableToUpdateNodes_many": "Impossibile aggiornare {{count}} nodi", "unableToUpdateNodes_many": "Impossibile aggiornare {{count}} nodi",
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi" "unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare",
"outputFieldInInput": "Campo di uscita in ingresso",
"unableToMigrateWorkflow": "Impossibile migrare il flusso di lavoro",
"unableToUpdateNode": "Impossibile aggiornare nodo",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare",
"nodeVersion": "Versione Nodo"
}, },
"boards": { "boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca", "autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1062,7 +1091,10 @@
"deleteBoardOnly": "Elimina solo la Bacheca", "deleteBoardOnly": "Elimina solo la Bacheca",
"deleteBoard": "Elimina Bacheca", "deleteBoard": "Elimina Bacheca",
"deleteBoardAndImages": "Elimina Bacheca e Immagini", "deleteBoardAndImages": "Elimina Bacheca e Immagini",
"deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate" "deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate",
"movingImagesToBoard_one": "Spostare {{count}} immagine nella bacheca:",
"movingImagesToBoard_many": "Spostare {{count}} immagini nella bacheca:",
"movingImagesToBoard_other": "Spostare {{count}} immagini nella bacheca:"
}, },
"controlnet": { "controlnet": {
"contentShuffleDescription": "Rimescola il contenuto di un'immagine", "contentShuffleDescription": "Rimescola il contenuto di un'immagine",
@ -1136,7 +1168,8 @@
"megaControl": "Mega ControlNet", "megaControl": "Mega ControlNet",
"minConfidence": "Confidenza minima", "minConfidence": "Confidenza minima",
"scribble": "Scribble", "scribble": "Scribble",
"amult": "Angolo di illuminazione" "amult": "Angolo di illuminazione",
"coarse": "Approssimativo"
}, },
"queue": { "queue": {
"queueFront": "Aggiungi all'inizio della coda", "queueFront": "Aggiungi all'inizio della coda",
@ -1204,7 +1237,8 @@
"embedding": { "embedding": {
"noMatchingEmbedding": "Nessun Incorporamento corrispondente", "noMatchingEmbedding": "Nessun Incorporamento corrispondente",
"addEmbedding": "Aggiungi Incorporamento", "addEmbedding": "Aggiungi Incorporamento",
"incompatibleModel": "Modello base incompatibile:" "incompatibleModel": "Modello base incompatibile:",
"noEmbeddingsLoaded": "Nessun incorporamento caricato"
}, },
"models": { "models": {
"noMatchingModels": "Nessun modello corrispondente", "noMatchingModels": "Nessun modello corrispondente",
@ -1217,7 +1251,8 @@
"noRefinerModelsInstalled": "Nessun modello SDXL Refiner installato", "noRefinerModelsInstalled": "Nessun modello SDXL Refiner installato",
"noLoRAsInstalled": "Nessun LoRA installato", "noLoRAsInstalled": "Nessun LoRA installato",
"esrganModel": "Modello ESRGAN", "esrganModel": "Modello ESRGAN",
"addLora": "Aggiungi LoRA" "addLora": "Aggiungi LoRA",
"noLoRAsLoaded": "Nessuna LoRA caricata"
}, },
"invocationCache": { "invocationCache": {
"disable": "Disabilita", "disable": "Disabilita",
@ -1233,7 +1268,8 @@
"enable": "Abilita", "enable": "Abilita",
"clear": "Svuota", "clear": "Svuota",
"maxCacheSize": "Dimensione max cache", "maxCacheSize": "Dimensione max cache",
"cacheSize": "Dimensione cache" "cacheSize": "Dimensione cache",
"useCache": "Usa Cache"
}, },
"dynamicPrompts": { "dynamicPrompts": {
"seedBehaviour": { "seedBehaviour": {

View File

@ -72,5 +72,13 @@
}, },
"unifiedCanvas": { "unifiedCanvas": {
"betaPreserveMasked": "마스크 레이어 유지" "betaPreserveMasked": "마스크 레이어 유지"
},
"accessibility": {
"previousImage": "이전 이미지",
"modifyConfig": "Config 수정",
"nextImage": "다음 이미지",
"mode": "모드",
"menu": "메뉴",
"modelSelect": "모델 선택"
} }
} }

View File

@ -99,7 +99,17 @@
"data": "数据", "data": "数据",
"safetensors": "Safetensors", "safetensors": "Safetensors",
"outpaint": "外扩绘制", "outpaint": "外扩绘制",
"details": "详情" "details": "详情",
"format": "格式",
"unknown": "未知",
"folder": "文件夹",
"error": "错误",
"installed": "已安装",
"file": "文件",
"somethingWentWrong": "出了点问题",
"copyError": "$t(gallery.copy) 错误",
"input": "输入",
"notInstalled": "非 $t(common.installed)"
}, },
"gallery": { "gallery": {
"generations": "生成的图像", "generations": "生成的图像",
@ -130,7 +140,12 @@
"preparingDownload": "准备下载", "preparingDownload": "准备下载",
"preparingDownloadFailed": "准备下载时出现问题", "preparingDownloadFailed": "准备下载时出现问题",
"downloadSelection": "下载所选内容", "downloadSelection": "下载所选内容",
"noImageSelected": "无选中的图像" "noImageSelected": "无选中的图像",
"deleteSelection": "删除所选内容",
"image": "图像",
"drop": "弃用",
"dropOrUpload": "$t(gallery.drop) 或上传",
"dropToUpload": "$t(gallery.drop) 以上传"
}, },
"hotkeys": { "hotkeys": {
"keyboardShortcuts": "键盘快捷键", "keyboardShortcuts": "键盘快捷键",
@ -486,7 +501,8 @@
"alpha": "Alpha", "alpha": "Alpha",
"vaePrecision": "VAE 精度", "vaePrecision": "VAE 精度",
"checkpointOrSafetensors": "$t(common.checkpoint) / $t(common.safetensors)", "checkpointOrSafetensors": "$t(common.checkpoint) / $t(common.safetensors)",
"noModelSelected": "无选中的模型" "noModelSelected": "无选中的模型",
"conversionNotSupported": "转换尚未支持"
}, },
"parameters": { "parameters": {
"images": "图像", "images": "图像",
@ -615,7 +631,10 @@
"seamlessX": "无缝 X", "seamlessX": "无缝 X",
"seamlessY": "无缝 Y", "seamlessY": "无缝 Y",
"maskEdge": "遮罩边缘", "maskEdge": "遮罩边缘",
"unmasked": "取消遮罩" "unmasked": "取消遮罩",
"cfgRescaleMultiplier": "CFG 重缩放倍数",
"cfgRescale": "CFG 重缩放",
"useSize": "使用尺寸"
}, },
"settings": { "settings": {
"models": "模型", "models": "模型",
@ -655,7 +674,8 @@
"clearIntermediatesDisabled": "队列为空才能清理中间产物", "clearIntermediatesDisabled": "队列为空才能清理中间产物",
"enableNSFWChecker": "启用成人内容检测器", "enableNSFWChecker": "启用成人内容检测器",
"enableInvisibleWatermark": "启用不可见水印", "enableInvisibleWatermark": "启用不可见水印",
"enableInformationalPopovers": "启用信息弹窗" "enableInformationalPopovers": "启用信息弹窗",
"reloadingIn": "重新加载中"
}, },
"toast": { "toast": {
"tempFoldersEmptied": "临时文件夹已清空", "tempFoldersEmptied": "临时文件夹已清空",
@ -739,7 +759,8 @@
"imageUploadFailed": "图像上传失败", "imageUploadFailed": "图像上传失败",
"problemImportingMask": "导入遮罩时出现问题", "problemImportingMask": "导入遮罩时出现问题",
"baseModelChangedCleared_other": "基础模型已更改, 已清除或禁用 {{count}} 个不兼容的子模型", "baseModelChangedCleared_other": "基础模型已更改, 已清除或禁用 {{count}} 个不兼容的子模型",
"setAsCanvasInitialImage": "设为画布初始图像" "setAsCanvasInitialImage": "设为画布初始图像",
"invalidUpload": "无效的上传"
}, },
"unifiedCanvas": { "unifiedCanvas": {
"layer": "图层", "layer": "图层",
@ -748,7 +769,7 @@
"maskingOptions": "遮罩选项", "maskingOptions": "遮罩选项",
"enableMask": "启用遮罩", "enableMask": "启用遮罩",
"preserveMaskedArea": "保留遮罩区域", "preserveMaskedArea": "保留遮罩区域",
"clearMask": "清除遮罩", "clearMask": "清除遮罩 (Shift+C)",
"brush": "刷子", "brush": "刷子",
"eraser": "橡皮擦", "eraser": "橡皮擦",
"fillBoundingBox": "填充选择区域", "fillBoundingBox": "填充选择区域",
@ -801,7 +822,8 @@
"betaPreserveMasked": "保留遮罩层", "betaPreserveMasked": "保留遮罩层",
"antialiasing": "抗锯齿", "antialiasing": "抗锯齿",
"showResultsOn": "显示结果 (开)", "showResultsOn": "显示结果 (开)",
"showResultsOff": "显示结果 (关)" "showResultsOff": "显示结果 (关)",
"saveMask": "保存 $t(unifiedCanvas.mask)"
}, },
"accessibility": { "accessibility": {
"modelSelect": "模型选择", "modelSelect": "模型选择",
@ -826,7 +848,9 @@
"menu": "菜单", "menu": "菜单",
"showGalleryPanel": "显示图库浮窗", "showGalleryPanel": "显示图库浮窗",
"loadMore": "加载更多", "loadMore": "加载更多",
"mode": "模式" "mode": "模式",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "创建问题"
}, },
"ui": { "ui": {
"showProgressImages": "显示处理中的图片", "showProgressImages": "显示处理中的图片",
@ -877,7 +901,7 @@
"animatedEdges": "边缘动效", "animatedEdges": "边缘动效",
"nodeTemplate": "节点模板", "nodeTemplate": "节点模板",
"pickOne": "选择一个", "pickOne": "选择一个",
"unableToLoadWorkflow": "无法验证工作流", "unableToLoadWorkflow": "无法加载工作流",
"snapToGrid": "对齐网格", "snapToGrid": "对齐网格",
"noFieldsLinearview": "线性视图中未添加任何字段", "noFieldsLinearview": "线性视图中未添加任何字段",
"nodeSearch": "检索节点", "nodeSearch": "检索节点",
@ -929,7 +953,7 @@
"skippingUnknownOutputType": "跳过未知类型的输出", "skippingUnknownOutputType": "跳过未知类型的输出",
"latentsFieldDescription": "Latents 可以在节点间传递。", "latentsFieldDescription": "Latents 可以在节点间传递。",
"denoiseMaskFieldDescription": "去噪遮罩可以在节点间传递", "denoiseMaskFieldDescription": "去噪遮罩可以在节点间传递",
"missingTemplate": "缺失模板", "missingTemplate": "无效的节点:类型为 {{type}} 的节点 {{node}} 缺失模板(无已安装模板?)",
"outputSchemaNotFound": "未找到输出模式", "outputSchemaNotFound": "未找到输出模式",
"latentsPolymorphicDescription": "Latents 可以在节点间传递。", "latentsPolymorphicDescription": "Latents 可以在节点间传递。",
"colorFieldDescription": "一种 RGBA 颜色。", "colorFieldDescription": "一种 RGBA 颜色。",
@ -957,7 +981,7 @@
"collectionItem": "项目合集", "collectionItem": "项目合集",
"controlCollectionDescription": "节点间传递的控制信息。", "controlCollectionDescription": "节点间传递的控制信息。",
"skippedReservedInput": "跳过保留的输入", "skippedReservedInput": "跳过保留的输入",
"outputFields": "输出", "outputFields": "输出区域",
"edge": "边缘", "edge": "边缘",
"inputNode": "输入节点", "inputNode": "输入节点",
"enumDescription": "枚举 (Enums) 可能是多个选项的一个数值。", "enumDescription": "枚举 (Enums) 可能是多个选项的一个数值。",
@ -992,7 +1016,7 @@
"string": "字符串", "string": "字符串",
"inputFields": "输入", "inputFields": "输入",
"uNetFieldDescription": "UNet 子模型。", "uNetFieldDescription": "UNet 子模型。",
"mismatchedVersion": "不匹配的版本", "mismatchedVersion": "无效的节点:类型为 {{type}} 的节点 {{node}} 版本不匹配(是否尝试更新?)",
"vaeFieldDescription": "Vae 子模型。", "vaeFieldDescription": "Vae 子模型。",
"imageFieldDescription": "图像可以在节点间传递。", "imageFieldDescription": "图像可以在节点间传递。",
"outputNode": "输出节点", "outputNode": "输出节点",
@ -1050,8 +1074,36 @@
"latentsPolymorphic": "Latents 多态", "latentsPolymorphic": "Latents 多态",
"conditioningField": "条件", "conditioningField": "条件",
"latentsField": "Latents", "latentsField": "Latents",
"updateAllNodes": "更新所有节点", "updateAllNodes": "更新节点",
"unableToUpdateNodes_other": "{{count}} 个节点无法完成更新" "unableToUpdateNodes_other": "{{count}} 个节点无法完成更新",
"inputFieldTypeParseError": "无法解析 {{node}} 的输入类型 {{field}}。({{message}})",
"unsupportedArrayItemType": "不支持的数组类型 \"{{type}}\"",
"addLinearView": "添加到线性视图",
"targetNodeFieldDoesNotExist": "无效的边缘:{{node}} 的目标/输入区域 {{field}} 不存在",
"unsupportedMismatchedUnion": "合集或标量类型与基类 {{firstType}} 和 {{secondType}} 不匹配",
"allNodesUpdated": "已更新所有节点",
"sourceNodeDoesNotExist": "无效的边缘:{{node}} 的源/输出节点不存在",
"unableToExtractEnumOptions": "无法提取枚举选项",
"unableToParseFieldType": "无法解析类型",
"outputFieldInInput": "输入中的输出区域",
"unrecognizedWorkflowVersion": "无法识别的工作流架构版本:{{version}}",
"outputFieldTypeParseError": "无法解析 {{node}} 的输出类型 {{field}}。({{message}})",
"sourceNodeFieldDoesNotExist": "无效的边缘:{{node}} 的源/输出区域 {{field}} 不存在",
"unableToGetWorkflowVersion": "无法获取工作流架构版本",
"nodePack": "节点包",
"unableToExtractSchemaNameFromRef": "无法从参考中提取架构名",
"unableToMigrateWorkflow": "无法迁移工作流",
"unknownOutput": "未知输出:{{name}}",
"unableToUpdateNode": "无法更新节点",
"unknownErrorValidatingWorkflow": "验证工作流时出现未知错误",
"collectionFieldType": "{{name}} 合集",
"unknownNodeType": "未知节点类型",
"targetNodeDoesNotExist": "无效的边缘:{{node}} 的目标/输入节点不存在",
"unknownFieldType": "$t(nodes.unknownField) 类型:{{type}}",
"collectionOrScalarFieldType": "{{name}} 合集 | 标量",
"nodeVersion": "节点版本",
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
"unknownInput": "未知输入:{{name}}"
}, },
"controlnet": { "controlnet": {
"resize": "直接缩放", "resize": "直接缩放",
@ -1137,8 +1189,7 @@
"openPose": "Openpose", "openPose": "Openpose",
"controlAdapter_other": "Control Adapters", "controlAdapter_other": "Control Adapters",
"lineartAnime": "Lineart Anime", "lineartAnime": "Lineart Anime",
"canny": "Canny", "canny": "Canny"
"unstarImage": "取消收藏图像"
}, },
"queue": { "queue": {
"status": "状态", "status": "状态",
@ -1246,7 +1297,8 @@
"fit": "图生图匹配", "fit": "图生图匹配",
"recallParameters": "召回参数", "recallParameters": "召回参数",
"noRecallParameters": "未找到要召回的参数", "noRecallParameters": "未找到要召回的参数",
"vae": "VAE" "vae": "VAE",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
}, },
"models": { "models": {
"noMatchingModels": "无相匹配的模型", "noMatchingModels": "无相匹配的模型",
@ -1259,7 +1311,8 @@
"noRefinerModelsInstalled": "无已安装的 SDXL Refiner 模型", "noRefinerModelsInstalled": "无已安装的 SDXL Refiner 模型",
"noLoRAsInstalled": "无已安装的 LoRA", "noLoRAsInstalled": "无已安装的 LoRA",
"esrganModel": "ESRGAN 模型", "esrganModel": "ESRGAN 模型",
"addLora": "添加 LoRA" "addLora": "添加 LoRA",
"noLoRAsLoaded": "无已加载的 LoRA"
}, },
"boards": { "boards": {
"autoAddBoard": "自动添加面板", "autoAddBoard": "自动添加面板",
@ -1281,12 +1334,14 @@
"deleteBoardOnly": "仅删除面板", "deleteBoardOnly": "仅删除面板",
"deleteBoard": "删除面板", "deleteBoard": "删除面板",
"deleteBoardAndImages": "删除面板和图像", "deleteBoardAndImages": "删除面板和图像",
"deletedBoardsCannotbeRestored": "已删除的面板无法被恢复" "deletedBoardsCannotbeRestored": "已删除的面板无法被恢复",
"movingImagesToBoard_other": "移动 {{count}} 张图像到面板:"
}, },
"embedding": { "embedding": {
"noMatchingEmbedding": "不匹配的 Embedding", "noMatchingEmbedding": "不匹配的 Embedding",
"addEmbedding": "添加 Embedding", "addEmbedding": "添加 Embedding",
"incompatibleModel": "不兼容的基础模型:" "incompatibleModel": "不兼容的基础模型:",
"noEmbeddingsLoaded": "无已加载的 Embedding"
}, },
"dynamicPrompts": { "dynamicPrompts": {
"seedBehaviour": { "seedBehaviour": {
@ -1515,6 +1570,12 @@
"ControlNet 为生成过程提供引导,为生成具有受控构图、结构、样式的图像提供帮助,具体的功能由所选的模型决定。" "ControlNet 为生成过程提供引导,为生成具有受控构图、结构、样式的图像提供帮助,具体的功能由所选的模型决定。"
], ],
"heading": "ControlNet" "heading": "ControlNet"
},
"paramCFGRescaleMultiplier": {
"heading": "CFG 重缩放倍数",
"paragraphs": [
"CFG 引导的重缩放倍率,用于通过 zero-terminal SNR (ztsnr) 训练的模型。推荐设为 0.7。"
]
} }
}, },
"invocationCache": { "invocationCache": {
@ -1531,7 +1592,8 @@
"enable": "启用", "enable": "启用",
"clear": "清除", "clear": "清除",
"maxCacheSize": "最大缓存大小", "maxCacheSize": "最大缓存大小",
"cacheSize": "缓存大小" "cacheSize": "缓存大小",
"useCache": "使用缓存"
}, },
"hrf": { "hrf": {
"enableHrf": "启用高分辨率修复", "enableHrf": "启用高分辨率修复",

View File

@ -21,6 +21,7 @@ import GlobalHotkeys from './GlobalHotkeys';
import PreselectedImage from './PreselectedImage'; import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster'; import Toaster from './Toaster';
import { useSocketIO } from 'app/hooks/useSocketIO'; import { useSocketIO } from 'app/hooks/useSocketIO';
import { useClearStorage } from 'common/hooks/useClearStorage';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -36,15 +37,16 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
const language = useAppSelector(languageSelector); const language = useAppSelector(languageSelector);
const logger = useLogger('system'); const logger = useLogger('system');
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const clearStorage = useClearStorage();
// singleton! // singleton!
useSocketIO(); useSocketIO();
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
localStorage.clear(); clearStorage();
location.reload(); location.reload();
return false; return false;
}, []); }, [clearStorage]);
useEffect(() => { useEffect(() => {
i18n.changeLanguage(language); i18n.changeLanguage(language);

View File

@ -7,21 +7,23 @@ import { $headerComponent } from 'app/store/nanostores/headerComponent';
import { $isDebugging } from 'app/store/nanostores/isDebugging'; import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { $projectId } from 'app/store/nanostores/projectId'; import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId'; import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
import { store } from 'app/store/store'; import { $store } from 'app/store/nanostores/store';
import { createStore } from 'app/store/store';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import 'i18n';
import React, { import React, {
PropsWithChildren, PropsWithChildren,
ReactNode, ReactNode,
lazy, lazy,
memo, memo,
useEffect, useEffect,
useMemo,
} from 'react'; } from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { ManagerOptions, SocketOptions } from 'socket.io-client'; import { ManagerOptions, SocketOptions } from 'socket.io-client';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import 'i18n';
const App = lazy(() => import('./App')); const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -137,6 +139,14 @@ const InvokeAIUI = ({
}; };
}, [isDebugging]); }, [isDebugging]);
const store = useMemo(() => {
return createStore(projectId);
}, [projectId]);
useEffect(() => {
$store.set(store);
}, [store]);
return ( return (
<React.StrictMode> <React.StrictMode>
<Provider store={store}> <Provider store={store}>

View File

@ -9,9 +9,9 @@ import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter'; import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core'; import { MantineProvider } from '@mantine/core';
import { useMantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css'; import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = { type ThemeLocaleProviderProps = {
children: ReactNode; children: ReactNode;

View File

@ -3,8 +3,8 @@ import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging'; import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { MapStore, WritableAtom, atom, map } from 'nanostores'; import { MapStore, atom, map } from 'nanostores';
import { useEffect } from 'react'; import { useEffect, useMemo } from 'react';
import { import {
ClientToServerEvents, ClientToServerEvents,
ServerToClientEvents, ServerToClientEvents,
@ -16,57 +16,10 @@ import { ManagerOptions, Socket, SocketOptions, io } from 'socket.io-client';
declare global { declare global {
interface Window { interface Window {
$socketOptions?: MapStore<Partial<ManagerOptions & SocketOptions>>; $socketOptions?: MapStore<Partial<ManagerOptions & SocketOptions>>;
$socketUrl?: WritableAtom<string>;
} }
} }
const makeSocketOptions = (): Partial<ManagerOptions & SocketOptions> => {
const socketOptions: Parameters<typeof io>[0] = {
timeout: 60000,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
forceNew: true,
};
// if building in package mode, replace socket url with open api base url minus the http protocol
if (['nodes', 'package'].includes(import.meta.env.MODE)) {
const authToken = $authToken.get();
if (authToken) {
// TODO: handle providing jwt to socket.io
socketOptions.auth = { token: authToken };
}
socketOptions.transports = ['websocket', 'polling'];
}
return socketOptions;
};
const makeSocketUrl = (): string => {
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
let socketUrl = `${wsProtocol}://${window.location.host}`;
if (['nodes', 'package'].includes(import.meta.env.MODE)) {
const baseUrl = $baseUrl.get();
if (baseUrl) {
//eslint-disable-next-line
socketUrl = baseUrl.replace(/^https?\:\/\//i, '');
}
}
return socketUrl;
};
const makeSocket = (): Socket<ServerToClientEvents, ClientToServerEvents> => {
const socketOptions = makeSocketOptions();
const socketUrl = $socketUrl.get();
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(
socketUrl,
{ ...socketOptions, ...$socketOptions.get() }
);
return socket;
};
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({}); export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
export const $socketUrl = atom<string>(makeSocketUrl());
export const $isSocketInitialized = atom<boolean>(false); export const $isSocketInitialized = atom<boolean>(false);
/** /**
@ -74,23 +27,50 @@ export const $isSocketInitialized = atom<boolean>(false);
*/ */
export const useSocketIO = () => { export const useSocketIO = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const socketOptions = useStore($socketOptions);
const socketUrl = useStore($socketUrl);
const baseUrl = useStore($baseUrl); const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken); const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
const socketUrl = useMemo(() => {
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
if (baseUrl) {
return baseUrl.replace(/^https?:\/\//i, '');
}
return `${wsProtocol}://${window.location.host}`;
}, [baseUrl]);
const socketOptions = useMemo(() => {
const options: Parameters<typeof io>[0] = {
timeout: 60000,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
forceNew: true,
};
if (authToken) {
options.auth = { token: authToken };
options.transports = ['websocket', 'polling'];
}
return { ...options, ...addlSocketOptions };
}, [authToken, addlSocketOptions]);
useEffect(() => { useEffect(() => {
if ($isSocketInitialized.get()) { if ($isSocketInitialized.get()) {
// Singleton! // Singleton!
return; return;
} }
const socket = makeSocket();
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(
socketUrl,
socketOptions
);
setEventListeners({ dispatch, socket }); setEventListeners({ dispatch, socket });
socket.connect(); socket.connect();
if ($isDebugging.get()) { if ($isDebugging.get()) {
window.$socketOptions = $socketOptions; window.$socketOptions = $socketOptions;
window.$socketUrl = $socketUrl;
console.log('Socket initialized', socket); console.log('Socket initialized', socket);
} }
@ -99,11 +79,10 @@ export const useSocketIO = () => {
return () => { return () => {
if ($isDebugging.get()) { if ($isDebugging.get()) {
window.$socketOptions = undefined; window.$socketOptions = undefined;
window.$socketUrl = undefined;
console.log('Socket teardown', socket); console.log('Socket teardown', socket);
} }
socket.disconnect(); socket.disconnect();
$isSocketInitialized.set(false); $isSocketInitialized.set(false);
}; };
}, [dispatch, socketOptions, socketUrl, baseUrl, authToken]); }, [dispatch, socketOptions, socketUrl]);
}; };

View File

@ -1,8 +1 @@
export const LOCALSTORAGE_KEYS = [ export const STORAGE_PREFIX = '@@invokeai-';
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';

View File

@ -23,16 +23,16 @@ import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { Driver, rememberEnhancer, rememberReducer } from 'redux-remember';
import { api } from 'services/api'; import { api } from 'services/api';
import { LOCALSTORAGE_PREFIX } from './constants'; import { STORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { $store } from './nanostores/store'; import { createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
const allReducers = { const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
@ -74,16 +74,28 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'modelmanager', 'modelmanager',
]; ];
export const store = configureStore({ // Create a custom idb-keyval store (just needed to customize the name)
export const idbKeyValStore = createIDBKeyValStore('invoke', 'invoke-store');
// Create redux-remember driver, wrapping idb-keyval
const idbKeyValDriver: Driver = {
getItem: (key) => get(key, idbKeyValStore),
setItem: (key, value) => set(key, value, idbKeyValStore),
};
export const createStore = (uniqueStoreKey?: string) =>
configureStore({
reducer: rememberedRootReducer, reducer: rememberedRootReducer,
enhancers: (existingEnhancers) => { enhancers: (existingEnhancers) => {
return existingEnhancers return existingEnhancers
.concat( .concat(
rememberEnhancer(window.localStorage, rememberedKeys, { rememberEnhancer(idbKeyValDriver, rememberedKeys, {
persistDebounce: 300, persistDebounce: 300,
serialize, serialize,
unserialize, unserialize,
prefix: LOCALSTORAGE_PREFIX, prefix: uniqueStoreKey
? `${STORAGE_PREFIX}${uniqueStoreKey}-`
: STORAGE_PREFIX,
}) })
) )
.concat(autoBatchEnhancer()); .concat(autoBatchEnhancer());
@ -121,10 +133,11 @@ export const store = configureStore({
}, },
}); });
export type AppGetState = typeof store.getState; export type AppGetState = ReturnType<
export type RootState = ReturnType<typeof store.getState>; ReturnType<typeof createStore>['getState']
>;
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch; export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export const stateSelector = (state: RootState) => state; export const stateSelector = (state: RootState) => state;
$store.set(store);

View File

@ -25,6 +25,7 @@ export type Feature =
| 'lora' | 'lora'
| 'noiseUseCPU' | 'noiseUseCPU'
| 'paramCFGScale' | 'paramCFGScale'
| 'paramCFGRescaleMultiplier'
| 'paramDenoisingStrength' | 'paramDenoisingStrength'
| 'paramIterations' | 'paramIterations'
| 'paramModel' | 'paramModel'

View File

@ -0,0 +1,12 @@
import { idbKeyValStore } from 'app/store/store';
import { clear } from 'idb-keyval';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clear(idbKeyValStore);
localStorage.clear();
}, []);
return clearStorage;
};

View File

@ -5,14 +5,19 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'features/dnd/types'; } from 'features/dnd/types';
import { setHeight, setWidth } from 'features/parameters/store/generationSlice'; import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa'; import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
import { import {
@ -22,11 +27,6 @@ import {
useRemoveImageFromBoardMutation, useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types'; import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
type Props = { type Props = {
id: string; id: string;
@ -35,13 +35,15 @@ type Props = {
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
({ controlAdapters, gallery }) => { ({ controlAdapters, gallery, system }) => {
const { pendingControlImages } = controlAdapters; const { pendingControlImages } = controlAdapters;
const { autoAddBoardId } = gallery; const { autoAddBoardId } = gallery;
const { isConnected } = system;
return { return {
pendingControlImages, pendingControlImages,
autoAddBoardId, autoAddBoardId,
isConnected,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -55,18 +57,19 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector); const { pendingControlImages, autoAddBoardId, isConnected } =
useAppSelector(selector);
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage } = useGetImageDTOQuery( const { currentData: controlImage, isError: isErrorControlImage } =
controlImageName ?? skipToken useGetImageDTOQuery(controlImageName ?? skipToken);
);
const { currentData: processedControlImage } = useGetImageDTOQuery( const {
processedControlImageName ?? skipToken currentData: processedControlImage,
); isError: isErrorProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation(); const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation(); const [addToBoard] = useAddImageToBoardMutation();
@ -158,6 +161,17 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
!pendingControlImages.includes(id) && !pendingControlImages.includes(id) &&
processorType !== 'none'; processorType !== 'none';
useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
handleResetControlImage();
}
}, [
handleResetControlImage,
isConnected,
isErrorControlImage,
isErrorProcessedControlImage,
]);
return ( return (
<Flex <Flex
onMouseEnter={handleMouseEnter} onMouseEnter={handleMouseEnter}

View File

@ -73,7 +73,13 @@ const BoardContextMenu = ({
addToast({ addToast({
title: t('gallery.preparingDownload'), title: t('gallery.preparingDownload'),
status: 'success', status: 'success',
...(response.response ? { description: response.response } : {}), ...(response.response
? {
description: response.response,
duration: null,
isClosable: true,
}
: {}),
}) })
); );
} catch { } catch {

View File

@ -59,7 +59,13 @@ const MultipleSelectionMenuItems = () => {
addToast({ addToast({
title: t('gallery.preparingDownload'), title: t('gallery.preparingDownload'),
status: 'success', status: 'success',
...(response.response ? { description: response.response } : {}), ...(response.response
? {
description: response.response,
duration: null,
isClosable: true,
}
: {}),
}) })
); );
} catch { } catch {

View File

@ -234,14 +234,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
icon={customStarUi ? customStarUi.off.icon : <MdStar />} icon={customStarUi ? customStarUi.off.icon : <MdStar />}
onClickCapture={handleUnstarImage} onClickCapture={handleUnstarImage}
> >
{customStarUi ? customStarUi.off.text : t('controlnet.unstarImage')} {customStarUi ? customStarUi.off.text : t('gallery.unstarImage')}
</MenuItem> </MenuItem>
) : ( ) : (
<MenuItem <MenuItem
icon={customStarUi ? customStarUi.on.icon : <MdStarBorder />} icon={customStarUi ? customStarUi.on.icon : <MdStarBorder />}
onClickCapture={handleStarImage} onClickCapture={handleStarImage}
> >
{customStarUi ? customStarUi.on.text : `Star Image`} {customStarUi ? customStarUi.on.text : t('gallery.starImage')}
</MenuItem> </MenuItem>
)} )}
<MenuItem <MenuItem

View File

@ -29,6 +29,7 @@ const ImageMetadataActions = (props: Props) => {
recallNegativePrompt, recallNegativePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallCfgRescaleMultiplier,
recallModel, recallModel,
recallScheduler, recallScheduler,
recallVaeModel, recallVaeModel,
@ -85,6 +86,10 @@ const ImageMetadataActions = (props: Props) => {
recallCfgScale(metadata?.cfg_scale); recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]); }, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallCfgRescaleMultiplier = useCallback(() => {
recallCfgRescaleMultiplier(metadata?.cfg_rescale_multiplier);
}, [metadata?.cfg_rescale_multiplier, recallCfgRescaleMultiplier]);
const handleRecallStrength = useCallback(() => { const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength); recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]); }, [metadata?.strength, recallStrength]);
@ -243,6 +248,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallCfgScale} onClick={handleRecallCfgScale}
/> />
)} )}
{metadata.cfg_rescale_multiplier !== undefined &&
metadata.cfg_rescale_multiplier !== null && (
<ImageMetadataItem
label={t('metadata.cfgRescaleMultiplier')}
value={metadata.cfg_rescale_multiplier}
onClick={handleRecallCfgRescaleMultiplier}
/>
)}
{metadata.strength && ( {metadata.strength && (
<ImageMetadataItem <ImageMetadataItem
label={t('metadata.strength')} label={t('metadata.strength')}

View File

@ -1,6 +1,6 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { import {
@ -13,7 +13,7 @@ import {
ImageFieldInputTemplate, ImageFieldInputTemplate,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa'; import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@ -24,8 +24,8 @@ const ImageFieldInputComponent = (
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isConnected = useAppSelector((state) => state.system.isConnected);
const { currentData: imageDTO } = useGetImageDTOQuery( const { currentData: imageDTO, isError } = useGetImageDTOQuery(
field.value?.image_name ?? skipToken field.value?.image_name ?? skipToken
); );
@ -67,6 +67,12 @@ const ImageFieldInputComponent = (
[nodeId, field.name] [nodeId, field.name]
); );
useEffect(() => {
if (isConnected && isError) {
handleReset();
}
}, [handleReset, isConnected, isError]);
return ( return (
<Flex <Flex
className="nodrag" className="nodrag"

View File

@ -43,10 +43,10 @@ export class NodeUpdateError extends Error {
} }
/** /**
* FieldTypeParseError * FieldParseError
* Raised when a field cannot be parsed from a field schema. * Raised when a field cannot be parsed from a field schema.
*/ */
export class FieldTypeParseError extends Error { export class FieldParseError extends Error {
/** /**
* Create FieldTypeParseError * Create FieldTypeParseError
* @param {String} message * @param {String} message
@ -56,18 +56,3 @@ export class FieldTypeParseError extends Error {
this.name = this.constructor.name; this.name = this.constructor.name;
} }
} }
/**
* UnsupportedFieldTypeError
* Raised when an unsupported field type is parsed.
*/
export class UnsupportedFieldTypeError extends Error {
/**
* Create UnsupportedFieldTypeError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}

View File

@ -51,6 +51,7 @@ export const zCoreMetadata = z
seed: z.number().int().nullish().catch(null), seed: z.number().int().nullish().catch(null),
rand_device: z.string().nullish().catch(null), rand_device: z.string().nullish().catch(null),
cfg_scale: z.number().nullish().catch(null), cfg_scale: z.number().nullish().catch(null),
cfg_rescale_multiplier: z.number().nullish().catch(null),
steps: z.number().int().nullish().catch(null), steps: z.number().int().nullish().catch(null),
scheduler: z.string().nullish().catch(null), scheduler: z.string().nullish().catch(null),
clip_skip: z.number().int().nullish().catch(null), clip_skip: z.number().int().nullish().catch(null),

View File

@ -43,6 +43,7 @@ export const buildCanvasImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -316,6 +317,7 @@ export const buildCanvasImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -45,6 +45,7 @@ export const buildCanvasSDXLImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -327,6 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -43,6 +43,7 @@ export const buildCanvasSDXLTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -306,6 +307,7 @@ export const buildCanvasSDXLTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildCanvasTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -294,6 +295,7 @@ export const buildCanvasTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
width: !isUsingScaledDimensions width: !isUsingScaledDimensions
? width ? width
: scaledBoundingBoxDimensions.width, : scaledBoundingBoxDimensions.width,

View File

@ -41,6 +41,7 @@ export const buildLinearImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -316,6 +317,7 @@ export const buildLinearImageToImageGraph = (
{ {
generation_mode: 'img2img', generation_mode: 'img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -43,6 +43,7 @@ export const buildLinearSDXLImageToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -336,6 +337,7 @@ export const buildLinearSDXLImageToImageGraph = (
{ {
generation_mode: 'sdxl_img2img', generation_mode: 'sdxl_img2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -34,6 +34,7 @@ export const buildLinearSDXLTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -230,6 +231,7 @@ export const buildLinearSDXLTextToImageGraph = (
{ {
generation_mode: 'sdxl_txt2img', generation_mode: 'sdxl_txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -38,6 +38,7 @@ export const buildLinearTextToImageGraph = (
negativePrompt, negativePrompt,
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
width, width,
@ -84,6 +85,7 @@ export const buildLinearTextToImageGraph = (
id: DENOISE_LATENTS, id: DENOISE_LATENTS,
is_intermediate, is_intermediate,
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
scheduler, scheduler,
steps, steps,
denoising_start: 0, denoising_start: 0,
@ -239,6 +241,7 @@ export const buildLinearTextToImageGraph = (
{ {
generation_mode: 'txt2img', generation_mode: 'txt2img',
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
width, width,
positive_prompt: positivePrompt, positive_prompt: positivePrompt,

View File

@ -23,7 +23,12 @@ import {
VAEModelFieldInputTemplate, VAEModelFieldInputTemplate,
isStatefulFieldType, isStatefulFieldType,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import { InvocationFieldSchema } from 'features/nodes/types/openapi'; import {
InvocationFieldSchema,
isSchemaObject,
} from 'features/nodes/types/openapi';
import { t } from 'i18next';
import { FieldParseError } from 'features/nodes/types/error';
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`! type FieldInputTemplateBuilder<T extends FieldInputTemplate = any> = // valid `any`!
@ -321,7 +326,28 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder<
const buildEnumFieldInputTemplate: FieldInputTemplateBuilder< const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<
EnumFieldInputTemplate EnumFieldInputTemplate
> = ({ schemaObject, baseField, isCollection, isCollectionOrScalar }) => { > = ({ schemaObject, baseField, isCollection, isCollectionOrScalar }) => {
const options = schemaObject.enum ?? []; let options: EnumFieldInputTemplate['options'] = [];
if (schemaObject.anyOf) {
const filteredAnyOf = schemaObject.anyOf.filter((i) => {
if (isSchemaObject(i)) {
if (i.type === 'null') {
return false;
}
}
return true;
});
const firstAnyOf = filteredAnyOf[0];
if (filteredAnyOf.length !== 1 || !isSchemaObject(firstAnyOf)) {
options = [];
} else {
options = firstAnyOf.enum ?? [];
}
} else {
options = schemaObject.enum ?? [];
}
if (options.length === 0) {
throw new FieldParseError(t('nodes.unableToExtractEnumOptions'));
}
const template: EnumFieldInputTemplate = { const template: EnumFieldInputTemplate = {
...baseField, ...baseField,
type: { type: {

View File

@ -1,10 +1,4 @@
import { t } from 'i18next'; import { FieldParseError } from 'features/nodes/types/error';
import { isArray } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import {
FieldTypeParseError,
UnsupportedFieldTypeError,
} from 'features/nodes/types/error';
import { FieldType } from 'features/nodes/types/field'; import { FieldType } from 'features/nodes/types/field';
import { import {
OpenAPIV3_1SchemaOrRef, OpenAPIV3_1SchemaOrRef,
@ -14,6 +8,9 @@ import {
isRefObject, isRefObject,
isSchemaObject, isSchemaObject,
} from 'features/nodes/types/openapi'; } from 'features/nodes/types/openapi';
import { t } from 'i18next';
import { isArray } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
/** /**
* Transforms an invocation output ref object to field type. * Transforms an invocation output ref object to field type.
@ -70,7 +67,7 @@ export const parseFieldType = (
// This is a single ref type // This is a single ref type
const name = refObjectToSchemaName(allOf[0]); const name = refObjectToSchemaName(allOf[0]);
if (!name) { if (!name) {
throw new FieldTypeParseError( throw new FieldParseError(
t('nodes.unableToExtractSchemaNameFromRef') t('nodes.unableToExtractSchemaNameFromRef')
); );
} }
@ -95,7 +92,7 @@ export const parseFieldType = (
if (isRefObject(filteredAnyOf[0])) { if (isRefObject(filteredAnyOf[0])) {
const name = refObjectToSchemaName(filteredAnyOf[0]); const name = refObjectToSchemaName(filteredAnyOf[0]);
if (!name) { if (!name) {
throw new FieldTypeParseError( throw new FieldParseError(
t('nodes.unableToExtractSchemaNameFromRef') t('nodes.unableToExtractSchemaNameFromRef')
); );
} }
@ -120,7 +117,7 @@ export const parseFieldType = (
if (filteredAnyOf.length !== 2) { if (filteredAnyOf.length !== 2) {
// This is a union of more than 2 types, which we don't support // This is a union of more than 2 types, which we don't support
throw new UnsupportedFieldTypeError( throw new FieldParseError(
t('nodes.unsupportedAnyOfLength', { t('nodes.unsupportedAnyOfLength', {
count: filteredAnyOf.length, count: filteredAnyOf.length,
}) })
@ -167,7 +164,7 @@ export const parseFieldType = (
}; };
} }
throw new UnsupportedFieldTypeError( throw new FieldParseError(
t('nodes.unsupportedMismatchedUnion', { t('nodes.unsupportedMismatchedUnion', {
firstType, firstType,
secondType, secondType,
@ -186,7 +183,7 @@ export const parseFieldType = (
if (isSchemaObject(schemaObject.items)) { if (isSchemaObject(schemaObject.items)) {
const itemType = schemaObject.items.type; const itemType = schemaObject.items.type;
if (!itemType || isArray(itemType)) { if (!itemType || isArray(itemType)) {
throw new UnsupportedFieldTypeError( throw new FieldParseError(
t('nodes.unsupportedArrayItemType', { t('nodes.unsupportedArrayItemType', {
type: itemType, type: itemType,
}) })
@ -196,7 +193,7 @@ export const parseFieldType = (
const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType];
if (!name) { if (!name) {
// it's 'null', 'object', or 'array' - skip // it's 'null', 'object', or 'array' - skip
throw new UnsupportedFieldTypeError( throw new FieldParseError(
t('nodes.unsupportedArrayItemType', { t('nodes.unsupportedArrayItemType', {
type: itemType, type: itemType,
}) })
@ -212,7 +209,7 @@ export const parseFieldType = (
// This is a ref object, extract the type name // This is a ref object, extract the type name
const name = refObjectToSchemaName(schemaObject.items); const name = refObjectToSchemaName(schemaObject.items);
if (!name) { if (!name) {
throw new FieldTypeParseError( throw new FieldParseError(
t('nodes.unableToExtractSchemaNameFromRef') t('nodes.unableToExtractSchemaNameFromRef')
); );
} }
@ -226,7 +223,7 @@ export const parseFieldType = (
const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type];
if (!name) { if (!name) {
// it's 'null', 'object', or 'array' - skip // it's 'null', 'object', or 'array' - skip
throw new UnsupportedFieldTypeError( throw new FieldParseError(
t('nodes.unsupportedArrayItemType', { t('nodes.unsupportedArrayItemType', {
type: schemaObject.type, type: schemaObject.type,
}) })
@ -242,9 +239,7 @@ export const parseFieldType = (
} else if (isRefObject(schemaObject)) { } else if (isRefObject(schemaObject)) {
const name = refObjectToSchemaName(schemaObject); const name = refObjectToSchemaName(schemaObject);
if (!name) { if (!name) {
throw new FieldTypeParseError( throw new FieldParseError(t('nodes.unableToExtractSchemaNameFromRef'));
t('nodes.unableToExtractSchemaNameFromRef')
);
} }
return { return {
name, name,
@ -252,5 +247,5 @@ export const parseFieldType = (
isCollectionOrScalar: false, isCollectionOrScalar: false,
}; };
} }
throw new FieldTypeParseError(t('nodes.unableToParseFieldType')); throw new FieldParseError(t('nodes.unableToParseFieldType'));
}; };

View File

@ -1,12 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { t } from 'i18next'; import { FieldParseError } from 'features/nodes/types/error';
import { reduce } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import {
FieldTypeParseError,
UnsupportedFieldTypeError,
} from 'features/nodes/types/error';
import { import {
FieldInputTemplate, FieldInputTemplate,
FieldOutputTemplate, FieldOutputTemplate,
@ -18,6 +12,9 @@ import {
isInvocationOutputSchemaObject, isInvocationOutputSchemaObject,
isInvocationSchemaObject, isInvocationSchemaObject,
} from 'features/nodes/types/openapi'; } from 'features/nodes/types/openapi';
import { t } from 'i18next';
import { reduce } from 'lodash-es';
import { OpenAPIV3_1 } from 'openapi-types';
import { buildFieldInputTemplate } from './buildFieldInputTemplate'; import { buildFieldInputTemplate } from './buildFieldInputTemplate';
import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; import { buildFieldOutputTemplate } from './buildFieldOutputTemplate';
import { parseFieldType } from './parseFieldType'; import { parseFieldType } from './parseFieldType';
@ -133,10 +130,7 @@ export const parseSchema = (
inputsAccumulator[propertyName] = fieldInputTemplate; inputsAccumulator[propertyName] = fieldInputTemplate;
} catch (e) { } catch (e) {
if ( if (e instanceof FieldParseError) {
e instanceof FieldTypeParseError ||
e instanceof UnsupportedFieldTypeError
) {
logger('nodes').warn( logger('nodes').warn(
{ {
node: type, node: type,
@ -225,10 +219,7 @@ export const parseSchema = (
outputsAccumulator[propertyName] = fieldOutputTemplate; outputsAccumulator[propertyName] = fieldOutputTemplate;
} catch (e) { } catch (e) {
if ( if (e instanceof FieldParseError) {
e instanceof FieldTypeParseError ||
e instanceof UnsupportedFieldTypeError
) {
logger('nodes').warn( logger('nodes').warn(
{ {
node: type, node: type,

View File

@ -9,21 +9,41 @@ import { useTranslation } from 'react-i18next';
import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise'; import { ParamCpuNoiseToggle } from 'features/parameters/components/Parameters/Noise/ParamCpuNoise';
import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless'; import ParamSeamless from 'features/parameters/components/Parameters/Seamless/ParamSeamless';
import ParamClipSkip from './ParamClipSkip'; import ParamClipSkip from './ParamClipSkip';
import ParamCFGRescaleMultiplier from './ParamCFGRescaleMultiplier';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
(state: RootState) => { (state: RootState) => {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = const {
state.generation; clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = state.generation;
return { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise }; return {
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
};
}, },
defaultSelectorOptions defaultSelectorOptions
); );
export default function ParamAdvancedCollapse() { export default function ParamAdvancedCollapse() {
const { clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise } = const {
useAppSelector(selector); clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
cfgRescaleMultiplier,
} = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
const activeLabel = useMemo(() => { const activeLabel = useMemo(() => {
const activeLabel: string[] = []; const activeLabel: string[] = [];
@ -46,8 +66,20 @@ export default function ParamAdvancedCollapse() {
activeLabel.push(t('parameters.seamlessY')); activeLabel.push(t('parameters.seamlessY'));
} }
if (cfgRescaleMultiplier) {
activeLabel.push(t('parameters.cfgRescale'));
}
return activeLabel.join(', '); return activeLabel.join(', ');
}, [clipSkip, model, seamlessXAxis, seamlessYAxis, shouldUseCpuNoise, t]); }, [
cfgRescaleMultiplier,
clipSkip,
model,
seamlessXAxis,
seamlessYAxis,
shouldUseCpuNoise,
t,
]);
return ( return (
<IAICollapse label={t('common.advanced')} activeLabel={activeLabel}> <IAICollapse label={t('common.advanced')} activeLabel={activeLabel}>
@ -61,6 +93,8 @@ export default function ParamAdvancedCollapse() {
</> </>
)} )}
<ParamCpuNoiseToggle /> <ParamCpuNoiseToggle />
<Divider />
<ParamCFGRescaleMultiplier />
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover';
import IAISlider from 'common/components/IAISlider';
import { setCfgRescaleMultiplier } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ generation, hotkeys }) => {
const { cfgRescaleMultiplier } = generation;
const { shift } = hotkeys;
return {
cfgRescaleMultiplier,
shift,
};
},
defaultSelectorOptions
);
const ParamCFGRescaleMultiplier = () => {
const { cfgRescaleMultiplier, shift } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setCfgRescaleMultiplier(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setCfgRescaleMultiplier(0)),
[dispatch]
);
return (
<IAIInformationalPopover feature="paramCFGRescaleMultiplier">
<IAISlider
label={t('parameters.cfgRescaleMultiplier')}
step={shift ? 0.01 : 0.05}
min={0}
max={0.99}
onChange={handleChange}
handleReset={handleReset}
value={cfgRescaleMultiplier}
sliderNumberInputProps={{ max: 0.99 }}
withInput
withReset
withSliderMarks
isInteger={false}
/>
</IAIInformationalPopover>
);
};
export default memo(ParamCFGRescaleMultiplier);

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
@ -9,25 +9,30 @@ import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'features/dnd/types'; } from 'features/dnd/types';
import { memo, useMemo } from 'react'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { memo, useEffect, useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
(state) => { (state) => {
const { initialImage } = state.generation; const { initialImage } = state.generation;
const { isConnected } = state.system;
return { return {
initialImage, initialImage,
isResetButtonDisabled: !initialImage, isResetButtonDisabled: !initialImage,
isConnected,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const InitialImage = () => { const InitialImage = () => {
const { initialImage } = useAppSelector(selector); const dispatch = useAppDispatch();
const { initialImage, isConnected } = useAppSelector(selector);
const { currentData: imageDTO } = useGetImageDTOQuery( const { currentData: imageDTO, isError } = useGetImageDTOQuery(
initialImage?.imageName ?? skipToken initialImage?.imageName ?? skipToken
); );
@ -49,6 +54,13 @@ const InitialImage = () => {
[] []
); );
useEffect(() => {
if (isError && isConnected) {
// The image doesn't exist, reset init image
dispatch(clearInitialImage());
}
}, [dispatch, isConnected, isError]);
return ( return (
<IAIDndImage <IAIDndImage
imageDTO={imageDTO} imageDTO={imageDTO}

View File

@ -57,6 +57,7 @@ import {
modelSelected, modelSelected,
} from 'features/parameters/store/actions'; } from 'features/parameters/store/actions';
import { import {
setCfgRescaleMultiplier,
setCfgScale, setCfgScale,
setHeight, setHeight,
setHrfEnabled, setHrfEnabled,
@ -94,6 +95,7 @@ import {
isParameterStrength, isParameterStrength,
isParameterVAEModel, isParameterVAEModel,
isParameterWidth, isParameterWidth,
isParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
const selector = createSelector( const selector = createSelector(
@ -282,6 +284,21 @@ export const useRecallParameters = () => {
[dispatch, parameterSetToast, parameterNotSetToast] [dispatch, parameterSetToast, parameterNotSetToast]
); );
/**
* Recall CFG rescale multiplier with toast
*/
const recallCfgRescaleMultiplier = useCallback(
(cfgRescaleMultiplier: unknown) => {
if (!isParameterCFGRescaleMultiplier(cfgRescaleMultiplier)) {
parameterNotSetToast();
return;
}
dispatch(setCfgRescaleMultiplier(cfgRescaleMultiplier));
parameterSetToast();
},
[dispatch, parameterSetToast, parameterNotSetToast]
);
/** /**
* Recall model with toast * Recall model with toast
*/ */
@ -799,6 +816,7 @@ export const useRecallParameters = () => {
const { const {
cfg_scale, cfg_scale,
cfg_rescale_multiplier,
height, height,
model, model,
positive_prompt, positive_prompt,
@ -831,6 +849,10 @@ export const useRecallParameters = () => {
dispatch(setCfgScale(cfg_scale)); dispatch(setCfgScale(cfg_scale));
} }
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
if (isParameterModel(model)) { if (isParameterModel(model)) {
dispatch(modelSelected(model)); dispatch(modelSelected(model));
} }
@ -985,6 +1007,7 @@ export const useRecallParameters = () => {
recallSDXLNegativeStylePrompt, recallSDXLNegativeStylePrompt,
recallSeed, recallSeed,
recallCfgScale, recallCfgScale,
recallCfgRescaleMultiplier,
recallModel, recallModel,
recallScheduler, recallScheduler,
recallVaeModel, recallVaeModel,

View File

@ -24,6 +24,7 @@ import {
ParameterVAEModel, ParameterVAEModel,
ParameterWidth, ParameterWidth,
zParameterModel, zParameterModel,
ParameterCFGRescaleMultiplier,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
export interface GenerationState { export interface GenerationState {
@ -31,6 +32,7 @@ export interface GenerationState {
hrfStrength: ParameterStrength; hrfStrength: ParameterStrength;
hrfMethod: ParameterHRFMethod; hrfMethod: ParameterHRFMethod;
cfgScale: ParameterCFGScale; cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
height: ParameterHeight; height: ParameterHeight;
img2imgStrength: ParameterStrength; img2imgStrength: ParameterStrength;
infillMethod: string; infillMethod: string;
@ -76,6 +78,7 @@ export const initialGenerationState: GenerationState = {
hrfEnabled: false, hrfEnabled: false,
hrfMethod: 'ESRGAN', hrfMethod: 'ESRGAN',
cfgScale: 7.5, cfgScale: 7.5,
cfgRescaleMultiplier: 0,
height: 512, height: 512,
img2imgStrength: 0.75, img2imgStrength: 0.75,
infillMethod: 'patchmatch', infillMethod: 'patchmatch',
@ -145,9 +148,15 @@ export const generationSlice = createSlice({
state.steps state.steps
); );
}, },
setCfgScale: (state, action: PayloadAction<number>) => { setCfgScale: (state, action: PayloadAction<ParameterCFGScale>) => {
state.cfgScale = action.payload; state.cfgScale = action.payload;
}, },
setCfgRescaleMultiplier: (
state,
action: PayloadAction<ParameterCFGRescaleMultiplier>
) => {
state.cfgRescaleMultiplier = action.payload;
},
setThreshold: (state, action: PayloadAction<number>) => { setThreshold: (state, action: PayloadAction<number>) => {
state.threshold = action.payload; state.threshold = action.payload;
}, },
@ -336,6 +345,7 @@ export const {
resetParametersState, resetParametersState,
resetSeed, resetSeed,
setCfgScale, setCfgScale,
setCfgRescaleMultiplier,
setWidth, setWidth,
setHeight, setHeight,
toggleSize, toggleSize,

View File

@ -77,6 +77,17 @@ export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale =>
zParameterCFGScale.safeParse(val).success; zParameterCFGScale.safeParse(val).success;
// #endregion // #endregion
// #region CFG Rescale Multiplier
export const zParameterCFGRescaleMultiplier = z.number().gte(0).lt(1);
export type ParameterCFGRescaleMultiplier = z.infer<
typeof zParameterCFGRescaleMultiplier
>;
export const isParameterCFGRescaleMultiplier = (
val: unknown
): val is ParameterCFGRescaleMultiplier =>
zParameterCFGRescaleMultiplier.safeParse(val).success;
// #endregion
// #region Scheduler // #region Scheduler
export const zParameterScheduler = zSchedulerField; export const zParameterScheduler = zSchedulerField;
export type ParameterScheduler = z.infer<typeof zParameterScheduler>; export type ParameterScheduler = z.infer<typeof zParameterScheduler>;

View File

@ -14,11 +14,11 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { VALID_LOG_LEVELS } from 'app/logging/logger'; import { VALID_LOG_LEVELS } from 'app/logging/logger';
import { LOCALSTORAGE_KEYS, LOCALSTORAGE_PREFIX } from 'app/store/constants';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { import {
consoleLogLevelChanged, consoleLogLevelChanged,
setEnableImageDebugging, setEnableImageDebugging,
@ -164,20 +164,14 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
shouldEnableInformationalPopovers, shouldEnableInformationalPopovers,
} = useAppSelector(selector); } = useAppSelector(selector);
const clearStorage = useClearStorage();
const handleClickResetWebUI = useCallback(() => { const handleClickResetWebUI = useCallback(() => {
// Only remove our keys clearStorage();
Object.keys(window.localStorage).forEach((key) => {
if (
LOCALSTORAGE_KEYS.includes(key) ||
key.startsWith(LOCALSTORAGE_PREFIX)
) {
localStorage.removeItem(key);
}
});
onSettingsModalClose(); onSettingsModalClose();
onRefreshModalOpen(); onRefreshModalOpen();
setInterval(() => setCountdown((prev) => prev - 1), 1000); setInterval(() => setCountdown((prev) => prev - 1), 1000);
}, [onSettingsModalClose, onRefreshModalOpen]); }, [clearStorage, onSettingsModalClose, onRefreshModalOpen]);
useEffect(() => { useEffect(() => {
if (countdown <= 0) { if (countdown <= 0) {

File diff suppressed because one or more lines are too long

View File

@ -4158,6 +4158,11 @@ i18next@^23.6.0:
dependencies: dependencies:
"@babel/runtime" "^7.22.5" "@babel/runtime" "^7.22.5"
idb-keyval@^6.2.1:
version "6.2.1"
resolved "https://registry.yarnpkg.com/idb-keyval/-/idb-keyval-6.2.1.tgz#94516d625346d16f56f3b33855da11bfded2db33"
integrity sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==
ieee754@^1.1.13: ieee754@^1.1.13:
version "1.2.1" version "1.2.1"
resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352" resolved "https://registry.yarnpkg.com/ieee754/-/ieee754-1.2.1.tgz#8eb7a10a63fff25d15a57b001586d177d1b0d352"

View File

@ -54,7 +54,8 @@ dependencies = [
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"numpy", # Minimum numpy version of 1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal().
"numpy>=1.24.0",
"npyscreen", "npyscreen",
"omegaconf", "omegaconf",
"onnx", "onnx",

View File

@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device):
"unet_model_id": "runwayml/stable-diffusion-v1-5", "unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5", "unet_model_name": "stable-diffusion-v1-5",
}, },
# SD1.5, IPAdapterFull
{
"ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15",
"ip_adapter_model_name": "ip-adapter-full-face_sd15",
"base_model": BaseModelType.StableDiffusion1,
"unet_model_id": "runwayml/stable-diffusion-v1-5",
"unet_model_name": "stable-diffusion-v1-5",
},
], ],
) )
@pytest.mark.slow @pytest.mark.slow

View File

@ -0,0 +1,57 @@
"""
Test interaction of logging with configuration system.
"""
import io
import logging
import re
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import LOG_FORMATTERS, InvokeAILogger
# test formatting
# Would prefer to use the capfd/capsys fixture here, but it is broken
# when used with the logging module: https://github.com/pytest-dev/pytest/issue
def test_formatting():
logger = InvokeAILogger.get_logger()
stream = io.StringIO()
handler = logging.StreamHandler(stream)
handler.setFormatter(LOG_FORMATTERS["plain"]())
logger.addHandler(handler)
logger.info("test1")
output = stream.getvalue()
assert re.search(r"\[InvokeAI\]::INFO --> test1$", output)
handler.setFormatter(LOG_FORMATTERS["legacy"]())
logger.info("test2")
output = stream.getvalue()
assert re.search(r">> test2$", output)
# test independence of two loggers with different names
def test_independence():
logger1 = InvokeAILogger.get_logger()
logger2 = InvokeAILogger.get_logger("Test")
assert logger1.name == "InvokeAI"
assert logger2.name == "Test"
assert logger1.level == logging.INFO
assert logger2.level == logging.INFO
logger2.setLevel(logging.DEBUG)
assert logger1.level == logging.INFO
assert logger2.level == logging.DEBUG
# test that the logger is returned from two similar get_logger() calls
def test_retrieval():
logger1 = InvokeAILogger.get_logger()
logger2 = InvokeAILogger.get_logger()
logger3 = InvokeAILogger.get_logger("Test")
assert logger1 == logger2
assert logger1 != logger3
# test that the configuration is used to set the initial logging level
def test_config():
config = InvokeAIAppConfig(log_level="debug")
logger1 = InvokeAILogger.get_logger("DebugTest", config=config)
assert logger1.level == logging.DEBUG