mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat_compel_and
This commit is contained in:
commit
65feb92286
37
.gitignore
vendored
37
.gitignore
vendored
@ -1,23 +1,8 @@
|
|||||||
# ignore default image save location and model symbolic link
|
|
||||||
.idea/
|
.idea/
|
||||||
embeddings/
|
|
||||||
outputs/
|
|
||||||
models/ldm/stable-diffusion-v1/model.ckpt
|
|
||||||
**/restoration/codeformer/weights
|
|
||||||
|
|
||||||
# ignore user models config
|
|
||||||
configs/models.user.yaml
|
|
||||||
config/models.user.yml
|
|
||||||
invokeai.init
|
|
||||||
.version
|
|
||||||
.last_model
|
|
||||||
|
|
||||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||||
anaconda.sh
|
anaconda.sh
|
||||||
|
|
||||||
# ignore a directory which serves as a place for initial images
|
|
||||||
inputs/
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
@ -189,39 +174,17 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
src
|
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
outputs
|
|
||||||
|
|
||||||
# Logs and associated folders
|
|
||||||
# created from generated embeddings.
|
|
||||||
logs
|
|
||||||
testtube
|
|
||||||
checkpoints
|
|
||||||
# If it's a Mac
|
# If it's a Mac
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
invokeai/frontend/yarn.lock
|
|
||||||
invokeai/frontend/node_modules
|
|
||||||
|
|
||||||
# Let the frontend manage its own gitignore
|
# Let the frontend manage its own gitignore
|
||||||
!invokeai/frontend/web/*
|
!invokeai/frontend/web/*
|
||||||
|
|
||||||
# Scratch folder
|
# Scratch folder
|
||||||
.scratch/
|
.scratch/
|
||||||
.vscode/
|
.vscode/
|
||||||
gfpgan/
|
|
||||||
models/ldm/stable-diffusion-v1/*.sha256
|
|
||||||
|
|
||||||
|
|
||||||
# GFPGAN model files
|
|
||||||
gfpgan/
|
|
||||||
|
|
||||||
# config file (will be created by installer)
|
|
||||||
configs/models.yaml
|
|
||||||
|
|
||||||
# ignore initfile
|
|
||||||
.invokeai
|
|
||||||
|
|
||||||
# ignore environment.yml and requirements.txt
|
# ignore environment.yml and requirements.txt
|
||||||
# these are links to the real files in environments-and-requirements
|
# these are links to the real files in environments-and-requirements
|
||||||
|
@ -175,22 +175,27 @@ These configuration settings allow you to enable and disable various InvokeAI fe
|
|||||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||||
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
|
|
||||||
|
|
||||||
### Memory/Performance
|
### Generation
|
||||||
|
|
||||||
These options tune InvokeAI's memory and performance characteristics.
|
These options tune InvokeAI's memory and performance characteristics.
|
||||||
|
|
||||||
| Setting | Default Value | Description |
|
| Setting | Default Value | Description |
|
||||||
|----------|----------------|--------------|
|
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
|
|
||||||
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
|
|
||||||
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
|
|
||||||
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
|
|
||||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
|
||||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||||
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
|
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
|
||||||
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
|
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
|
||||||
|
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
|
||||||
|
|
||||||
|
### Device
|
||||||
|
|
||||||
|
These options configure the generation execution device.
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
|
||||||
|
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||||
|
|
||||||
|
|
||||||
### Paths
|
### Paths
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
|||||||
|
|
||||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ["tile"]
|
infill_methods = ["tile", "lama"]
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append("patchmatch")
|
infill_methods.append("patchmatch")
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ def custom_openapi():
|
|||||||
|
|
||||||
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
output_schemas = schema(output_types, ref_prefix="#/components/schemas/")
|
||||||
for schema_key, output_schema in output_schemas["definitions"].items():
|
for schema_key, output_schema in output_schemas["definitions"].items():
|
||||||
|
output_schema["class"] = "output"
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||||
|
|
||||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||||
@ -130,8 +131,8 @@ def custom_openapi():
|
|||||||
|
|
||||||
# Add Node Editor UI helper schemas
|
# Add Node Editor UI helper schemas
|
||||||
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/")
|
||||||
for schema_key, output_schema in ui_config_schemas["definitions"].items():
|
for schema_key, ui_config_schema in ui_config_schemas["definitions"].items():
|
||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
@ -140,8 +141,8 @@ def custom_openapi():
|
|||||||
output_type_title = output_type_titles[output_type.__name__]
|
output_type_title = output_type_titles[output_type.__name__]
|
||||||
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
invoker_schema = openapi_schema["components"]["schemas"][invoker_name]
|
||||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||||
|
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
|
invoker_schema["class"] = "invocation"
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
|
||||||
|
@ -71,6 +71,9 @@ class FieldDescriptions:
|
|||||||
safe_mode = "Whether or not to use safe mode"
|
safe_mode = "Whether or not to use safe mode"
|
||||||
scribble_mode = "Whether or not to use scribble mode"
|
scribble_mode = "Whether or not to use scribble mode"
|
||||||
scale_factor = "The factor by which to scale"
|
scale_factor = "The factor by which to scale"
|
||||||
|
blend_alpha = (
|
||||||
|
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||||
|
)
|
||||||
num_1 = "The first number"
|
num_1 = "The first number"
|
||||||
num_2 = "The second number"
|
num_2 = "The second number"
|
||||||
mask = "The mask to use for the operation"
|
mask = "The mask to use for the operation"
|
||||||
@ -140,6 +143,7 @@ class UIType(str, Enum):
|
|||||||
# region Misc
|
# region Misc
|
||||||
FilePath = "FilePath"
|
FilePath = "FilePath"
|
||||||
Enum = "enum"
|
Enum = "enum"
|
||||||
|
Scheduler = "Scheduler"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
@ -166,6 +170,7 @@ class _InputField(BaseModel):
|
|||||||
ui_hidden: bool
|
ui_hidden: bool
|
||||||
ui_type: Optional[UIType]
|
ui_type: Optional[UIType]
|
||||||
ui_component: Optional[UIComponent]
|
ui_component: Optional[UIComponent]
|
||||||
|
ui_order: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class _OutputField(BaseModel):
|
class _OutputField(BaseModel):
|
||||||
@ -178,6 +183,7 @@ class _OutputField(BaseModel):
|
|||||||
|
|
||||||
ui_hidden: bool
|
ui_hidden: bool
|
||||||
ui_type: Optional[UIType]
|
ui_type: Optional[UIType]
|
||||||
|
ui_order: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
def InputField(
|
def InputField(
|
||||||
@ -211,6 +217,7 @@ def InputField(
|
|||||||
ui_type: Optional[UIType] = None,
|
ui_type: Optional[UIType] = None,
|
||||||
ui_component: Optional[UIComponent] = None,
|
ui_component: Optional[UIComponent] = None,
|
||||||
ui_hidden: bool = False,
|
ui_hidden: bool = False,
|
||||||
|
ui_order: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -269,6 +276,7 @@ def InputField(
|
|||||||
ui_type=ui_type,
|
ui_type=ui_type,
|
||||||
ui_component=ui_component,
|
ui_component=ui_component,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
|
ui_order=ui_order,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -302,6 +310,7 @@ def OutputField(
|
|||||||
repr: bool = True,
|
repr: bool = True,
|
||||||
ui_type: Optional[UIType] = None,
|
ui_type: Optional[UIType] = None,
|
||||||
ui_hidden: bool = False,
|
ui_hidden: bool = False,
|
||||||
|
ui_order: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -348,6 +357,7 @@ def OutputField(
|
|||||||
repr=repr,
|
repr=repr,
|
||||||
ui_type=ui_type,
|
ui_type=ui_type,
|
||||||
ui_hidden=ui_hidden,
|
ui_hidden=ui_hidden,
|
||||||
|
ui_order=ui_order,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -376,7 +386,7 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
"""Base class for all invocation outputs"""
|
"""Base class for all invocation outputs"""
|
||||||
|
|
||||||
# All outputs must include a type name like this:
|
# All outputs must include a type name like this:
|
||||||
# type: Literal['your_output_name']
|
# type: Literal['your_output_name'] # noqa f821
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses_tuple(cls):
|
def get_all_subclasses_tuple(cls):
|
||||||
@ -389,6 +399,13 @@ class BaseInvocationOutput(BaseModel):
|
|||||||
toprocess.extend(next_subclasses)
|
toprocess.extend(next_subclasses)
|
||||||
return tuple(subclasses)
|
return tuple(subclasses)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
@staticmethod
|
||||||
|
def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
|
schema["required"] = list()
|
||||||
|
schema["required"].extend(["type"])
|
||||||
|
|
||||||
|
|
||||||
class RequiredConnectionException(Exception):
|
class RequiredConnectionException(Exception):
|
||||||
"""Raised when an field which requires a connection did not receive a value."""
|
"""Raised when an field which requires a connection did not receive a value."""
|
||||||
@ -410,7 +427,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# All invocations must include a type name like this:
|
# All invocations must include a type name like this:
|
||||||
# type: Literal['your_output_name']
|
# type: Literal['your_output_name'] # noqa f821
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_subclasses(cls):
|
def get_all_subclasses(cls):
|
||||||
@ -449,6 +466,9 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["title"] = uiconfig.title
|
schema["title"] = uiconfig.title
|
||||||
if uiconfig and hasattr(uiconfig, "tags"):
|
if uiconfig and hasattr(uiconfig, "tags"):
|
||||||
schema["tags"] = uiconfig.tags
|
schema["tags"] = uiconfig.tags
|
||||||
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
|
schema["required"] = list()
|
||||||
|
schema["required"].extend(["type", "id"])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||||
@ -485,7 +505,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
raise MissingInputException(self.__fields__["type"].default, field_name)
|
raise MissingInputException(self.__fields__["type"].default, field_name)
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
id: str = InputField(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||||
is_intermediate: bool = InputField(
|
is_intermediate: bool = InputField(
|
||||||
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
|
||||||
)
|
)
|
||||||
|
@ -232,7 +232,7 @@ class SDXLPromptInvocationBase:
|
|||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=True,
|
requires_pooled=get_pooled,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(prompt)
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
|
@ -8,7 +8,7 @@ import numpy
|
|||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
@ -41,6 +41,39 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Blank Image")
|
||||||
|
@tags("image")
|
||||||
|
class BlankImageInvocation(BaseInvocation):
|
||||||
|
"""Creates a blank image and forwards it to the pipeline"""
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
type: Literal["blank_image"] = "blank_image"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
width: int = InputField(default=512, description="The width of the image")
|
||||||
|
height: int = InputField(default=512, description="The height of the image")
|
||||||
|
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
||||||
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("Crop Image")
|
@title("Crop Image")
|
||||||
@tags("image", "crop")
|
@tags("image", "crop")
|
||||||
class ImageCropInvocation(BaseInvocation):
|
class ImageCropInvocation(BaseInvocation):
|
||||||
|
@ -1,23 +1,25 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Literal, Optional, get_args
|
from typing import Literal, Optional, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
|
|
||||||
|
|
||||||
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title
|
||||||
|
|
||||||
|
|
||||||
def infill_methods() -> list[str]:
|
def infill_methods() -> list[str]:
|
||||||
methods = [
|
methods = [
|
||||||
"tile",
|
"tile",
|
||||||
"solid",
|
"solid",
|
||||||
|
"lama",
|
||||||
]
|
]
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
methods.insert(0, "patchmatch")
|
methods.insert(0, "patchmatch")
|
||||||
@ -28,6 +30,11 @@ INFILL_METHODS = Literal[tuple(infill_methods())]
|
|||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
|
||||||
|
|
||||||
|
def infill_lama(im: Image.Image) -> Image.Image:
|
||||||
|
lama = LaMA()
|
||||||
|
return lama(im)
|
||||||
|
|
||||||
|
|
||||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
@ -90,7 +97,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Find all invalid tiles and replace with a random valid tile
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
replace_count = (tiles_mask is False).sum()
|
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||||
rng = np.random.default_rng(seed=seed)
|
rng = np.random.default_rng(seed=seed)
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||||
|
|
||||||
@ -218,3 +225,34 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@title("LaMa Infill")
|
||||||
|
@tags("image", "inpaint")
|
||||||
|
class LaMaInfillInvocation(BaseInvocation):
|
||||||
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
|
type: Literal["infill_lama"] = "infill_lama"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
infilled = infill_lama(image.copy())
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=infilled,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
@ -4,6 +4,7 @@ from contextlib import ExitStack
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
@ -106,24 +107,28 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection)
|
noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection, ui_order=3)
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: Union[float, List[float]] = InputField(
|
cfg_scale: Union[float, List[float]] = InputField(
|
||||||
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float
|
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type=UIType.Float, title="CFG Scale"
|
||||||
)
|
)
|
||||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler)
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection)
|
default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler
|
||||||
control: Union[ControlField, list[ControlField]] = InputField(
|
)
|
||||||
default=None, description=FieldDescriptions.control, input=Input.Connection
|
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||||
|
control: Union[ControlField, list[ControlField]] = InputField(
|
||||||
|
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||||
|
)
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
|
||||||
)
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
|
||||||
mask: Optional[ImageField] = InputField(
|
mask: Optional[ImageField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.mask,
|
description=FieldDescriptions.mask,
|
||||||
@ -453,7 +458,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@title("Latents to Image")
|
@title("Latents to Image")
|
||||||
@tags("latents", "image", "vae")
|
@tags("latents", "image", "vae", "l2i")
|
||||||
class LatentsToImageInvocation(BaseInvocation):
|
class LatentsToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
@ -641,7 +646,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@title("Image to Latents")
|
@title("Image to Latents")
|
||||||
@tags("latents", "image", "vae")
|
@tags("latents", "image", "vae", "i2l")
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
@ -720,3 +725,81 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
|
||||||
|
@title("Blend Latents")
|
||||||
|
@tags("latents", "blend")
|
||||||
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
|
|
||||||
|
type: Literal["lblend"] = "lblend"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents_a: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
latents_b: LatentsField = InputField(
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||||
|
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||||
|
|
||||||
|
if latents_a.shape != latents_b.shape:
|
||||||
|
raise "Latents to blend must be the same size."
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
device = choose_torch_device()
|
||||||
|
|
||||||
|
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||||
|
"""
|
||||||
|
Spherical linear interpolation
|
||||||
|
Args:
|
||||||
|
t (float/np.ndarray): Float value between 0.0 and 1.0
|
||||||
|
v0 (np.ndarray): Starting vector
|
||||||
|
v1 (np.ndarray): Final vector
|
||||||
|
DOT_THRESHOLD (float): Threshold for considering the two vectors as
|
||||||
|
colineal. Not recommended to alter this.
|
||||||
|
Returns:
|
||||||
|
v2 (np.ndarray): Interpolation vector between v0 and v1
|
||||||
|
"""
|
||||||
|
inputs_are_torch = False
|
||||||
|
if not isinstance(v0, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v0 = v0.detach().cpu().numpy()
|
||||||
|
if not isinstance(v1, np.ndarray):
|
||||||
|
inputs_are_torch = True
|
||||||
|
v1 = v1.detach().cpu().numpy()
|
||||||
|
|
||||||
|
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||||
|
if np.abs(dot) > DOT_THRESHOLD:
|
||||||
|
v2 = (1 - t) * v0 + t * v1
|
||||||
|
else:
|
||||||
|
theta_0 = np.arccos(dot)
|
||||||
|
sin_theta_0 = np.sin(theta_0)
|
||||||
|
theta_t = theta_0 * t
|
||||||
|
sin_theta_t = np.sin(theta_t)
|
||||||
|
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = sin_theta_t / sin_theta_0
|
||||||
|
v2 = s0 * v0 + s1 * v1
|
||||||
|
|
||||||
|
if inputs_are_torch:
|
||||||
|
v2 = torch.from_numpy(v2).to(device)
|
||||||
|
|
||||||
|
return v2
|
||||||
|
|
||||||
|
# blend
|
||||||
|
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||||
|
|
||||||
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
|
blended_latents = blended_latents.to("cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
|
context.services.latents.save(name, blended_latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||||
|
@ -21,7 +21,7 @@ class AddInvocation(BaseInvocation):
|
|||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Subtract Integers")
|
@title("Subtract Integers")
|
||||||
@ -36,7 +36,7 @@ class SubtractInvocation(BaseInvocation):
|
|||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=self.a - self.b)
|
return IntegerOutput(value=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Multiply Integers")
|
@title("Multiply Integers")
|
||||||
@ -51,7 +51,7 @@ class MultiplyInvocation(BaseInvocation):
|
|||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=self.a * self.b)
|
return IntegerOutput(value=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
@title("Divide Integers")
|
@title("Divide Integers")
|
||||||
@ -66,7 +66,7 @@ class DivideInvocation(BaseInvocation):
|
|||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@title("Random Integer")
|
@title("Random Integer")
|
||||||
@ -81,4 +81,4 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=np.random.randint(self.low, self.high))
|
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||||
|
@ -72,7 +72,7 @@ class LoRAModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
@title("Main Model Loader")
|
@title("Main Model")
|
||||||
@tags("model")
|
@tags("model")
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
@ -179,7 +179,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@title("LoRA Loader")
|
@title("LoRA")
|
||||||
@tags("lora", "model")
|
@tags("lora", "model")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
@ -257,7 +257,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL LoRA Loader")
|
@title("SDXL LoRA")
|
||||||
@tags("sdxl", "lora", "model")
|
@tags("sdxl", "lora", "model")
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
@ -356,7 +356,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("VAE Loader")
|
@title("VAE")
|
||||||
@tags("vae", "model")
|
@tags("vae", "model")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
@ -169,7 +169,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
ui_type=UIType.Float,
|
ui_type=UIType.Float,
|
||||||
)
|
)
|
||||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||||
)
|
)
|
||||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||||
unet: UNetField = InputField(
|
unet: UNetField = InputField(
|
||||||
@ -406,7 +406,7 @@ class OnnxModelField(BaseModel):
|
|||||||
model_type: ModelType = Field(description="Model Type")
|
model_type: ModelType = Field(description="Model Type")
|
||||||
|
|
||||||
|
|
||||||
@title("ONNX Model Loader")
|
@title("ONNX Main Model")
|
||||||
@tags("onnx", "model")
|
@tags("onnx", "model")
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Tuple
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -33,7 +33,7 @@ class BooleanOutput(BaseInvocationOutput):
|
|||||||
"""Base class for nodes that output a single boolean"""
|
"""Base class for nodes that output a single boolean"""
|
||||||
|
|
||||||
type: Literal["boolean_output"] = "boolean_output"
|
type: Literal["boolean_output"] = "boolean_output"
|
||||||
a: bool = OutputField(description="The output boolean")
|
value: bool = OutputField(description="The output boolean")
|
||||||
|
|
||||||
|
|
||||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||||
@ -42,9 +42,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
type: Literal["boolean_collection_output"] = "boolean_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[bool] = OutputField(
|
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||||
default_factory=list, description="The output boolean collection", ui_type=UIType.BooleanCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("Boolean Primitive")
|
@title("Boolean Primitive")
|
||||||
@ -55,10 +53,10 @@ class BooleanInvocation(BaseInvocation):
|
|||||||
type: Literal["boolean"] = "boolean"
|
type: Literal["boolean"] = "boolean"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
a: bool = InputField(default=False, description="The boolean value")
|
value: bool = InputField(default=False, description="The boolean value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||||
return BooleanOutput(a=self.a)
|
return BooleanOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Boolean Primitive Collection")
|
@title("Boolean Primitive Collection")
|
||||||
@ -70,7 +68,7 @@ class BooleanCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
collection: list[bool] = InputField(
|
collection: list[bool] = InputField(
|
||||||
default=False, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||||
@ -86,7 +84,7 @@ class IntegerOutput(BaseInvocationOutput):
|
|||||||
"""Base class for nodes that output a single integer"""
|
"""Base class for nodes that output a single integer"""
|
||||||
|
|
||||||
type: Literal["integer_output"] = "integer_output"
|
type: Literal["integer_output"] = "integer_output"
|
||||||
a: int = OutputField(description="The output integer")
|
value: int = OutputField(description="The output integer")
|
||||||
|
|
||||||
|
|
||||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||||
@ -95,9 +93,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["integer_collection_output"] = "integer_collection_output"
|
type: Literal["integer_collection_output"] = "integer_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[int] = OutputField(
|
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||||
default_factory=list, description="The int collection", ui_type=UIType.IntegerCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Primitive")
|
@title("Integer Primitive")
|
||||||
@ -108,10 +104,10 @@ class IntegerInvocation(BaseInvocation):
|
|||||||
type: Literal["integer"] = "integer"
|
type: Literal["integer"] = "integer"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
a: int = InputField(default=0, description="The integer value")
|
value: int = InputField(default=0, description="The integer value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||||
return IntegerOutput(a=self.a)
|
return IntegerOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Integer Primitive Collection")
|
@title("Integer Primitive Collection")
|
||||||
@ -139,7 +135,7 @@ class FloatOutput(BaseInvocationOutput):
|
|||||||
"""Base class for nodes that output a single float"""
|
"""Base class for nodes that output a single float"""
|
||||||
|
|
||||||
type: Literal["float_output"] = "float_output"
|
type: Literal["float_output"] = "float_output"
|
||||||
a: float = OutputField(description="The output float")
|
value: float = OutputField(description="The output float")
|
||||||
|
|
||||||
|
|
||||||
class FloatCollectionOutput(BaseInvocationOutput):
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
@ -148,9 +144,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["float_collection_output"] = "float_collection_output"
|
type: Literal["float_collection_output"] = "float_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[float] = OutputField(
|
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||||
default_factory=list, description="The float collection", ui_type=UIType.FloatCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("Float Primitive")
|
@title("Float Primitive")
|
||||||
@ -161,10 +155,10 @@ class FloatInvocation(BaseInvocation):
|
|||||||
type: Literal["float"] = "float"
|
type: Literal["float"] = "float"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
param: float = InputField(default=0.0, description="The float value")
|
value: float = InputField(default=0.0, description="The float value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
return FloatOutput(a=self.param)
|
return FloatOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("Float Primitive Collection")
|
@title("Float Primitive Collection")
|
||||||
@ -176,7 +170,7 @@ class FloatCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
collection: list[float] = InputField(
|
collection: list[float] = InputField(
|
||||||
default=0, description="The collection of float values", ui_type=UIType.FloatCollection
|
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
@ -192,7 +186,7 @@ class StringOutput(BaseInvocationOutput):
|
|||||||
"""Base class for nodes that output a single string"""
|
"""Base class for nodes that output a single string"""
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
type: Literal["string_output"] = "string_output"
|
||||||
text: str = OutputField(description="The output string")
|
value: str = OutputField(description="The output string")
|
||||||
|
|
||||||
|
|
||||||
class StringCollectionOutput(BaseInvocationOutput):
|
class StringCollectionOutput(BaseInvocationOutput):
|
||||||
@ -201,9 +195,7 @@ class StringCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["string_collection_output"] = "string_collection_output"
|
type: Literal["string_collection_output"] = "string_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[str] = OutputField(
|
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||||
default_factory=list, description="The output strings", ui_type=UIType.StringCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("String Primitive")
|
@title("String Primitive")
|
||||||
@ -214,10 +206,10 @@ class StringInvocation(BaseInvocation):
|
|||||||
type: Literal["string"] = "string"
|
type: Literal["string"] = "string"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
text: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
return StringOutput(text=self.text)
|
return StringOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@title("String Primitive Collection")
|
@title("String Primitive Collection")
|
||||||
@ -229,7 +221,7 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
collection: list[str] = InputField(
|
collection: list[str] = InputField(
|
||||||
default=0, description="The collection of string values", ui_type=UIType.StringCollection
|
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||||
@ -262,9 +254,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["image_collection_output"] = "image_collection_output"
|
type: Literal["image_collection_output"] = "image_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ImageField] = OutputField(
|
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||||
default_factory=list, description="The output images", ui_type=UIType.ImageCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("Image Primitive")
|
@title("Image Primitive")
|
||||||
@ -334,7 +324,6 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["latents_collection_output"] = "latents_collection_output"
|
type: Literal["latents_collection_output"] = "latents_collection_output"
|
||||||
|
|
||||||
collection: list[LatentsField] = OutputField(
|
collection: list[LatentsField] = OutputField(
|
||||||
default_factory=list,
|
|
||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
ui_type=UIType.LatentsCollection,
|
ui_type=UIType.LatentsCollection,
|
||||||
)
|
)
|
||||||
@ -365,7 +354,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
collection: list[LatentsField] = InputField(
|
collection: list[LatentsField] = InputField(
|
||||||
default=0, description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||||
@ -410,9 +399,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["color_collection_output"] = "color_collection_output"
|
type: Literal["color_collection_output"] = "color_collection_output"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ColorField] = OutputField(
|
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||||
default_factory=list, description="The output colors", ui_type=UIType.ColorCollection
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@title("Color Primitive")
|
@title("Color Primitive")
|
||||||
@ -455,7 +442,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ConditioningField] = OutputField(
|
collection: list[ConditioningField] = OutputField(
|
||||||
default_factory=list,
|
|
||||||
description="The output conditioning tensors",
|
description="The output conditioning tensors",
|
||||||
ui_type=UIType.ConditioningCollection,
|
ui_type=UIType.ConditioningCollection,
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Main Model Loader")
|
@title("SDXL Main Model")
|
||||||
@tags("model", "sdxl")
|
@tags("model", "sdxl")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
@ -122,7 +122,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@title("SDXL Refiner Model Loader")
|
@title("SDXL Refiner Model")
|
||||||
@tags("model", "sdxl", "refiner")
|
@tags("model", "sdxl", "refiner")
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
8
invokeai/app/services/config/__init__.py
Normal file
8
invokeai/app/services/config/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Init file for InvokeAI configure package
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .invokeai_config import ( # noqa F401
|
||||||
|
InvokeAIAppConfig,
|
||||||
|
get_invokeai_config,
|
||||||
|
)
|
239
invokeai/app/services/config/base.py
Normal file
239
invokeai/app/services/config/base.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
Base class for the InvokeAI configuration system.
|
||||||
|
It defines a type of pydantic BaseSettings object that
|
||||||
|
is able to read and write from an omegaconf-based config file,
|
||||||
|
with overriding of settings from environment variables and/or
|
||||||
|
the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pydoc
|
||||||
|
import sys
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from omegaconf import OmegaConf, DictConfig, ListConfig
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
|
|
||||||
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
|
"""
|
||||||
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
|
It also supports reading defaults from an init file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def print_help(self, file=None):
|
||||||
|
text = self.format_help()
|
||||||
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAISettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Runtime configuration settings in which default values are
|
||||||
|
read from an omegaconf .yaml file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
initconf: ClassVar[DictConfig] = None
|
||||||
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
|
|
||||||
|
def parse_args(self, argv: list = sys.argv[1:]):
|
||||||
|
parser = self.get_parser()
|
||||||
|
opt = parser.parse_args(argv)
|
||||||
|
for name in self.__fields__:
|
||||||
|
if name not in self._excluded():
|
||||||
|
value = getattr(opt, name)
|
||||||
|
if isinstance(value, ListConfig):
|
||||||
|
value = list(value)
|
||||||
|
elif isinstance(value, DictConfig):
|
||||||
|
value = dict(value)
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
|
def to_yaml(self) -> str:
|
||||||
|
"""
|
||||||
|
Return a YAML string representing our settings. This can be used
|
||||||
|
as the contents of `invokeai.yaml` to restore settings later.
|
||||||
|
"""
|
||||||
|
cls = self.__class__
|
||||||
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
|
field_dict = dict({type: dict()})
|
||||||
|
for name, field in self.__fields__.items():
|
||||||
|
if name in cls._excluded_from_yaml():
|
||||||
|
continue
|
||||||
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
|
value = getattr(self, name)
|
||||||
|
if category not in field_dict[type]:
|
||||||
|
field_dict[type][category] = dict()
|
||||||
|
# keep paths as strings to make it easier to read
|
||||||
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
|
conf = OmegaConf.create(field_dict)
|
||||||
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_parser_arguments(cls, parser):
|
||||||
|
if "type" in get_type_hints(cls):
|
||||||
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||||
|
else:
|
||||||
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
|
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||||
|
|
||||||
|
initconf = (
|
||||||
|
cls.initconf.get(settings_stanza)
|
||||||
|
if cls.initconf and settings_stanza in cls.initconf
|
||||||
|
else OmegaConf.create()
|
||||||
|
)
|
||||||
|
|
||||||
|
# create an upcase version of the environment in
|
||||||
|
# order to achieve case-insensitive environment
|
||||||
|
# variables (the way Windows does)
|
||||||
|
upcase_environ = dict()
|
||||||
|
for key, value in os.environ.items():
|
||||||
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
|
fields = cls.__fields__
|
||||||
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
|
for name, field in fields.items():
|
||||||
|
if name not in cls._excluded():
|
||||||
|
current_default = field.default
|
||||||
|
|
||||||
|
category = field.field_info.extra.get("category", "Uncategorized")
|
||||||
|
env_name = env_prefix + "_" + name
|
||||||
|
if category in initconf and name in initconf.get(category):
|
||||||
|
field.default = initconf.get(category).get(name)
|
||||||
|
if env_name.upper() in upcase_environ:
|
||||||
|
field.default = upcase_environ[env_name.upper()]
|
||||||
|
cls.add_field_argument(parser, name, field)
|
||||||
|
|
||||||
|
field.default = current_default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cmd_name(self, command_field: str = "type") -> str:
|
||||||
|
hints = get_type_hints(self)
|
||||||
|
if command_field in hints:
|
||||||
|
return get_args(hints[command_field])[0]
|
||||||
|
else:
|
||||||
|
return "Uncategorized"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_parser(cls) -> ArgumentParser:
|
||||||
|
parser = PagingArgumentParser(
|
||||||
|
prog=cls.cmd_name(),
|
||||||
|
description=cls.__doc__,
|
||||||
|
)
|
||||||
|
cls.add_parser_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||||
|
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _excluded(self) -> List[str]:
|
||||||
|
# internal fields that shouldn't be exposed as command line options
|
||||||
|
return ["type", "initconf"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _excluded_from_yaml(self) -> List[str]:
|
||||||
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
|
return [
|
||||||
|
"type",
|
||||||
|
"initconf",
|
||||||
|
"version",
|
||||||
|
"from_file",
|
||||||
|
"model",
|
||||||
|
"root",
|
||||||
|
"max_cache_size",
|
||||||
|
"max_vram_cache_size",
|
||||||
|
"always_use_cpu",
|
||||||
|
"free_gpu_mem",
|
||||||
|
"xformers_enabled",
|
||||||
|
"tiled_decode",
|
||||||
|
]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
case_sensitive = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||||
|
field_type = get_type_hints(cls).get(name)
|
||||||
|
default = (
|
||||||
|
default_override
|
||||||
|
if default_override is not None
|
||||||
|
else field.default
|
||||||
|
if field.default_factory is None
|
||||||
|
else field.default_factory()
|
||||||
|
)
|
||||||
|
if category := field.field_info.extra.get("category"):
|
||||||
|
if category not in cls.argparse_groups:
|
||||||
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
|
argparse_group = cls.argparse_groups[category]
|
||||||
|
else:
|
||||||
|
argparse_group = command_parser
|
||||||
|
|
||||||
|
if get_origin(field_type) == Literal:
|
||||||
|
allowed_values = get_args(field.type_)
|
||||||
|
allowed_types = set()
|
||||||
|
for val in allowed_values:
|
||||||
|
allowed_types.add(type(val))
|
||||||
|
allowed_types_list = list(allowed_types)
|
||||||
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
|
||||||
|
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field_type,
|
||||||
|
default=default,
|
||||||
|
choices=allowed_values,
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif get_origin(field_type) == Union:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=int_or_float_or_str,
|
||||||
|
default=default,
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif get_origin(field_type) == list:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
nargs="*",
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||||
|
"""
|
||||||
|
Workaround for argparse type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
return str(value)
|
@ -10,37 +10,49 @@ categories returned by `invokeai --help`. The file looks like this:
|
|||||||
[file: invokeai.yaml]
|
[file: invokeai.yaml]
|
||||||
|
|
||||||
InvokeAI:
|
InvokeAI:
|
||||||
Paths:
|
|
||||||
root: /home/lstein/invokeai-main
|
|
||||||
conf_path: configs/models.yaml
|
|
||||||
legacy_conf_dir: configs/stable-diffusion
|
|
||||||
outdir: outputs
|
|
||||||
autoimport_dir: null
|
|
||||||
Models:
|
|
||||||
model: stable-diffusion-1.5
|
|
||||||
embeddings: true
|
|
||||||
Memory/Performance:
|
|
||||||
xformers_enabled: false
|
|
||||||
sequential_guidance: false
|
|
||||||
precision: float16
|
|
||||||
max_cache_size: 6
|
|
||||||
max_vram_cache_size: 0.5
|
|
||||||
always_use_cpu: false
|
|
||||||
free_gpu_mem: false
|
|
||||||
Features:
|
|
||||||
esrgan: true
|
|
||||||
patchmatch: true
|
|
||||||
internet_available: true
|
|
||||||
log_tokenization: false
|
|
||||||
Web Server:
|
Web Server:
|
||||||
host: 127.0.0.1
|
host: 127.0.0.1
|
||||||
port: 8081
|
port: 9090
|
||||||
allow_origins: []
|
allow_origins: []
|
||||||
allow_credentials: true
|
allow_credentials: true
|
||||||
allow_methods:
|
allow_methods:
|
||||||
- '*'
|
- '*'
|
||||||
allow_headers:
|
allow_headers:
|
||||||
- '*'
|
- '*'
|
||||||
|
Features:
|
||||||
|
esrgan: true
|
||||||
|
internet_available: true
|
||||||
|
log_tokenization: false
|
||||||
|
patchmatch: true
|
||||||
|
ignore_missing_core_models: false
|
||||||
|
Paths:
|
||||||
|
autoimport_dir: autoimport
|
||||||
|
lora_dir: null
|
||||||
|
embedding_dir: null
|
||||||
|
controlnet_dir: null
|
||||||
|
conf_path: configs/models.yaml
|
||||||
|
models_dir: models
|
||||||
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
|
db_dir: databases
|
||||||
|
outdir: /home/lstein/invokeai-main/outputs
|
||||||
|
use_memory_db: false
|
||||||
|
Logging:
|
||||||
|
log_handlers:
|
||||||
|
- console
|
||||||
|
log_format: plain
|
||||||
|
log_level: info
|
||||||
|
Model Cache:
|
||||||
|
ram: 13.5
|
||||||
|
vram: 0.25
|
||||||
|
lazy_offload: true
|
||||||
|
Device:
|
||||||
|
device: auto
|
||||||
|
precision: auto
|
||||||
|
Generation:
|
||||||
|
sequential_guidance: false
|
||||||
|
attention_type: xformers
|
||||||
|
attention_slice_size: auto
|
||||||
|
force_tiled_decode: false
|
||||||
|
|
||||||
The default name of the configuration file is `invokeai.yaml`, located
|
The default name of the configuration file is `invokeai.yaml`, located
|
||||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||||
@ -54,24 +66,23 @@ InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
|||||||
at initialization time. You may pass a list of strings in the optional
|
at initialization time. You may pass a list of strings in the optional
|
||||||
`argv` argument to use instead of the system argv:
|
`argv` argument to use instead of the system argv:
|
||||||
|
|
||||||
conf.parse_args(argv=['--xformers_enabled'])
|
conf.parse_args(argv=['--log_tokenization'])
|
||||||
|
|
||||||
It is also possible to set a value at initialization time. However, if
|
It is also possible to set a value at initialization time. However, if
|
||||||
you call parse_args() it may be overwritten.
|
you call parse_args() it may be overwritten.
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
conf = InvokeAIAppConfig(log_tokenization=True)
|
||||||
conf.parse_args(argv=['--no-xformers'])
|
conf.parse_args(argv=['--no-log_tokenization'])
|
||||||
conf.xformers_enabled
|
conf.log_tokenization
|
||||||
# False
|
# False
|
||||||
|
|
||||||
|
|
||||||
To avoid this, use `get_config()` to retrieve the application-wide
|
To avoid this, use `get_config()` to retrieve the application-wide
|
||||||
configuration object. This will retain any properties set at object
|
configuration object. This will retain any properties set at object
|
||||||
creation time:
|
creation time:
|
||||||
|
|
||||||
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
|
||||||
conf.parse_args(argv=['--no-xformers'])
|
conf.parse_args(argv=['--no-log_tokenization'])
|
||||||
conf.xformers_enabled
|
conf.log_tokenization
|
||||||
# True
|
# True
|
||||||
|
|
||||||
Any setting can be overwritten by setting an environment variable of
|
Any setting can be overwritten by setting an environment variable of
|
||||||
@ -93,7 +104,7 @@ Typical usage at the top level file:
|
|||||||
# get global configuration and print its cache size
|
# get global configuration and print its cache size
|
||||||
conf = InvokeAIAppConfig.get_config()
|
conf = InvokeAIAppConfig.get_config()
|
||||||
conf.parse_args()
|
conf.parse_args()
|
||||||
print(conf.max_cache_size)
|
print(conf.ram_cache_size)
|
||||||
|
|
||||||
Typical usage in a backend module:
|
Typical usage in a backend module:
|
||||||
|
|
||||||
@ -101,8 +112,7 @@ Typical usage in a backend module:
|
|||||||
|
|
||||||
# get global configuration and print its cache size value
|
# get global configuration and print its cache size value
|
||||||
conf = InvokeAIAppConfig.get_config()
|
conf = InvokeAIAppConfig.get_config()
|
||||||
print(conf.max_cache_size)
|
print(conf.ram_cache_size)
|
||||||
|
|
||||||
|
|
||||||
Computed properties:
|
Computed properties:
|
||||||
|
|
||||||
@ -159,15 +169,13 @@ two configs are kept in separate sections of the config file:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import argparse
|
|
||||||
import pydoc
|
|
||||||
import os
|
import os
|
||||||
import sys
|
from omegaconf import OmegaConf, DictConfig
|
||||||
from argparse import ArgumentParser
|
|
||||||
from omegaconf import OmegaConf, DictConfig, ListConfig
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import Field, parse_obj_as
|
||||||
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Literal, Union, Optional, get_type_hints
|
||||||
|
|
||||||
|
from .base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
@ -175,195 +183,6 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
|||||||
DEFAULT_MAX_VRAM = 0.5
|
DEFAULT_MAX_VRAM = 0.5
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
|
||||||
"""
|
|
||||||
Runtime configuration settings in which default values are
|
|
||||||
read from an omegaconf .yaml file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
initconf: ClassVar[DictConfig] = None
|
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
|
||||||
|
|
||||||
def parse_args(self, argv: list = sys.argv[1:]):
|
|
||||||
parser = self.get_parser()
|
|
||||||
opt = parser.parse_args(argv)
|
|
||||||
for name in self.__fields__:
|
|
||||||
if name not in self._excluded():
|
|
||||||
value = getattr(opt, name)
|
|
||||||
if isinstance(value, ListConfig):
|
|
||||||
value = list(value)
|
|
||||||
elif isinstance(value, DictConfig):
|
|
||||||
value = dict(value)
|
|
||||||
setattr(self, name, value)
|
|
||||||
|
|
||||||
def to_yaml(self) -> str:
|
|
||||||
"""
|
|
||||||
Return a YAML string representing our settings. This can be used
|
|
||||||
as the contents of `invokeai.yaml` to restore settings later.
|
|
||||||
"""
|
|
||||||
cls = self.__class__
|
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
|
||||||
field_dict = dict({type: dict()})
|
|
||||||
for name, field in self.__fields__.items():
|
|
||||||
if name in cls._excluded_from_yaml():
|
|
||||||
continue
|
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
|
||||||
value = getattr(self, name)
|
|
||||||
if category not in field_dict[type]:
|
|
||||||
field_dict[type][category] = dict()
|
|
||||||
# keep paths as strings to make it easier to read
|
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
|
||||||
conf = OmegaConf.create(field_dict)
|
|
||||||
return OmegaConf.to_yaml(conf)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_parser_arguments(cls, parser):
|
|
||||||
if "type" in get_type_hints(cls):
|
|
||||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
|
||||||
else:
|
|
||||||
settings_stanza = "Uncategorized"
|
|
||||||
|
|
||||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
|
||||||
|
|
||||||
initconf = (
|
|
||||||
cls.initconf.get(settings_stanza)
|
|
||||||
if cls.initconf and settings_stanza in cls.initconf
|
|
||||||
else OmegaConf.create()
|
|
||||||
)
|
|
||||||
|
|
||||||
# create an upcase version of the environment in
|
|
||||||
# order to achieve case-insensitive environment
|
|
||||||
# variables (the way Windows does)
|
|
||||||
upcase_environ = dict()
|
|
||||||
for key, value in os.environ.items():
|
|
||||||
upcase_environ[key.upper()] = value
|
|
||||||
|
|
||||||
fields = cls.__fields__
|
|
||||||
cls.argparse_groups = {}
|
|
||||||
|
|
||||||
for name, field in fields.items():
|
|
||||||
if name not in cls._excluded():
|
|
||||||
current_default = field.default
|
|
||||||
|
|
||||||
category = field.field_info.extra.get("category", "Uncategorized")
|
|
||||||
env_name = env_prefix + "_" + name
|
|
||||||
if category in initconf and name in initconf.get(category):
|
|
||||||
field.default = initconf.get(category).get(name)
|
|
||||||
if env_name.upper() in upcase_environ:
|
|
||||||
field.default = upcase_environ[env_name.upper()]
|
|
||||||
cls.add_field_argument(parser, name, field)
|
|
||||||
|
|
||||||
field.default = current_default
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cmd_name(self, command_field: str = "type") -> str:
|
|
||||||
hints = get_type_hints(self)
|
|
||||||
if command_field in hints:
|
|
||||||
return get_args(hints[command_field])[0]
|
|
||||||
else:
|
|
||||||
return "Uncategorized"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_parser(cls) -> ArgumentParser:
|
|
||||||
parser = PagingArgumentParser(
|
|
||||||
prog=cls.cmd_name(),
|
|
||||||
description=cls.__doc__,
|
|
||||||
)
|
|
||||||
cls.add_parser_arguments(parser)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
|
||||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _excluded(self) -> List[str]:
|
|
||||||
# internal fields that shouldn't be exposed as command line options
|
|
||||||
return ["type", "initconf"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _excluded_from_yaml(self) -> List[str]:
|
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
|
||||||
return [
|
|
||||||
"type",
|
|
||||||
"initconf",
|
|
||||||
"version",
|
|
||||||
"from_file",
|
|
||||||
"model",
|
|
||||||
"root",
|
|
||||||
]
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
case_sensitive = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
|
||||||
field_type = get_type_hints(cls).get(name)
|
|
||||||
default = (
|
|
||||||
default_override
|
|
||||||
if default_override is not None
|
|
||||||
else field.default
|
|
||||||
if field.default_factory is None
|
|
||||||
else field.default_factory()
|
|
||||||
)
|
|
||||||
if category := field.field_info.extra.get("category"):
|
|
||||||
if category not in cls.argparse_groups:
|
|
||||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
|
||||||
argparse_group = cls.argparse_groups[category]
|
|
||||||
else:
|
|
||||||
argparse_group = command_parser
|
|
||||||
|
|
||||||
if get_origin(field_type) == Literal:
|
|
||||||
allowed_values = get_args(field.type_)
|
|
||||||
allowed_types = set()
|
|
||||||
for val in allowed_values:
|
|
||||||
allowed_types.add(type(val))
|
|
||||||
allowed_types_list = list(allowed_types)
|
|
||||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
|
||||||
|
|
||||||
argparse_group.add_argument(
|
|
||||||
f"--{name}",
|
|
||||||
dest=name,
|
|
||||||
type=field_type,
|
|
||||||
default=default,
|
|
||||||
choices=allowed_values,
|
|
||||||
help=field.field_info.description,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif get_origin(field_type) == list:
|
|
||||||
argparse_group.add_argument(
|
|
||||||
f"--{name}",
|
|
||||||
dest=name,
|
|
||||||
nargs="*",
|
|
||||||
type=field.type_,
|
|
||||||
default=default,
|
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
|
||||||
help=field.field_info.description,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
argparse_group.add_argument(
|
|
||||||
f"--{name}",
|
|
||||||
dest=name,
|
|
||||||
type=field.type_,
|
|
||||||
default=default,
|
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
|
||||||
help=field.field_info.description,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_root() -> Path:
|
|
||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
|
||||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
|
||||||
root = (venv.parent).resolve()
|
|
||||||
else:
|
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
|
||||||
return root
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""
|
"""
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
@ -378,6 +197,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
|
|
||||||
|
# WEB
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||||
@ -385,20 +206,14 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||||
|
|
||||||
|
# FEATURES
|
||||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
|
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
# PATHS
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
|
||||||
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
|
||||||
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
|
||||||
precision : Literal['auto', 'float16', 'float32', 'autocast'] = Field(default='auto', description='Floating point precision', category='Memory/Performance')
|
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
|
||||||
|
|
||||||
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||||
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||||
@ -409,16 +224,41 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
|
|
||||||
|
# CACHE
|
||||||
|
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||||
|
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||||
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||||
|
|
||||||
|
# DEVICE
|
||||||
|
device : Literal[tuple(["auto", "cpu", "cuda", "cuda:1", "mps"])] = Field(default="auto", description="Generation device", category="Device", )
|
||||||
|
precision: Literal[tuple(["auto", "float16", "float32", "autocast"])] = Field(default="auto", description="Floating point precision", category="Device", )
|
||||||
|
|
||||||
|
# GENERATION
|
||||||
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category="Generation", )
|
||||||
|
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
|
||||||
|
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||||
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||||
|
|
||||||
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
|
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
|
||||||
|
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
|
||||||
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
|
|
||||||
|
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -541,11 +381,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""Return true if precision set to float32"""
|
"""Return true if precision set to float32"""
|
||||||
return self.precision == "float32"
|
return self.precision == "float32"
|
||||||
|
|
||||||
@property
|
|
||||||
def disable_xformers(self) -> bool:
|
|
||||||
"""Return true if xformers_enabled is false"""
|
|
||||||
return not self.xformers_enabled
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def try_patchmatch(self) -> bool:
|
def try_patchmatch(self) -> bool:
|
||||||
"""Return true if patchmatch true"""
|
"""Return true if patchmatch true"""
|
||||||
@ -561,6 +396,27 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ram_cache_size(self) -> float:
|
||||||
|
return self.max_cache_size or self.ram
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_cache_size(self) -> float:
|
||||||
|
return self.max_vram_cache_size or self.vram
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_cpu(self) -> bool:
|
||||||
|
return self.always_use_cpu or self.device == "cpu"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disable_xformers(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return true if enable_xformers is false (reversed logic)
|
||||||
|
and attention type is not set to xformers.
|
||||||
|
"""
|
||||||
|
disabled_in_config = not self.xformers_enabled
|
||||||
|
return disabled_in_config and self.attention_type != "xformers"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_root() -> Path:
|
def find_root() -> Path:
|
||||||
"""
|
"""
|
||||||
@ -570,19 +426,19 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
|
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
|
||||||
"""
|
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
|
||||||
It also supports reading defaults from an init file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def print_help(self, file=None):
|
|
||||||
text = self.format_help()
|
|
||||||
pydoc.pager(text)
|
|
||||||
|
|
||||||
|
|
||||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||||
"""
|
"""
|
||||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||||
"""
|
"""
|
||||||
return InvokeAIAppConfig.get_config(**kwargs)
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_root() -> Path:
|
||||||
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
|
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||||
|
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
||||||
|
root = (venv.parent).resolve()
|
||||||
|
else:
|
||||||
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
return root
|
@ -17,9 +17,9 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
description="Converts text to an image",
|
description="Converts text to an image",
|
||||||
graph=Graph(
|
graph=Graph(
|
||||||
nodes={
|
nodes={
|
||||||
"width": IntegerInvocation(id="width", a=512),
|
"width": IntegerInvocation(id="width", value=512),
|
||||||
"height": IntegerInvocation(id="height", a=512),
|
"height": IntegerInvocation(id="height", value=512),
|
||||||
"seed": IntegerInvocation(id="seed", a=-1),
|
"seed": IntegerInvocation(id="seed", value=-1),
|
||||||
"3": NoiseInvocation(id="3"),
|
"3": NoiseInvocation(id="3"),
|
||||||
"4": CompelInvocation(id="4"),
|
"4": CompelInvocation(id="4"),
|
||||||
"5": CompelInvocation(id="5"),
|
"5": CompelInvocation(id="5"),
|
||||||
@ -29,15 +29,15 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id="width", field="a"),
|
source=EdgeConnection(node_id="width", field="value"),
|
||||||
destination=EdgeConnection(node_id="3", field="width"),
|
destination=EdgeConnection(node_id="3", field="width"),
|
||||||
),
|
),
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id="height", field="a"),
|
source=EdgeConnection(node_id="height", field="value"),
|
||||||
destination=EdgeConnection(node_id="3", field="height"),
|
destination=EdgeConnection(node_id="3", field="height"),
|
||||||
),
|
),
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id="seed", field="a"),
|
source=EdgeConnection(node_id="seed", field="value"),
|
||||||
destination=EdgeConnection(node_id="3", field="seed"),
|
destination=EdgeConnection(node_id="3", field="seed"),
|
||||||
),
|
),
|
||||||
Edge(
|
Edge(
|
||||||
@ -65,9 +65,9 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||||
],
|
],
|
||||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||||
)
|
)
|
||||||
|
@ -49,9 +49,36 @@ from invokeai.backend.model_management.model_cache import CacheStats
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
class InvocationStatsServiceBase(ABC):
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
"""
|
"""
|
||||||
@ -94,8 +121,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
invocation_type: str,
|
invocation_type: str,
|
||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add timing information on execution of a node. Usually
|
Add timing information on execution of a node. Usually
|
||||||
@ -104,8 +129,6 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
:param invocation_type: String literal type of the node
|
:param invocation_type: String literal type of the node
|
||||||
:param time_used: Time used by node's exection (sec)
|
:param time_used: Time used by node's exection (sec)
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -116,25 +139,19 @@ class InvocationStatsServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
@dataclass
|
:param ram_used: How much RAM is currently in use.
|
||||||
class NodeStats:
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
"""
|
||||||
|
pass
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
@ -152,12 +169,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
invocation: BaseInvocation = None
|
invocation: BaseInvocation
|
||||||
collector: "InvocationStatsServiceBase" = None
|
collector: "InvocationStatsServiceBase"
|
||||||
graph_id: str = None
|
graph_id: str
|
||||||
start_time: int = 0
|
start_time: float
|
||||||
ram_used: int = 0
|
ram_used: int
|
||||||
model_manager: ModelManagerService = None
|
model_manager: ModelManagerService
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +187,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self.invocation = invocation
|
self.invocation = invocation
|
||||||
self.collector = collector
|
self.collector = collector
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.start_time = 0
|
self.start_time = 0.0
|
||||||
self.ram_used = 0
|
self.ram_used = 0
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
|
|
||||||
@ -191,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
)
|
)
|
||||||
self.collector.update_invocation_stats(
|
self.collector.update_invocation_stats(
|
||||||
graph_id=self.graph_id,
|
graph_id=self.graph_id,
|
||||||
invocation_type=self.invocation.type,
|
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
|
||||||
time_used=time.time() - self.start_time,
|
time_used=time.time() - self.start_time,
|
||||||
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
|
||||||
)
|
)
|
||||||
@ -202,11 +219,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_manager: ModelManagerService,
|
model_manager: ModelManagerService,
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
||||||
"""
|
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
@ -217,7 +229,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self._stats = {}
|
self._stats = {}
|
||||||
|
|
||||||
def reset_stats(self, graph_execution_id: str):
|
def reset_stats(self, graph_execution_id: str):
|
||||||
"""Zero the statistics for the indicated graph."""
|
|
||||||
try:
|
try:
|
||||||
self._stats.pop(graph_execution_id)
|
self._stats.pop(graph_execution_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -228,12 +239,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
ram_used: float,
|
ram_used: float,
|
||||||
ram_changed: float,
|
ram_changed: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
self.ram_used = ram_used
|
self.ram_used = ram_used
|
||||||
self.ram_changed = ram_changed
|
self.ram_changed = ram_changed
|
||||||
|
|
||||||
@ -244,16 +249,6 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
time_used: float,
|
time_used: float,
|
||||||
vram_used: float,
|
vram_used: float,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
:param ram_used: Current RAM available (GB)
|
|
||||||
:param ram_changed: Change in RAM usage over course of the run (GB)
|
|
||||||
"""
|
|
||||||
if not self._stats[graph_id].nodes.get(invocation_type):
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
stats = self._stats[graph_id].nodes[invocation_type]
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
@ -262,14 +257,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
stats.max_vram = max(stats.max_vram, vram_used)
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
"""
|
|
||||||
Send the statistics to the system logger at the info level.
|
|
||||||
Stats will only be printed when the execution of the graph
|
|
||||||
is complete.
|
|
||||||
"""
|
|
||||||
completed = set()
|
completed = set()
|
||||||
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
|
try:
|
||||||
current_graph_state = self.graph_execution_manager.get(graph_id)
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
except Exception:
|
||||||
|
errored.add(graph_id)
|
||||||
|
continue
|
||||||
|
|
||||||
if not current_graph_state.is_complete():
|
if not current_graph_state.is_complete():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -302,3 +298,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
for graph_id in completed:
|
for graph_id in completed:
|
||||||
del self._stats[graph_id]
|
del self._stats[graph_id]
|
||||||
del self._cache_stats[graph_id]
|
del self._cache_stats[graph_id]
|
||||||
|
|
||||||
|
for graph_id in errored:
|
||||||
|
del self._stats[graph_id]
|
||||||
|
del self._cache_stats[graph_id]
|
||||||
|
@ -330,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
# configuration value. If present, then the
|
# configuration value. If present, then the
|
||||||
# cache size is set to 2.5 GB times
|
# cache size is set to 2.5 GB times
|
||||||
# the number of max_loaded_models. Otherwise
|
# the number of max_loaded_models. Otherwise
|
||||||
# use new `max_cache_size` config setting
|
# use new `ram_cache_size` config setting
|
||||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
max_cache_size = config.ram_cache_size
|
||||||
|
|
||||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||||
|
|
||||||
|
56
invokeai/backend/image_util/lama.py
Normal file
56
invokeai/backend/image_util/lama.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import gc
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def norm_img(np_img):
|
||||||
|
if len(np_img.shape) == 2:
|
||||||
|
np_img = np_img[:, :, np.newaxis]
|
||||||
|
np_img = np.transpose(np_img, (2, 0, 1))
|
||||||
|
np_img = np_img.astype("float32") / 255
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
|
def load_jit_model(url_or_path, device):
|
||||||
|
model_path = url_or_path
|
||||||
|
print(f"Loading model from: {model_path}")
|
||||||
|
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class LaMA:
|
||||||
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
|
device = choose_torch_device()
|
||||||
|
model_location = get_invokeai_config().models_path / "core/misc/lama/lama.pt"
|
||||||
|
model = load_jit_model(model_location, device)
|
||||||
|
|
||||||
|
image = np.asarray(input_image.convert("RGB"))
|
||||||
|
image = norm_img(image)
|
||||||
|
|
||||||
|
mask = input_image.split()[-1]
|
||||||
|
mask = np.asarray(mask)
|
||||||
|
mask = np.invert(mask)
|
||||||
|
mask = norm_img(mask)
|
||||||
|
|
||||||
|
mask = (mask > 0) * 1
|
||||||
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||||
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
infilled_image = model(image, mask)
|
||||||
|
|
||||||
|
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
|
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||||
|
infilled_image = Image.fromarray(infilled_image)
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return infilled_image
|
@ -21,6 +21,7 @@ from argparse import Namespace
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
|
from typing import get_type_hints, get_args, Any
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
@ -50,6 +51,7 @@ from invokeai.frontend.install.model_install import addModelsForm, process_and_e
|
|||||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
|
MultiSelectColumns,
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
FileBox,
|
FileBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
@ -71,6 +73,10 @@ warnings.filterwarnings("ignore")
|
|||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
|
def get_literal_fields(field) -> list[Any]:
|
||||||
|
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||||
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
@ -80,7 +86,11 @@ Model_dir = "models"
|
|||||||
Default_config_file = config.model_conf_path
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = config.legacy_conf_path
|
SD_Configs = config.legacy_conf_path
|
||||||
|
|
||||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
PRECISION_CHOICES = get_literal_fields("precision")
|
||||||
|
DEVICE_CHOICES = get_literal_fields("device")
|
||||||
|
ATTENTION_CHOICES = get_literal_fields("attention_type")
|
||||||
|
ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||||
|
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||||
GB = 1073741824 # GB in bytes
|
GB = 1073741824 # GB in bytes
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
HAS_CUDA = torch.cuda.is_available()
|
||||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||||
@ -311,6 +321,7 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||||
"""
|
"""
|
||||||
|
self.nextrely -= 1
|
||||||
for i in textwrap.wrap(label, width=window_width - 6):
|
for i in textwrap.wrap(label, width=window_width - 6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
@ -337,76 +348,129 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
|
||||||
self.add_widget_intelligent(
|
# old settings for defaults
|
||||||
npyscreen.TitleFixedText,
|
|
||||||
name="GPU Management",
|
|
||||||
begin_entry_at=0,
|
|
||||||
editable=False,
|
|
||||||
color="CONTROL",
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.nextrely -= 1
|
|
||||||
self.free_gpu_mem = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Free GPU memory after each generation",
|
|
||||||
value=old_opts.free_gpu_mem,
|
|
||||||
max_width=45,
|
|
||||||
relx=5,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.nextrely -= 1
|
|
||||||
self.xformers_enabled = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Enable xformers support",
|
|
||||||
value=old_opts.xformers_enabled,
|
|
||||||
max_width=30,
|
|
||||||
relx=50,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.nextrely -= 1
|
|
||||||
self.always_use_cpu = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Force CPU to be used on GPU systems",
|
|
||||||
value=old_opts.always_use_cpu,
|
|
||||||
relx=80,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||||
|
device = old_opts.device
|
||||||
|
attention_type = old_opts.attention_type
|
||||||
|
attention_slice_size = old_opts.attention_slice_size
|
||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="Floating Point Precision",
|
name="Image Generation Options:",
|
||||||
|
editable=False,
|
||||||
|
color="CONTROL",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrely -= 2
|
||||||
|
self.generation_options = self.add_widget_intelligent(
|
||||||
|
MultiSelectColumns,
|
||||||
|
columns=3,
|
||||||
|
values=GENERATION_OPT_CHOICES,
|
||||||
|
value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)],
|
||||||
|
relx=30,
|
||||||
|
max_height=2,
|
||||||
|
max_width=80,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name="Floating Point Precision:",
|
||||||
begin_entry_at=0,
|
begin_entry_at=0,
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
self.nextrely -= 2
|
||||||
self.precision = self.add_widget_intelligent(
|
self.precision = self.add_widget_intelligent(
|
||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
columns=3,
|
columns=len(PRECISION_CHOICES),
|
||||||
name="Precision",
|
name="Precision",
|
||||||
values=PRECISION_CHOICES,
|
values=PRECISION_CHOICES,
|
||||||
value=PRECISION_CHOICES.index(precision),
|
value=PRECISION_CHOICES.index(precision),
|
||||||
begin_entry_at=3,
|
begin_entry_at=3,
|
||||||
max_height=2,
|
max_height=2,
|
||||||
|
relx=30,
|
||||||
|
max_width=56,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name="Generation Device:",
|
||||||
|
begin_entry_at=0,
|
||||||
|
editable=False,
|
||||||
|
color="CONTROL",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrely -= 2
|
||||||
|
self.device = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
columns=len(DEVICE_CHOICES),
|
||||||
|
values=DEVICE_CHOICES,
|
||||||
|
value=DEVICE_CHOICES.index(device),
|
||||||
|
begin_entry_at=3,
|
||||||
|
relx=30,
|
||||||
|
max_height=2,
|
||||||
|
max_width=60,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name="Attention Type:",
|
||||||
|
begin_entry_at=0,
|
||||||
|
editable=False,
|
||||||
|
color="CONTROL",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrely -= 2
|
||||||
|
self.attention_type = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
columns=len(ATTENTION_CHOICES),
|
||||||
|
values=ATTENTION_CHOICES,
|
||||||
|
value=ATTENTION_CHOICES.index(attention_type),
|
||||||
|
begin_entry_at=3,
|
||||||
|
max_height=2,
|
||||||
|
relx=30,
|
||||||
max_width=80,
|
max_width=80,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.attention_type.on_changed = self.show_hide_slice_sizes
|
||||||
|
self.attention_slice_label = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFixedText,
|
||||||
|
name="Attention Slice Size:",
|
||||||
|
relx=5,
|
||||||
|
editable=False,
|
||||||
|
hidden=attention_type != "sliced",
|
||||||
|
color="CONTROL",
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.nextrely -= 2
|
||||||
|
self.attention_slice_size = self.add_widget_intelligent(
|
||||||
|
SingleSelectColumns,
|
||||||
|
columns=len(ATTENTION_SLICE_CHOICES),
|
||||||
|
values=ATTENTION_SLICE_CHOICES,
|
||||||
|
value=ATTENTION_SLICE_CHOICES.index(attention_slice_size),
|
||||||
|
relx=30,
|
||||||
|
hidden=attention_type != "sliced",
|
||||||
|
max_height=2,
|
||||||
|
max_width=110,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||||
begin_entry_at=0,
|
begin_entry_at=0,
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
self.nextrely -= 1
|
||||||
self.max_cache_size = self.add_widget_intelligent(
|
self.ram = self.add_widget_intelligent(
|
||||||
npyscreen.Slider,
|
npyscreen.Slider,
|
||||||
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
value=clip(old_opts.ram_cache_size, range=(3.0, MAX_RAM), step=0.5),
|
||||||
out_of=round(MAX_RAM),
|
out_of=round(MAX_RAM),
|
||||||
lowest=0.0,
|
lowest=0.0,
|
||||||
step=0.5,
|
step=0.5,
|
||||||
@ -417,16 +481,16 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||||
begin_entry_at=0,
|
begin_entry_at=0,
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
self.nextrely -= 1
|
||||||
self.max_vram_cache_size = self.add_widget_intelligent(
|
self.vram = self.add_widget_intelligent(
|
||||||
npyscreen.Slider,
|
npyscreen.Slider,
|
||||||
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
value=clip(old_opts.vram_cache_size, range=(0, MAX_VRAM), step=0.25),
|
||||||
out_of=round(MAX_VRAM * 2) / 2,
|
out_of=round(MAX_VRAM * 2) / 2,
|
||||||
lowest=0.0,
|
lowest=0.0,
|
||||||
relx=8,
|
relx=8,
|
||||||
@ -434,7 +498,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.max_vram_cache_size = DummyWidgetValue.zero
|
self.vram_cache_size = DummyWidgetValue.zero
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.outdir = self.add_widget_intelligent(
|
self.outdir = self.add_widget_intelligent(
|
||||||
FileBox,
|
FileBox,
|
||||||
@ -490,6 +554,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
when_pressed_function=self.on_ok,
|
when_pressed_function=self.on_ok,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def show_hide_slice_sizes(self, value):
|
||||||
|
show = ATTENTION_CHOICES[value[0]] == "sliced"
|
||||||
|
self.attention_slice_label.hidden = not show
|
||||||
|
self.attention_slice_size.hidden = not show
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
options = self.marshall_arguments()
|
options = self.marshall_arguments()
|
||||||
if self.validate_field_values(options):
|
if self.validate_field_values(options):
|
||||||
@ -523,12 +592,9 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
new_opts = Namespace()
|
new_opts = Namespace()
|
||||||
|
|
||||||
for attr in [
|
for attr in [
|
||||||
|
"ram",
|
||||||
|
"vram",
|
||||||
"outdir",
|
"outdir",
|
||||||
"free_gpu_mem",
|
|
||||||
"max_cache_size",
|
|
||||||
"max_vram_cache_size",
|
|
||||||
"xformers_enabled",
|
|
||||||
"always_use_cpu",
|
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
@ -541,6 +607,12 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
new_opts.hf_token = self.hf_token.value
|
new_opts.hf_token = self.hf_token.value
|
||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
|
new_opts.device = DEVICE_CHOICES[self.device.value[0]]
|
||||||
|
new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]]
|
||||||
|
new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]]
|
||||||
|
generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
|
||||||
|
for v in GENERATION_OPT_CHOICES:
|
||||||
|
setattr(new_opts, v, v in generation_options)
|
||||||
|
|
||||||
return new_opts
|
return new_opts
|
||||||
|
|
||||||
|
@ -20,11 +20,36 @@
|
|||||||
import re
|
import re
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.models import (
|
||||||
|
AutoencoderKL,
|
||||||
|
ControlNetModel,
|
||||||
|
PriorTransformer,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||||
|
from diffusers.schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
DDPMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
HeunDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
UnCLIPScheduler,
|
||||||
|
)
|
||||||
|
from diffusers.utils import is_accelerate_available, is_omegaconf_available
|
||||||
|
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
@ -37,35 +62,8 @@ from transformers import (
|
|||||||
CLIPVisionModelWithProjection,
|
CLIPVisionModelWithProjection,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.models import (
|
|
||||||
AutoencoderKL,
|
|
||||||
ControlNetModel,
|
|
||||||
PriorTransformer,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
)
|
|
||||||
from diffusers.schedulers import (
|
|
||||||
DDIMScheduler,
|
|
||||||
DDPMScheduler,
|
|
||||||
DPMSolverMultistepScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
|
||||||
EulerDiscreteScheduler,
|
|
||||||
HeunDiscreteScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
|
||||||
PNDMScheduler,
|
|
||||||
UnCLIPScheduler,
|
|
||||||
)
|
|
||||||
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
|
||||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
|
||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
from .models import BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1221,9 +1219,6 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
if not is_safetensors_available():
|
|
||||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
||||||
|
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
|
|
||||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||||
@ -1662,9 +1657,6 @@ def download_controlnet_from_original_ckpt(
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
if from_safetensors:
|
if from_safetensors:
|
||||||
if not is_safetensors_available():
|
|
||||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
||||||
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
@ -1741,7 +1733,7 @@ def convert_ckpt_to_diffusers(
|
|||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
dump_path,
|
dump_path,
|
||||||
safe_serialization=use_safetensors and is_safetensors_available(),
|
safe_serialization=use_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1757,7 +1749,4 @@ def convert_controlnet_to_diffusers(
|
|||||||
"""
|
"""
|
||||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||||
dump_path,
|
|
||||||
safe_serialization=is_safetensors_available(),
|
|
||||||
)
|
|
||||||
|
@ -341,7 +341,8 @@ class ModelManager(object):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
max_vram_cache_size=self.app_config.max_vram_cache_size,
|
max_vram_cache_size=self.app_config.vram_cache_size,
|
||||||
|
lazy_offloading=self.app_config.lazy_offload,
|
||||||
execution_device=device_type,
|
execution_device=device_type,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
sequential_offload=sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
@ -175,5 +174,5 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
vae_config=config,
|
vae_config=config,
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
)
|
)
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
|
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -33,7 +33,7 @@ from .diffusion import (
|
|||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
)
|
)
|
||||||
from ..util import normalize_device
|
from ..util import normalize_device, auto_detect_slice_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -291,6 +291,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
if config.attention_type == "xformers":
|
||||||
|
self.enable_xformers_memory_efficient_attention()
|
||||||
|
return
|
||||||
|
elif config.attention_type == "sliced":
|
||||||
|
slice_size = config.attention_slice_size
|
||||||
|
if slice_size == "auto":
|
||||||
|
slice_size = auto_detect_slice_size(latents)
|
||||||
|
elif slice_size == "balanced":
|
||||||
|
slice_size = "auto"
|
||||||
|
self.enable_attention_slicing(slice_size=slice_size)
|
||||||
|
return
|
||||||
|
elif config.attention_type == "normal":
|
||||||
|
self.disable_attention_slicing()
|
||||||
|
return
|
||||||
|
elif config.attention_type == "torch-sdp":
|
||||||
|
raise Exception("torch-sdp attention slicing not yet implemented")
|
||||||
|
|
||||||
|
# the remainder if this code is called when attention_type=='auto'
|
||||||
if self.unet.device.type == "cuda":
|
if self.unet.device.type == "cuda":
|
||||||
if is_xformers_available() and not config.disable_xformers:
|
if is_xformers_available() and not config.disable_xformers:
|
||||||
self.enable_xformers_memory_efficient_attention()
|
self.enable_xformers_memory_efficient_attention()
|
||||||
|
@ -11,4 +11,11 @@ from .devices import ( # noqa: F401
|
|||||||
torch_dtype,
|
torch_dtype,
|
||||||
)
|
)
|
||||||
from .log import write_log # noqa: F401
|
from .log import write_log # noqa: F401
|
||||||
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir # noqa: F401
|
from .util import ( # noqa: F401
|
||||||
|
ask_user,
|
||||||
|
download_with_resume,
|
||||||
|
instantiate_from_config,
|
||||||
|
url_attachment_name,
|
||||||
|
Chdir,
|
||||||
|
)
|
||||||
|
from .attention import auto_detect_slice_size # noqa: F401
|
||||||
|
32
invokeai/backend/util/attention.py
Normal file
32
invokeai/backend/util/attention.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
|
||||||
|
"""
|
||||||
|
Utility routine used for autodetection of optimal slice size
|
||||||
|
for attention mechanism.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def auto_detect_slice_size(latents: torch.Tensor) -> str:
|
||||||
|
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||||
|
max_size_required_for_baddbmm = (
|
||||||
|
16
|
||||||
|
* latents.size(dim=2)
|
||||||
|
* latents.size(dim=3)
|
||||||
|
* latents.size(dim=2)
|
||||||
|
* latents.size(dim=3)
|
||||||
|
* bytes_per_element_needed_for_baddbmm_duplication
|
||||||
|
)
|
||||||
|
if latents.device.type in {"cpu", "mps"}:
|
||||||
|
mem_free = psutil.virtual_memory().free
|
||||||
|
elif latents.device.type == "cuda":
|
||||||
|
mem_free, _ = torch.cuda.mem_get_info(latents.device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unrecognized device {latents.device}")
|
||||||
|
|
||||||
|
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
|
||||||
|
return "max"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return "max"
|
||||||
|
else:
|
||||||
|
return "balanced"
|
@ -17,13 +17,17 @@ config = InvokeAIAppConfig.get_config()
|
|||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
if config.always_use_cpu:
|
if config.use_cpu: # legacy setting - force CPU
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
|
elif config.device == "auto":
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device("cuda")
|
return torch.device("cuda")
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
|
else:
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
|
else:
|
||||||
|
return torch.device(config.device)
|
||||||
|
|
||||||
|
|
||||||
def choose_precision(device: torch.device) -> str:
|
def choose_precision(device: torch.device) -> str:
|
||||||
|
@ -17,8 +17,8 @@ from shutil import get_terminal_size
|
|||||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||||
|
|
||||||
# minimum size for UIs
|
# minimum size for UIs
|
||||||
MIN_COLS = 130
|
MIN_COLS = 150
|
||||||
MIN_LINES = 38
|
MIN_LINES = 40
|
||||||
|
|
||||||
|
|
||||||
class WindowTooSmallException(Exception):
|
class WindowTooSmallException(Exception):
|
||||||
@ -277,6 +277,9 @@ class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged):
|
|||||||
def h_cursor_line_right(self, ch):
|
def h_cursor_line_right(self, ch):
|
||||||
self.h_exit_down("bye bye")
|
self.h_exit_down("bye bye")
|
||||||
|
|
||||||
|
def h_cursor_line_left(self, ch):
|
||||||
|
self.h_exit_up("bye bye")
|
||||||
|
|
||||||
|
|
||||||
class TextBoxInner(npyscreen.MultiLineEdit):
|
class TextBoxInner(npyscreen.MultiLineEdit):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -324,55 +327,6 @@ class TextBoxInner(npyscreen.MultiLineEdit):
|
|||||||
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
|
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
|
||||||
self.h_paste()
|
self.h_paste()
|
||||||
|
|
||||||
# def update(self, clear=True):
|
|
||||||
# if clear:
|
|
||||||
# self.clear()
|
|
||||||
|
|
||||||
# HEIGHT = self.height
|
|
||||||
# WIDTH = self.width
|
|
||||||
# # draw box.
|
|
||||||
# self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
|
||||||
# self.parent.curses_pad.hline(
|
|
||||||
# self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH
|
|
||||||
# )
|
|
||||||
# self.parent.curses_pad.vline(
|
|
||||||
# self.rely, self.relx, curses.ACS_VLINE, self.height
|
|
||||||
# )
|
|
||||||
# self.parent.curses_pad.vline(
|
|
||||||
# self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # draw corners
|
|
||||||
# self.parent.curses_pad.addch(
|
|
||||||
# self.rely,
|
|
||||||
# self.relx,
|
|
||||||
# curses.ACS_ULCORNER,
|
|
||||||
# )
|
|
||||||
# self.parent.curses_pad.addch(
|
|
||||||
# self.rely,
|
|
||||||
# self.relx + WIDTH,
|
|
||||||
# curses.ACS_URCORNER,
|
|
||||||
# )
|
|
||||||
# self.parent.curses_pad.addch(
|
|
||||||
# self.rely + HEIGHT,
|
|
||||||
# self.relx,
|
|
||||||
# curses.ACS_LLCORNER,
|
|
||||||
# )
|
|
||||||
# self.parent.curses_pad.addch(
|
|
||||||
# self.rely + HEIGHT,
|
|
||||||
# self.relx + WIDTH,
|
|
||||||
# curses.ACS_LRCORNER,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
|
||||||
# (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width)
|
|
||||||
# self.relx += 1
|
|
||||||
# self.rely += 1
|
|
||||||
# self.height -= 1
|
|
||||||
# self.width -= 1
|
|
||||||
# super().update(clear=False)
|
|
||||||
# (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width)
|
|
||||||
|
|
||||||
|
|
||||||
class TextBox(npyscreen.BoxTitle):
|
class TextBox(npyscreen.BoxTitle):
|
||||||
_contained_widget = TextBoxInner
|
_contained_widget = TextBoxInner
|
||||||
|
@ -9,8 +9,8 @@ module.exports = {
|
|||||||
'plugin:@typescript-eslint/recommended',
|
'plugin:@typescript-eslint/recommended',
|
||||||
'plugin:react/recommended',
|
'plugin:react/recommended',
|
||||||
'plugin:react-hooks/recommended',
|
'plugin:react-hooks/recommended',
|
||||||
'plugin:prettier/recommended',
|
|
||||||
'plugin:react/jsx-runtime',
|
'plugin:react/jsx-runtime',
|
||||||
|
'prettier',
|
||||||
],
|
],
|
||||||
parser: '@typescript-eslint/parser',
|
parser: '@typescript-eslint/parser',
|
||||||
parserOptions: {
|
parserOptions: {
|
||||||
@ -23,6 +23,11 @@ module.exports = {
|
|||||||
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
|
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
|
||||||
root: true,
|
root: true,
|
||||||
rules: {
|
rules: {
|
||||||
|
curly: 'error',
|
||||||
|
'react/jsx-curly-brace-presence': [
|
||||||
|
'error',
|
||||||
|
{ props: 'never', children: 'never' },
|
||||||
|
],
|
||||||
'react-hooks/exhaustive-deps': 'error',
|
'react-hooks/exhaustive-deps': 'error',
|
||||||
'no-var': 'error',
|
'no-var': 'error',
|
||||||
'brace-style': 'error',
|
'brace-style': 'error',
|
||||||
@ -34,7 +39,6 @@ module.exports = {
|
|||||||
'warn',
|
'warn',
|
||||||
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
|
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
|
||||||
],
|
],
|
||||||
'prettier/prettier': ['error', { endOfLine: 'auto' }],
|
|
||||||
'@typescript-eslint/ban-ts-comment': 'warn',
|
'@typescript-eslint/ban-ts-comment': 'warn',
|
||||||
'@typescript-eslint/no-explicit-any': 'warn',
|
'@typescript-eslint/no-explicit-any': 'warn',
|
||||||
'@typescript-eslint/no-empty-interface': [
|
'@typescript-eslint/no-empty-interface': [
|
||||||
|
@ -29,12 +29,13 @@
|
|||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
"lint:tsc": "tsc --noEmit",
|
"lint:tsc": "tsc --noEmit",
|
||||||
"lint": "yarn run lint:eslint && yarn run lint:prettier && yarn run lint:tsc && yarn run lint:madge",
|
"lint": "concurrently -g -n eslint,prettier,tsc,madge -c cyan,green,magenta,yellow \"yarn run lint:eslint\" \"yarn run lint:prettier\" \"yarn run lint:tsc\" \"yarn run lint:madge\"",
|
||||||
"fix": "eslint --fix . && prettier --loglevel warn --write . && tsc --noEmit",
|
"fix": "eslint --fix . && prettier --loglevel warn --write . && tsc --noEmit",
|
||||||
"lint-staged": "lint-staged",
|
"lint-staged": "lint-staged",
|
||||||
"postinstall": "patch-package && yarn run theme",
|
"postinstall": "patch-package && yarn run theme",
|
||||||
"theme": "chakra-cli tokens src/theme/theme.ts",
|
"theme": "chakra-cli tokens src/theme/theme.ts",
|
||||||
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch"
|
"theme:watch": "chakra-cli tokens src/theme/theme.ts --watch",
|
||||||
|
"up": "yarn upgrade-interactive --latest"
|
||||||
},
|
},
|
||||||
"madge": {
|
"madge": {
|
||||||
"detectiveOptions": {
|
"detectiveOptions": {
|
||||||
@ -54,7 +55,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@chakra-ui/anatomy": "^2.2.0",
|
"@chakra-ui/anatomy": "^2.2.0",
|
||||||
"@chakra-ui/icons": "^2.0.19",
|
"@chakra-ui/icons": "^2.1.0",
|
||||||
"@chakra-ui/react": "^2.8.0",
|
"@chakra-ui/react": "^2.8.0",
|
||||||
"@chakra-ui/styled-system": "^2.9.1",
|
"@chakra-ui/styled-system": "^2.9.1",
|
||||||
"@chakra-ui/theme-tools": "^2.1.0",
|
"@chakra-ui/theme-tools": "^2.1.0",
|
||||||
@ -65,55 +66,55 @@
|
|||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/styled": "^11.11.0",
|
"@emotion/styled": "^11.11.0",
|
||||||
"@floating-ui/react-dom": "^2.0.1",
|
"@floating-ui/react-dom": "^2.0.1",
|
||||||
"@fontsource-variable/inter": "^5.0.3",
|
"@fontsource-variable/inter": "^5.0.8",
|
||||||
"@fontsource/inter": "^5.0.3",
|
"@fontsource/inter": "^5.0.8",
|
||||||
"@mantine/core": "^6.0.14",
|
"@mantine/core": "^6.0.19",
|
||||||
"@mantine/form": "^6.0.15",
|
"@mantine/form": "^6.0.19",
|
||||||
"@mantine/hooks": "^6.0.14",
|
"@mantine/hooks": "^6.0.19",
|
||||||
"@nanostores/react": "^0.7.1",
|
"@nanostores/react": "^0.7.1",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
"chakra-ui-contextmenu": "^1.0.5",
|
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"downshift": "^7.6.0",
|
"formik": "^2.4.3",
|
||||||
"formik": "^2.4.2",
|
"framer-motion": "^10.16.1",
|
||||||
"framer-motion": "^10.12.17",
|
|
||||||
"fuse.js": "^6.6.2",
|
"fuse.js": "^6.6.2",
|
||||||
"i18next": "^23.2.3",
|
"i18next": "^23.4.4",
|
||||||
"i18next-browser-languagedetector": "^7.0.2",
|
"i18next-browser-languagedetector": "^7.0.2",
|
||||||
"i18next-http-backend": "^2.2.1",
|
"i18next-http-backend": "^2.2.1",
|
||||||
"konva": "^9.2.0",
|
"konva": "^9.2.0",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"nanostores": "^0.9.2",
|
"nanostores": "^0.9.2",
|
||||||
"openapi-fetch": "^0.6.1",
|
"new-github-issue-url": "^1.0.0",
|
||||||
|
"openapi-fetch": "^0.7.4",
|
||||||
"overlayscrollbars": "^2.2.0",
|
"overlayscrollbars": "^2.2.0",
|
||||||
"overlayscrollbars-react": "^0.5.0",
|
"overlayscrollbars-react": "^0.5.0",
|
||||||
"patch-package": "^7.0.0",
|
"patch-package": "^8.0.0",
|
||||||
"query-string": "^8.1.0",
|
"query-string": "^8.1.0",
|
||||||
"re-resizable": "^6.9.9",
|
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-colorful": "^5.6.1",
|
"react-colorful": "^5.6.1",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-dropzone": "^14.2.3",
|
"react-dropzone": "^14.2.3",
|
||||||
"react-hotkeys-hook": "4.4.0",
|
"react-error-boundary": "^4.0.11",
|
||||||
"react-i18next": "^13.0.1",
|
"react-hotkeys-hook": "4.4.1",
|
||||||
|
"react-i18next": "^13.1.2",
|
||||||
"react-icons": "^4.10.1",
|
"react-icons": "^4.10.1",
|
||||||
"react-konva": "^18.2.10",
|
"react-konva": "^18.2.10",
|
||||||
"react-redux": "^8.1.1",
|
"react-redux": "^8.1.2",
|
||||||
"react-resizable-panels": "^0.0.52",
|
"react-resizable-panels": "^0.0.55",
|
||||||
"react-use": "^17.4.0",
|
"react-use": "^17.4.0",
|
||||||
"react-virtuoso": "^4.3.11",
|
"react-virtuoso": "^4.5.0",
|
||||||
"react-zoom-pan-pinch": "^3.0.8",
|
"react-zoom-pan-pinch": "^3.0.8",
|
||||||
"reactflow": "^11.7.4",
|
"reactflow": "^11.8.3",
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-remember": "^3.3.1",
|
"redux-remember": "^4.0.1",
|
||||||
"roarr": "^7.15.0",
|
"roarr": "^7.15.1",
|
||||||
"serialize-error": "^11.0.0",
|
"serialize-error": "^11.0.1",
|
||||||
"socket.io-client": "^4.7.0",
|
"socket.io-client": "^4.7.2",
|
||||||
"use-debounce": "^9.0.4",
|
"use-debounce": "^9.0.4",
|
||||||
"use-image": "^1.1.1",
|
"use-image": "^1.1.1",
|
||||||
"uuid": "^9.0.0",
|
"uuid": "^9.0.0",
|
||||||
"zod": "^3.21.4"
|
"zod": "^3.22.2",
|
||||||
|
"zod-validation-error": "^1.5.0"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@chakra-ui/cli": "^2.4.0",
|
"@chakra-ui/cli": "^2.4.0",
|
||||||
@ -126,38 +127,36 @@
|
|||||||
"@chakra-ui/cli": "^2.4.1",
|
"@chakra-ui/cli": "^2.4.1",
|
||||||
"@types/dateformat": "^5.0.0",
|
"@types/dateformat": "^5.0.0",
|
||||||
"@types/lodash-es": "^4.14.194",
|
"@types/lodash-es": "^4.14.194",
|
||||||
"@types/node": "^20.3.1",
|
"@types/node": "^20.5.1",
|
||||||
"@types/react": "^18.2.14",
|
"@types/react": "^18.2.20",
|
||||||
"@types/react-dom": "^18.2.6",
|
"@types/react-dom": "^18.2.6",
|
||||||
"@types/react-redux": "^7.1.25",
|
"@types/react-redux": "^7.1.25",
|
||||||
"@types/react-transition-group": "^4.4.6",
|
"@types/react-transition-group": "^4.4.6",
|
||||||
"@types/uuid": "^9.0.2",
|
"@types/uuid": "^9.0.2",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.60.0",
|
"@typescript-eslint/eslint-plugin": "^6.4.1",
|
||||||
"@typescript-eslint/parser": "^5.60.0",
|
"@typescript-eslint/parser": "^6.4.1",
|
||||||
"@vitejs/plugin-react-swc": "^3.3.2",
|
"@vitejs/plugin-react-swc": "^3.3.2",
|
||||||
"axios": "^1.4.0",
|
"axios": "^1.4.0",
|
||||||
"babel-plugin-transform-imports": "^2.0.0",
|
"babel-plugin-transform-imports": "^2.0.0",
|
||||||
"concurrently": "^8.2.0",
|
"concurrently": "^8.2.0",
|
||||||
"eslint": "^8.43.0",
|
"eslint": "^8.47.0",
|
||||||
"eslint-config-prettier": "^8.8.0",
|
"eslint-config-prettier": "^9.0.0",
|
||||||
"eslint-plugin-prettier": "^4.2.1",
|
"eslint-plugin-prettier": "^5.0.0",
|
||||||
"eslint-plugin-react": "^7.32.2",
|
"eslint-plugin-react": "^7.33.2",
|
||||||
"eslint-plugin-react-hooks": "^4.6.0",
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
"form-data": "^4.0.0",
|
"form-data": "^4.0.0",
|
||||||
"husky": "^8.0.3",
|
"husky": "^8.0.3",
|
||||||
"lint-staged": "^13.2.2",
|
"lint-staged": "^14.0.1",
|
||||||
"madge": "^6.1.0",
|
"madge": "^6.1.0",
|
||||||
"openapi-types": "^12.1.3",
|
"openapi-types": "^12.1.3",
|
||||||
"openapi-typescript": "^6.2.8",
|
"openapi-typescript": "^6.5.2",
|
||||||
"openapi-typescript-codegen": "^0.24.0",
|
|
||||||
"postinstall-postinstall": "^2.1.0",
|
"postinstall-postinstall": "^2.1.0",
|
||||||
"prettier": "^2.8.8",
|
"prettier": "^3.0.2",
|
||||||
"rollup-plugin-visualizer": "^5.9.2",
|
"rollup-plugin-visualizer": "^5.9.2",
|
||||||
"terser": "^5.18.1",
|
|
||||||
"ts-toolbelt": "^9.6.0",
|
"ts-toolbelt": "^9.6.0",
|
||||||
"vite": "^4.3.9",
|
"vite": "^4.4.9",
|
||||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
"vite-plugin-css-injected-by-js": "^3.3.0",
|
||||||
"vite-plugin-dts": "^2.3.0",
|
"vite-plugin-dts": "^3.5.2",
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"yarn": "^1.22.19"
|
"yarn": "^1.22.19"
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"toggleAutoscroll": "Toggle autoscroll",
|
"toggleAutoscroll": "Toggle autoscroll",
|
||||||
"toggleLogViewer": "Toggle Log Viewer",
|
"toggleLogViewer": "Toggle Log Viewer",
|
||||||
"showGallery": "Show Gallery",
|
"showGallery": "Show Gallery",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Side Panel",
|
||||||
"menu": "Menu"
|
"menu": "Menu"
|
||||||
},
|
},
|
||||||
"common": {
|
"common": {
|
||||||
@ -52,7 +52,7 @@
|
|||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Node Editor",
|
"nodes": "Workflow Editor",
|
||||||
"batch": "Batch Manager",
|
"batch": "Batch Manager",
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
@ -95,7 +95,6 @@
|
|||||||
"statusModelConverted": "Model Converted",
|
"statusModelConverted": "Model Converted",
|
||||||
"statusMergingModels": "Merging Models",
|
"statusMergingModels": "Merging Models",
|
||||||
"statusMergedModels": "Models Merged",
|
"statusMergedModels": "Models Merged",
|
||||||
"pinOptionsPanel": "Pin Options Panel",
|
|
||||||
"loading": "Loading",
|
"loading": "Loading",
|
||||||
"loadingInvokeAI": "Loading Invoke AI",
|
"loadingInvokeAI": "Loading Invoke AI",
|
||||||
"random": "Random",
|
"random": "Random",
|
||||||
@ -116,7 +115,6 @@
|
|||||||
"maintainAspectRatio": "Maintain Aspect Ratio",
|
"maintainAspectRatio": "Maintain Aspect Ratio",
|
||||||
"autoSwitchNewImages": "Auto-Switch to New Images",
|
"autoSwitchNewImages": "Auto-Switch to New Images",
|
||||||
"singleColumnLayout": "Single Column Layout",
|
"singleColumnLayout": "Single Column Layout",
|
||||||
"pinGallery": "Pin Gallery",
|
|
||||||
"allImagesLoaded": "All Images Loaded",
|
"allImagesLoaded": "All Images Loaded",
|
||||||
"loadMore": "Load More",
|
"loadMore": "Load More",
|
||||||
"noImagesInGallery": "No Images to Display",
|
"noImagesInGallery": "No Images to Display",
|
||||||
@ -133,6 +131,7 @@
|
|||||||
"generalHotkeys": "General Hotkeys",
|
"generalHotkeys": "General Hotkeys",
|
||||||
"galleryHotkeys": "Gallery Hotkeys",
|
"galleryHotkeys": "Gallery Hotkeys",
|
||||||
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
|
"unifiedCanvasHotkeys": "Unified Canvas Hotkeys",
|
||||||
|
"nodesHotkeys": "Nodes Hotkeys",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"title": "Invoke",
|
"title": "Invoke",
|
||||||
"desc": "Generate an image"
|
"desc": "Generate an image"
|
||||||
@ -332,6 +331,10 @@
|
|||||||
"acceptStagingImage": {
|
"acceptStagingImage": {
|
||||||
"title": "Accept Staging Image",
|
"title": "Accept Staging Image",
|
||||||
"desc": "Accept Current Staging Area Image"
|
"desc": "Accept Current Staging Area Image"
|
||||||
|
},
|
||||||
|
"addNodes": {
|
||||||
|
"title": "Add Nodes",
|
||||||
|
"desc": "Opens the add node menu"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"modelManager": {
|
"modelManager": {
|
||||||
@ -506,12 +509,9 @@
|
|||||||
"maskAdjustmentsHeader": "Mask Adjustments",
|
"maskAdjustmentsHeader": "Mask Adjustments",
|
||||||
"maskBlur": "Mask Blur",
|
"maskBlur": "Mask Blur",
|
||||||
"maskBlurMethod": "Mask Blur Method",
|
"maskBlurMethod": "Mask Blur Method",
|
||||||
"seamPaintingHeader": "Seam Painting",
|
"coherencePassHeader": "Coherence Pass",
|
||||||
"seamSize": "Seam Size",
|
"coherenceSteps": "Coherence Pass Steps",
|
||||||
"seamBlur": "Seam Blur",
|
"coherenceStrength": "Coherence Pass Strength",
|
||||||
"seamSteps": "Seam Steps",
|
|
||||||
"seamStrength": "Seam Strength",
|
|
||||||
"seamThreshold": "Seam Threshold",
|
|
||||||
"seamLowThreshold": "Low",
|
"seamLowThreshold": "Low",
|
||||||
"seamHighThreshold": "High",
|
"seamHighThreshold": "High",
|
||||||
"scaleBeforeProcessing": "Scale Before Processing",
|
"scaleBeforeProcessing": "Scale Before Processing",
|
||||||
@ -572,7 +572,7 @@
|
|||||||
"resetWebUI": "Reset Web UI",
|
"resetWebUI": "Reset Web UI",
|
||||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||||
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
|
"resetComplete": "Web UI has been reset.",
|
||||||
"consoleLogLevel": "Log Level",
|
"consoleLogLevel": "Log Level",
|
||||||
"shouldLogToConsole": "Console Logging",
|
"shouldLogToConsole": "Console Logging",
|
||||||
"developer": "Developer",
|
"developer": "Developer",
|
||||||
@ -715,11 +715,12 @@
|
|||||||
"swapSizes": "Swap Sizes"
|
"swapSizes": "Swap Sizes"
|
||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"reloadSchema": "Reload Schema",
|
"reloadNodeTemplates": "Reload Node Templates",
|
||||||
"saveGraph": "Save Graph",
|
"saveWorkflow": "Save Workflow",
|
||||||
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
"loadWorkflow": "Load Workflow",
|
||||||
"clearGraph": "Clear Graph",
|
"resetWorkflow": "Reset Workflow",
|
||||||
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
|
||||||
|
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
|
||||||
"zoomInNodes": "Zoom In",
|
"zoomInNodes": "Zoom In",
|
||||||
"zoomOutNodes": "Zoom Out",
|
"zoomOutNodes": "Zoom Out",
|
||||||
"fitViewportNodes": "Fit View",
|
"fitViewportNodes": "Fit View",
|
||||||
|
@ -27,22 +27,10 @@ async function main() {
|
|||||||
* field accepts connection input. If it does, we can make the field optional.
|
* field accepts connection input. If it does, we can make the field optional.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Check if we are generating types for an invocation
|
if ('class' in schemaObject && schemaObject.class === 'invocation') {
|
||||||
const isInvocationPath = metadata.path.match(
|
|
||||||
/^#\/components\/schemas\/\w*Invocation$/
|
|
||||||
);
|
|
||||||
|
|
||||||
const hasInvocationProperties =
|
|
||||||
schemaObject.properties &&
|
|
||||||
['id', 'is_intermediate', 'type'].every(
|
|
||||||
(prop) => prop in schemaObject.properties
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isInvocationPath && hasInvocationProperties) {
|
|
||||||
// We only want to make fields optional if they are required
|
// We only want to make fields optional if they are required
|
||||||
if (!Array.isArray(schemaObject?.required)) {
|
if (!Array.isArray(schemaObject?.required)) {
|
||||||
schemaObject.required = ['id', 'type'];
|
schemaObject.required = [];
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
schemaObject.required.forEach((prop) => {
|
schemaObject.required.forEach((prop) => {
|
||||||
@ -61,19 +49,13 @@ async function main() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
schemaObject.required = [
|
|
||||||
...new Set(schemaObject.required.concat(['id', 'type'])),
|
|
||||||
];
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// if (
|
|
||||||
// 'input' in schemaObject &&
|
// Check if we are generating types for an invocation output
|
||||||
// (schemaObject.input === 'any' || schemaObject.input === 'connection')
|
if ('class' in schemaObject && schemaObject.class === 'output') {
|
||||||
// ) {
|
// modify output types
|
||||||
// schemaObject.required = false;
|
}
|
||||||
// }
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
fs.writeFileSync(OUTPUT_FILE, types);
|
fs.writeFileSync(OUTPUT_FILE, types);
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Flex, Grid, Portal } from '@chakra-ui/react';
|
import { Flex, Grid } from '@chakra-ui/react';
|
||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -6,17 +6,15 @@ import { PartialAppConfig } from 'app/types/invokeai';
|
|||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
|
||||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { ReactNode, memo, useEffect } from 'react';
|
import { ReactNode, memo, useCallback, useEffect } from 'react';
|
||||||
|
import { ErrorBoundary } from 'react-error-boundary';
|
||||||
|
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
|
|
||||||
@ -30,8 +28,13 @@ interface Props {
|
|||||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(languageSelector);
|
||||||
|
|
||||||
const logger = useLogger();
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
localStorage.clear();
|
||||||
|
location.reload();
|
||||||
|
return false;
|
||||||
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
i18n.changeLanguage(language);
|
i18n.changeLanguage(language);
|
||||||
@ -39,7 +42,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (size(config)) {
|
if (size(config)) {
|
||||||
logger.info({ namespace: 'App', config }, 'Received config');
|
logger.info({ config }, 'Received config');
|
||||||
dispatch(configChanged(config));
|
dispatch(configChanged(config));
|
||||||
}
|
}
|
||||||
}, [dispatch, config, logger]);
|
}, [dispatch, config, logger]);
|
||||||
@ -49,7 +52,10 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<ErrorBoundary
|
||||||
|
onReset={handleReset}
|
||||||
|
FallbackComponent={AppErrorBoundaryFallback}
|
||||||
|
>
|
||||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||||
<ImageUploader>
|
<ImageUploader>
|
||||||
<Grid
|
<Grid
|
||||||
@ -73,21 +79,12 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Grid>
|
</Grid>
|
||||||
</ImageUploader>
|
</ImageUploader>
|
||||||
|
|
||||||
<GalleryDrawer />
|
|
||||||
<ParametersDrawer />
|
|
||||||
<Portal>
|
|
||||||
<FloatingParametersPanelButtons />
|
|
||||||
</Portal>
|
|
||||||
<Portal>
|
|
||||||
<FloatingGalleryButton />
|
|
||||||
</Portal>
|
|
||||||
</Grid>
|
</Grid>
|
||||||
<DeleteImageModal />
|
<DeleteImageModal />
|
||||||
<ChangeBoardModal />
|
<ChangeBoardModal />
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<GlobalHotkeys />
|
<GlobalHotkeys />
|
||||||
</>
|
</ErrorBoundary>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,97 @@
|
|||||||
|
import { Flex, Heading, Link, Text, useToast } from '@chakra-ui/react';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import newGithubIssueUrl from 'new-github-issue-url';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { FaCopy, FaExternalLinkAlt } from 'react-icons/fa';
|
||||||
|
import { FaArrowRotateLeft } from 'react-icons/fa6';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
error: Error;
|
||||||
|
resetErrorBoundary: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
||||||
|
const toast = useToast();
|
||||||
|
|
||||||
|
const handleCopy = useCallback(() => {
|
||||||
|
const text = JSON.stringify(serializeError(error), null, 2);
|
||||||
|
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
|
||||||
|
toast({
|
||||||
|
title: 'Error Copied',
|
||||||
|
});
|
||||||
|
}, [error, toast]);
|
||||||
|
|
||||||
|
const url = useMemo(
|
||||||
|
() =>
|
||||||
|
newGithubIssueUrl({
|
||||||
|
user: 'invoke-ai',
|
||||||
|
repo: 'InvokeAI',
|
||||||
|
template: 'BUG_REPORT.yml',
|
||||||
|
title: `[bug]: ${error.name}: ${error.message}`,
|
||||||
|
}),
|
||||||
|
[error.message, error.name]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
layerStyle="body"
|
||||||
|
sx={{
|
||||||
|
w: '100vw',
|
||||||
|
h: '100vh',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
p: 4,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
layerStyle="first"
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
borderRadius: 'base',
|
||||||
|
justifyContent: 'center',
|
||||||
|
gap: 8,
|
||||||
|
p: 16,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Heading>Something went wrong</Heading>
|
||||||
|
<Flex
|
||||||
|
layerStyle="second"
|
||||||
|
sx={{
|
||||||
|
px: 8,
|
||||||
|
py: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
gap: 4,
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
alignItems: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontWeight: 600,
|
||||||
|
color: 'error.500',
|
||||||
|
_dark: { color: 'error.400' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{error.name}: {error.message}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<Flex sx={{ gap: 4 }}>
|
||||||
|
<IAIButton
|
||||||
|
leftIcon={<FaArrowRotateLeft />}
|
||||||
|
onClick={resetErrorBoundary}
|
||||||
|
>
|
||||||
|
Reset UI
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton leftIcon={<FaCopy />} onClick={handleCopy}>
|
||||||
|
Copy Error
|
||||||
|
</IAIButton>
|
||||||
|
<Link href={url} isExternal>
|
||||||
|
<IAIButton leftIcon={<FaExternalLinkAlt />}>Create Issue</IAIButton>
|
||||||
|
</Link>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(AppErrorBoundaryFallback);
|
@ -1,30 +1,21 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
|
||||||
import {
|
import {
|
||||||
ctrlKeyPressed,
|
ctrlKeyPressed,
|
||||||
metaKeyPressed,
|
metaKeyPressed,
|
||||||
shiftKeyPressed,
|
shiftKeyPressed,
|
||||||
} from 'features/ui/store/hotkeysSlice';
|
} from 'features/ui/store/hotkeysSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import {
|
|
||||||
setActiveTab,
|
|
||||||
toggleGalleryPanel,
|
|
||||||
toggleParametersPanel,
|
|
||||||
togglePinGalleryPanel,
|
|
||||||
togglePinParametersPanel,
|
|
||||||
} from 'features/ui/store/uiSlice';
|
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import React, { memo } from 'react';
|
import React, { memo } from 'react';
|
||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
const globalHotkeysSelector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ hotkeys, ui }) => {
|
({ hotkeys }) => {
|
||||||
const { shift, ctrl, meta } = hotkeys;
|
const { shift, ctrl, meta } = hotkeys;
|
||||||
const { shouldPinParametersPanel, shouldPinGallery } = ui;
|
return { shift, ctrl, meta };
|
||||||
return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel };
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
@ -41,9 +32,7 @@ const globalHotkeysSelector = createSelector(
|
|||||||
*/
|
*/
|
||||||
const GlobalHotkeys: React.FC = () => {
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } =
|
const { shift, ctrl, meta } = useAppSelector(globalHotkeysSelector);
|
||||||
useAppSelector(globalHotkeysSelector);
|
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'*',
|
'*',
|
||||||
@ -68,34 +57,6 @@ const GlobalHotkeys: React.FC = () => {
|
|||||||
[shift, ctrl, meta]
|
[shift, ctrl, meta]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys('o', () => {
|
|
||||||
dispatch(toggleParametersPanel());
|
|
||||||
if (activeTabName === 'unifiedCanvas' && shouldPinParametersPanel) {
|
|
||||||
dispatch(requestCanvasRescale());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
useHotkeys(['shift+o'], () => {
|
|
||||||
dispatch(togglePinParametersPanel());
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
|
||||||
dispatch(requestCanvasRescale());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
useHotkeys('g', () => {
|
|
||||||
dispatch(toggleGalleryPanel());
|
|
||||||
if (activeTabName === 'unifiedCanvas' && shouldPinGallery) {
|
|
||||||
dispatch(requestCanvasRescale());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
useHotkeys(['shift+g'], () => {
|
|
||||||
dispatch(togglePinGalleryPanel());
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
|
||||||
dispatch(requestCanvasRescale());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
useHotkeys('1', () => {
|
useHotkeys('1', () => {
|
||||||
dispatch(setActiveTab('txt2img'));
|
dispatch(setActiveTab('txt2img'));
|
||||||
});
|
});
|
||||||
@ -112,6 +73,10 @@ const GlobalHotkeys: React.FC = () => {
|
|||||||
dispatch(setActiveTab('nodes'));
|
dispatch(setActiveTab('nodes'));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useHotkeys('5', () => {
|
||||||
|
dispatch(setActiveTab('modelManager'));
|
||||||
|
});
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import {
|
|||||||
createLocalStorageManager,
|
createLocalStorageManager,
|
||||||
extendTheme,
|
extendTheme,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { ReactNode, useEffect, useMemo } from 'react';
|
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { theme as invokeAITheme } from 'theme/theme';
|
import { theme as invokeAITheme } from 'theme/theme';
|
||||||
|
|
||||||
@ -46,4 +46,4 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default ThemeLocaleProvider;
|
export default memo(ThemeLocaleProvider);
|
||||||
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||||
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
import { MakeToastArg, makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCallback, useEffect } from 'react';
|
import { memo, useCallback, useEffect } from 'react';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
||||||
@ -44,4 +44,4 @@ export const useAppToaster = () => {
|
|||||||
return toaster;
|
return toaster;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default Toaster;
|
export default memo(Toaster);
|
||||||
|
@ -9,7 +9,7 @@ export const log = Roarr.child(BASE_CONTEXT);
|
|||||||
|
|
||||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||||
|
|
||||||
type LoggerNamespace =
|
export type LoggerNamespace =
|
||||||
| 'images'
|
| 'images'
|
||||||
| 'models'
|
| 'models'
|
||||||
| 'config'
|
| 'config'
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useEffect } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
import { ROARR, Roarr } from 'roarr';
|
import { ROARR, Roarr } from 'roarr';
|
||||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP } from './logger';
|
import {
|
||||||
|
$logger,
|
||||||
|
BASE_CONTEXT,
|
||||||
|
LOG_LEVEL_MAP,
|
||||||
|
LoggerNamespace,
|
||||||
|
logger,
|
||||||
|
} from './logger';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
systemSelector,
|
systemSelector,
|
||||||
@ -25,7 +30,7 @@ const selector = createSelector(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
export const useLogger = () => {
|
export const useLogger = (namespace: LoggerNamespace) => {
|
||||||
const { consoleLogLevel, shouldLogToConsole } = useAppSelector(selector);
|
const { consoleLogLevel, shouldLogToConsole } = useAppSelector(selector);
|
||||||
|
|
||||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||||
@ -57,7 +62,7 @@ export const useLogger = () => {
|
|||||||
$logger.set(Roarr.child(newContext));
|
$logger.set(Roarr.child(newContext));
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const logger = useStore($logger);
|
const log = useMemo(() => logger(namespace), [namespace]);
|
||||||
|
|
||||||
return logger;
|
return log;
|
||||||
};
|
};
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
controlNetImageChanged,
|
||||||
|
controlNetProcessedImageChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp, forEach } from 'lodash-es';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { imagesAdapter } from 'services/api/util';
|
import { imagesAdapter } from 'services/api/util';
|
||||||
@ -73,22 +77,61 @@ export const addRequestedSingleImageDeletionListener = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||||
|
|
||||||
if (imageUsage.isCanvasImage) {
|
if (imageUsage.isCanvasImage) {
|
||||||
dispatch(resetCanvas());
|
dispatch(resetCanvas());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageUsage.isControlNetImage) {
|
imageDTOs.forEach((imageDTO) => {
|
||||||
dispatch(controlNetReset());
|
// reset init image if we deleted it
|
||||||
}
|
if (
|
||||||
|
getState().generation.initialImage?.imageName === imageDTO.image_name
|
||||||
if (imageUsage.isInitialImage) {
|
) {
|
||||||
dispatch(clearInitialImage());
|
dispatch(clearInitialImage());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageUsage.isNodesImage) {
|
// reset controlNets that use the deleted images
|
||||||
dispatch(nodeEditorReset());
|
forEach(getState().controlNet.controlNets, (controlNet) => {
|
||||||
|
if (
|
||||||
|
controlNet.controlImage === imageDTO.image_name ||
|
||||||
|
controlNet.processedControlImage === imageDTO.image_name
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
controlNetImageChanged({
|
||||||
|
controlNetId: controlNet.controlNetId,
|
||||||
|
controlImage: null,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessedImageChanged({
|
||||||
|
controlNetId: controlNet.controlNetId,
|
||||||
|
processedControlImage: null,
|
||||||
|
})
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// reset nodes that use the deleted images
|
||||||
|
getState().nodes.nodes.forEach((node) => {
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
forEach(node.data.inputs, (input) => {
|
||||||
|
if (
|
||||||
|
input.type === 'ImageField' &&
|
||||||
|
input.value?.image_name === imageDTO.image_name
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
fieldImageValueChanged({
|
||||||
|
nodeId: node.data.id,
|
||||||
|
fieldName: input.name,
|
||||||
|
value: undefined,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
// Delete from server
|
// Delete from server
|
||||||
const { requestId } = dispatch(
|
const { requestId } = dispatch(
|
||||||
@ -154,17 +197,58 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
|||||||
dispatch(resetCanvas());
|
dispatch(resetCanvas());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imagesUsage.some((i) => i.isControlNetImage)) {
|
imageDTOs.forEach((imageDTO) => {
|
||||||
dispatch(controlNetReset());
|
// reset init image if we deleted it
|
||||||
}
|
if (
|
||||||
|
getState().generation.initialImage?.imageName ===
|
||||||
if (imagesUsage.some((i) => i.isInitialImage)) {
|
imageDTO.image_name
|
||||||
|
) {
|
||||||
dispatch(clearInitialImage());
|
dispatch(clearInitialImage());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imagesUsage.some((i) => i.isNodesImage)) {
|
// reset controlNets that use the deleted images
|
||||||
dispatch(nodeEditorReset());
|
forEach(getState().controlNet.controlNets, (controlNet) => {
|
||||||
|
if (
|
||||||
|
controlNet.controlImage === imageDTO.image_name ||
|
||||||
|
controlNet.processedControlImage === imageDTO.image_name
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
controlNetImageChanged({
|
||||||
|
controlNetId: controlNet.controlNetId,
|
||||||
|
controlImage: null,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessedImageChanged({
|
||||||
|
controlNetId: controlNet.controlNetId,
|
||||||
|
processedControlImage: null,
|
||||||
|
})
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// reset nodes that use the deleted images
|
||||||
|
getState().nodes.nodes.forEach((node) => {
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
forEach(node.data.inputs, (input) => {
|
||||||
|
if (
|
||||||
|
input.type === 'ImageField' &&
|
||||||
|
input.value?.image_name === imageDTO.image_name
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
fieldImageValueChanged({
|
||||||
|
nodeId: node.data.id,
|
||||||
|
fieldName: input.name,
|
||||||
|
value: undefined,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
} catch {
|
} catch {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import { modelsApi } from 'services/api/endpoints/models';
|
|||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
import { size } from 'lodash-es';
|
||||||
|
|
||||||
export const addSocketConnectedEventListener = () => {
|
export const addSocketConnectedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -18,7 +19,7 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) {
|
||||||
dispatch(receivedOpenAPISchema());
|
dispatch(receivedOpenAPISchema());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,8 +8,8 @@ import {
|
|||||||
import { memo, ReactNode } from 'react';
|
import { memo, ReactNode } from 'react';
|
||||||
|
|
||||||
export interface IAIButtonProps extends ButtonProps {
|
export interface IAIButtonProps extends ButtonProps {
|
||||||
tooltip?: string;
|
tooltip?: TooltipProps['label'];
|
||||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
|
||||||
isChecked?: boolean;
|
isChecked?: boolean;
|
||||||
children: ReactNode;
|
children: ReactNode;
|
||||||
}
|
}
|
||||||
|
@ -34,14 +34,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
|||||||
gap: 2,
|
gap: 2,
|
||||||
borderTopRadius: 'base',
|
borderTopRadius: 'base',
|
||||||
borderBottomRadius: isOpen ? 0 : 'base',
|
borderBottomRadius: isOpen ? 0 : 'base',
|
||||||
bg: isOpen
|
bg: mode('base.250', 'base.750')(colorMode),
|
||||||
? mode('base.200', 'base.750')(colorMode)
|
|
||||||
: mode('base.150', 'base.800')(colorMode),
|
|
||||||
color: mode('base.900', 'base.100')(colorMode),
|
color: mode('base.900', 'base.100')(colorMode),
|
||||||
_hover: {
|
_hover: {
|
||||||
bg: isOpen
|
bg: mode('base.300', 'base.700')(colorMode),
|
||||||
? mode('base.250', 'base.700')(colorMode)
|
|
||||||
: mode('base.200', 'base.750')(colorMode),
|
|
||||||
},
|
},
|
||||||
fontSize: 'sm',
|
fontSize: 'sm',
|
||||||
fontWeight: 600,
|
fontWeight: 600,
|
||||||
@ -90,9 +86,10 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
|||||||
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
p: 4,
|
p: 2,
|
||||||
|
pt: 3,
|
||||||
borderBottomRadius: 'base',
|
borderBottomRadius: 'base',
|
||||||
bg: 'base.100',
|
bg: 'base.150',
|
||||||
_dark: {
|
_dark: {
|
||||||
bg: 'base.800',
|
bg: 'base.800',
|
||||||
},
|
},
|
||||||
|
@ -100,14 +100,18 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
const handleMouseOver = useCallback(
|
const handleMouseOver = useCallback(
|
||||||
(e: MouseEvent<HTMLDivElement>) => {
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
if (onMouseOver) onMouseOver(e);
|
if (onMouseOver) {
|
||||||
|
onMouseOver(e);
|
||||||
|
}
|
||||||
setIsHovered(true);
|
setIsHovered(true);
|
||||||
},
|
},
|
||||||
[onMouseOver]
|
[onMouseOver]
|
||||||
);
|
);
|
||||||
const handleMouseOut = useCallback(
|
const handleMouseOut = useCallback(
|
||||||
(e: MouseEvent<HTMLDivElement>) => {
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
if (onMouseOut) onMouseOut(e);
|
if (onMouseOut) {
|
||||||
|
onMouseOut(e);
|
||||||
|
}
|
||||||
setIsHovered(false);
|
setIsHovered(false);
|
||||||
},
|
},
|
||||||
[onMouseOut]
|
[onMouseOut]
|
||||||
@ -122,7 +126,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
? {}
|
? {}
|
||||||
: {
|
: {
|
||||||
cursor: 'pointer',
|
cursor: 'pointer',
|
||||||
bg: mode('base.200', 'base.800')(colorMode),
|
bg: mode('base.200', 'base.700')(colorMode),
|
||||||
_hover: {
|
_hover: {
|
||||||
bg: mode('base.300', 'base.650')(colorMode),
|
bg: mode('base.300', 'base.650')(colorMode),
|
||||||
color: mode('base.500', 'base.300')(colorMode),
|
color: mode('base.500', 'base.300')(colorMode),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Box, Flex, Icon } from '@chakra-ui/react';
|
import { Box, Flex, Icon } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
import { FaExclamation } from 'react-icons/fa';
|
import { FaExclamation } from 'react-icons/fa';
|
||||||
|
|
||||||
const IAIErrorLoadingImageFallback = () => {
|
const IAIErrorLoadingImageFallback = () => {
|
||||||
@ -39,4 +40,4 @@ const IAIErrorLoadingImageFallback = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAIErrorLoadingImageFallback;
|
export default memo(IAIErrorLoadingImageFallback);
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Box, Skeleton } from '@chakra-ui/react';
|
import { Box, Skeleton } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const IAIFillSkeleton = () => {
|
const IAIFillSkeleton = () => {
|
||||||
return (
|
return (
|
||||||
@ -27,4 +28,4 @@ const IAIFillSkeleton = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAIFillSkeleton;
|
export default memo(IAIFillSkeleton);
|
||||||
|
@ -9,8 +9,8 @@ import { memo } from 'react';
|
|||||||
|
|
||||||
export type IAIIconButtonProps = IconButtonProps & {
|
export type IAIIconButtonProps = IconButtonProps & {
|
||||||
role?: string;
|
role?: string;
|
||||||
tooltip?: string;
|
tooltip?: TooltipProps['label'];
|
||||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
tooltipProps?: Omit<TooltipProps, 'children' | 'label'>;
|
||||||
isChecked?: boolean;
|
isChecked?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Badge, Flex } from '@chakra-ui/react';
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
type ImageMetadataOverlayProps = {
|
type ImageMetadataOverlayProps = {
|
||||||
@ -26,4 +27,4 @@ const ImageMetadataOverlay = ({ imageDTO }: ImageMetadataOverlayProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ImageMetadataOverlay;
|
export default memo(ImageMetadataOverlay);
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Box, Flex, Heading } from '@chakra-ui/react';
|
import { Box, Flex, Heading } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
type ImageUploadOverlayProps = {
|
type ImageUploadOverlayProps = {
|
||||||
@ -87,4 +88,4 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
export default ImageUploadOverlay;
|
export default memo(ImageUploadOverlay);
|
||||||
|
@ -150,7 +150,9 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
{...getRootProps({ style: {} })}
|
{...getRootProps({ style: {} })}
|
||||||
onKeyDown={(e: KeyboardEvent) => {
|
onKeyDown={(e: KeyboardEvent) => {
|
||||||
// Bail out if user hits spacebar - do not open the uploader
|
// Bail out if user hits spacebar - do not open the uploader
|
||||||
if (e.key === ' ') return;
|
if (e.key === ' ') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<input {...getInputProps()} />
|
<input {...getInputProps()} />
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Flex, Icon } from '@chakra-ui/react';
|
import { Flex, Icon } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
import { FaImage } from 'react-icons/fa';
|
import { FaImage } from 'react-icons/fa';
|
||||||
|
|
||||||
const SelectImagePlaceholder = () => {
|
const SelectImagePlaceholder = () => {
|
||||||
@ -19,4 +20,4 @@ const SelectImagePlaceholder = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default SelectImagePlaceholder;
|
export default memo(SelectImagePlaceholder);
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
@ -18,6 +19,7 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
opacity: isSelected ? 1 : 0.7,
|
opacity: isSelected ? 1 : 0.7,
|
||||||
transitionProperty: 'common',
|
transitionProperty: 'common',
|
||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
|
pointerEvents: 'none',
|
||||||
shadow: isSelected
|
shadow: isSelected
|
||||||
? isHovered
|
? isHovered
|
||||||
? 'hoverSelected.light'
|
? 'hoverSelected.light'
|
||||||
@ -39,4 +41,4 @@ const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default SelectionOverlay;
|
export default memo(SelectionOverlay);
|
||||||
|
@ -2,71 +2,108 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
// import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { isInvocationNode } from 'features/nodes/types/types';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach, map } from 'lodash-es';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { getConnectedEdges } from 'reactflow';
|
||||||
import { modelsApi } from '../../services/api/endpoints/models';
|
|
||||||
|
|
||||||
const readinessSelector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
(state, activeTabName) => {
|
(state, activeTabName) => {
|
||||||
const { generation, system } = state;
|
const { generation, system, nodes } = state;
|
||||||
const { initialImage } = generation;
|
const { initialImage, model } = generation;
|
||||||
|
|
||||||
const { isProcessing, isConnected } = system;
|
const { isProcessing, isConnected } = system;
|
||||||
|
|
||||||
let isReady = true;
|
const reasons: string[] = [];
|
||||||
const reasonsWhyNotReady: string[] = [];
|
|
||||||
|
|
||||||
if (activeTabName === 'img2img' && !initialImage) {
|
|
||||||
isReady = false;
|
|
||||||
reasonsWhyNotReady.push('No initial image selected');
|
|
||||||
}
|
|
||||||
|
|
||||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
|
||||||
modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state);
|
|
||||||
if (!mainModelsSuccessfullyLoaded) {
|
|
||||||
isReady = false;
|
|
||||||
reasonsWhyNotReady.push('Models are not loaded');
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: job queue
|
|
||||||
// Cannot generate if already processing an image
|
// Cannot generate if already processing an image
|
||||||
if (isProcessing) {
|
if (isProcessing) {
|
||||||
isReady = false;
|
reasons.push('System busy');
|
||||||
reasonsWhyNotReady.push('System Busy');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cannot generate if not connected
|
// Cannot generate if not connected
|
||||||
if (!isConnected) {
|
if (!isConnected) {
|
||||||
isReady = false;
|
reasons.push('System disconnected');
|
||||||
reasonsWhyNotReady.push('System Disconnected');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Cannot generate variations without valid seed weights
|
if (activeTabName === 'img2img' && !initialImage) {
|
||||||
// if (
|
reasons.push('No initial image selected');
|
||||||
// shouldGenerateVariations &&
|
}
|
||||||
// (!(validateSeedWeights(seedWeights) || seedWeights === '') || seed === -1)
|
|
||||||
// ) {
|
|
||||||
// isReady = false;
|
|
||||||
// reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
|
||||||
// }
|
|
||||||
|
|
||||||
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
if (activeTabName === 'nodes' && nodes.shouldValidateGraph) {
|
||||||
if (!controlNet.model) {
|
if (!nodes.nodes.length) {
|
||||||
isReady = false;
|
reasons.push('No nodes in graph');
|
||||||
reasonsWhyNotReady.push(`ControlNet ${id} has no model selected.`);
|
}
|
||||||
|
|
||||||
|
nodes.nodes.forEach((node) => {
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||||
|
|
||||||
|
if (!nodeTemplate) {
|
||||||
|
// Node type not found
|
||||||
|
reasons.push('Missing node template');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||||
|
|
||||||
|
forEach(node.data.inputs, (field) => {
|
||||||
|
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||||
|
const hasConnection = connectedEdges.some(
|
||||||
|
(edge) =>
|
||||||
|
edge.target === node.id && edge.targetHandle === field.name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!fieldTemplate) {
|
||||||
|
reasons.push('Missing field template');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fieldTemplate.required && !field.value && !hasConnection) {
|
||||||
|
reasons.push(
|
||||||
|
`${node.data.label || nodeTemplate.title} -> ${
|
||||||
|
field.label || fieldTemplate.title
|
||||||
|
} missing input`
|
||||||
|
);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (!model) {
|
||||||
|
reasons.push('No model selected');
|
||||||
|
}
|
||||||
|
|
||||||
// All good
|
if (state.controlNet.isEnabled) {
|
||||||
return { isReady, reasonsWhyNotReady };
|
map(state.controlNet.controlNets).forEach((controlNet, i) => {
|
||||||
|
if (!controlNet.isEnabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!controlNet.model) {
|
||||||
|
reasons.push(`ControlNet ${i + 1} has no model selected.`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
!controlNet.controlImage ||
|
||||||
|
(!controlNet.processedControlImage &&
|
||||||
|
controlNet.processorType !== 'none')
|
||||||
|
) {
|
||||||
|
reasons.push(`ControlNet ${i + 1} has no control image`);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { isReady: !reasons.length, isProcessing, reasons };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
export const useIsReadyToInvoke = () => {
|
export const useIsReadyToInvoke = () => {
|
||||||
const { isReady } = useAppSelector(readinessSelector);
|
const { isReady, isProcessing, reasons } = useAppSelector(selector);
|
||||||
return isReady;
|
return { isReady, isProcessing, reasons };
|
||||||
};
|
};
|
||||||
|
@ -11,8 +11,14 @@ export default function useResolution():
|
|||||||
const tabletResolutions = ['md', 'lg'];
|
const tabletResolutions = ['md', 'lg'];
|
||||||
const desktopResolutions = ['xl', '2xl'];
|
const desktopResolutions = ['xl', '2xl'];
|
||||||
|
|
||||||
if (mobileResolutions.includes(breakpointValue)) return 'mobile';
|
if (mobileResolutions.includes(breakpointValue)) {
|
||||||
if (tabletResolutions.includes(breakpointValue)) return 'tablet';
|
return 'mobile';
|
||||||
if (desktopResolutions.includes(breakpointValue)) return 'desktop';
|
}
|
||||||
|
if (tabletResolutions.includes(breakpointValue)) {
|
||||||
|
return 'tablet';
|
||||||
|
}
|
||||||
|
if (desktopResolutions.includes(breakpointValue)) {
|
||||||
|
return 'desktop';
|
||||||
|
}
|
||||||
return 'unknown';
|
return 'unknown';
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
export const colorTokenToCssVar = (colorToken: string) =>
|
||||||
|
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
@ -6,7 +6,11 @@ export const dateComparator = (a: string, b: string) => {
|
|||||||
const dateB = new Date(b);
|
const dateB = new Date(b);
|
||||||
|
|
||||||
// sort in ascending order
|
// sort in ascending order
|
||||||
if (dateA > dateB) return 1;
|
if (dateA > dateB) {
|
||||||
if (dateA < dateB) return -1;
|
return 1;
|
||||||
|
}
|
||||||
|
if (dateA < dateB) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
};
|
};
|
||||||
|
@ -5,7 +5,9 @@ type Base64AndCaption = {
|
|||||||
|
|
||||||
const openBase64ImageInTab = (images: Base64AndCaption[]) => {
|
const openBase64ImageInTab = (images: Base64AndCaption[]) => {
|
||||||
const w = window.open('');
|
const w = window.open('');
|
||||||
if (!w) return;
|
if (!w) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
images.forEach((i) => {
|
images.forEach((i) => {
|
||||||
const image = new Image();
|
const image = new Image();
|
||||||
|
@ -5,6 +5,7 @@ import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
import { isStagingSelector } from '../store/canvasSelectors';
|
import { isStagingSelector } from '../store/canvasSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const ClearCanvasHistoryButtonModal = () => {
|
const ClearCanvasHistoryButtonModal = () => {
|
||||||
const isStaging = useAppSelector(isStagingSelector);
|
const isStaging = useAppSelector(isStagingSelector);
|
||||||
@ -28,4 +29,4 @@ const ClearCanvasHistoryButtonModal = () => {
|
|||||||
</IAIAlertDialog>
|
</IAIAlertDialog>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
export default ClearCanvasHistoryButtonModal;
|
export default memo(ClearCanvasHistoryButtonModal);
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { Box, chakra, Flex } from '@chakra-ui/react';
|
import { Box, chakra, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import {
|
import {
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
@ -9,7 +9,7 @@ import {
|
|||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import { KonvaEventObject } from 'konva/lib/Node';
|
import { KonvaEventObject } from 'konva/lib/Node';
|
||||||
import { Vector2d } from 'konva/lib/types';
|
import { Vector2d } from 'konva/lib/types';
|
||||||
import { useCallback, useRef } from 'react';
|
import { memo, useCallback, useEffect, useRef } from 'react';
|
||||||
import { Layer, Stage } from 'react-konva';
|
import { Layer, Stage } from 'react-konva';
|
||||||
import useCanvasDragMove from '../hooks/useCanvasDragMove';
|
import useCanvasDragMove from '../hooks/useCanvasDragMove';
|
||||||
import useCanvasHotkeys from '../hooks/useCanvasHotkeys';
|
import useCanvasHotkeys from '../hooks/useCanvasHotkeys';
|
||||||
@ -18,6 +18,7 @@ import useCanvasMouseMove from '../hooks/useCanvasMouseMove';
|
|||||||
import useCanvasMouseOut from '../hooks/useCanvasMouseOut';
|
import useCanvasMouseOut from '../hooks/useCanvasMouseOut';
|
||||||
import useCanvasMouseUp from '../hooks/useCanvasMouseUp';
|
import useCanvasMouseUp from '../hooks/useCanvasMouseUp';
|
||||||
import useCanvasWheel from '../hooks/useCanvasZoom';
|
import useCanvasWheel from '../hooks/useCanvasZoom';
|
||||||
|
import { canvasResized } from '../store/canvasSlice';
|
||||||
import {
|
import {
|
||||||
setCanvasBaseLayer,
|
setCanvasBaseLayer,
|
||||||
setCanvasStage,
|
setCanvasStage,
|
||||||
@ -106,7 +107,8 @@ const IAICanvas = () => {
|
|||||||
shouldAntialias,
|
shouldAntialias,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
useCanvasHotkeys();
|
useCanvasHotkeys();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
const stageRef = useRef<Konva.Stage | null>(null);
|
const stageRef = useRef<Konva.Stage | null>(null);
|
||||||
const canvasBaseLayerRef = useRef<Konva.Layer | null>(null);
|
const canvasBaseLayerRef = useRef<Konva.Layer | null>(null);
|
||||||
|
|
||||||
@ -137,8 +139,30 @@ const IAICanvas = () => {
|
|||||||
const { handleDragStart, handleDragMove, handleDragEnd } =
|
const { handleDragStart, handleDragMove, handleDragEnd } =
|
||||||
useCanvasDragMove();
|
useCanvasDragMove();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const resizeObserver = new ResizeObserver((entries) => {
|
||||||
|
for (const entry of entries) {
|
||||||
|
if (entry.contentBoxSize) {
|
||||||
|
const { width, height } = entry.contentRect;
|
||||||
|
dispatch(canvasResized({ width, height }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
resizeObserver.observe(containerRef.current);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
resizeObserver.disconnect();
|
||||||
|
};
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
id="canvas-container"
|
||||||
|
ref={containerRef}
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
height: '100%',
|
height: '100%',
|
||||||
@ -146,13 +170,18 @@ const IAICanvas = () => {
|
|||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Box sx={{ position: 'relative' }}>
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
// top: 0,
|
||||||
|
// insetInlineStart: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
<ChakraStage
|
<ChakraStage
|
||||||
tabIndex={-1}
|
tabIndex={-1}
|
||||||
ref={canvasStageRefCallback}
|
ref={canvasStageRefCallback}
|
||||||
sx={{
|
sx={{
|
||||||
outline: 'none',
|
outline: 'none',
|
||||||
// boxShadow: '0px 0px 0px 1px var(--border-color-light)',
|
|
||||||
overflow: 'hidden',
|
overflow: 'hidden',
|
||||||
cursor: stageCursor ? stageCursor : undefined,
|
cursor: stageCursor ? stageCursor : undefined,
|
||||||
canvas: {
|
canvas: {
|
||||||
@ -213,11 +242,11 @@ const IAICanvas = () => {
|
|||||||
/>
|
/>
|
||||||
</Layer>
|
</Layer>
|
||||||
</ChakraStage>
|
</ChakraStage>
|
||||||
|
</Box>
|
||||||
<IAICanvasStatusText />
|
<IAICanvasStatusText />
|
||||||
<IAICanvasStagingAreaToolbar />
|
<IAICanvasStagingAreaToolbar />
|
||||||
</Box>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvas;
|
export default memo(IAICanvas);
|
||||||
|
@ -4,6 +4,7 @@ import { isEqual } from 'lodash-es';
|
|||||||
|
|
||||||
import { Group, Rect } from 'react-konva';
|
import { Group, Rect } from 'react-konva';
|
||||||
import { canvasSelector } from '../store/canvasSelectors';
|
import { canvasSelector } from '../store/canvasSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
@ -67,4 +68,4 @@ const IAICanvasBoundingBoxOverlay = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasBoundingBoxOverlay;
|
export default memo(IAICanvasBoundingBoxOverlay);
|
||||||
|
@ -6,7 +6,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { isEqual, range } from 'lodash-es';
|
import { isEqual, range } from 'lodash-es';
|
||||||
|
|
||||||
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
|
import { ReactNode, memo, useCallback, useLayoutEffect, useState } from 'react';
|
||||||
import { Group, Line as KonvaLine } from 'react-konva';
|
import { Group, Line as KonvaLine } from 'react-konva';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -117,4 +117,4 @@ const IAICanvasGrid = () => {
|
|||||||
return <Group>{gridLines}</Group>;
|
return <Group>{gridLines}</Group>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasGrid;
|
export default memo(IAICanvasGrid);
|
||||||
|
@ -4,6 +4,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
|||||||
import useImage from 'use-image';
|
import useImage from 'use-image';
|
||||||
import { CanvasImage } from '../store/canvasTypes';
|
import { CanvasImage } from '../store/canvasTypes';
|
||||||
import { $authToken } from 'services/api/client';
|
import { $authToken } from 'services/api/client';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
type IAICanvasImageProps = {
|
type IAICanvasImageProps = {
|
||||||
canvasImage: CanvasImage;
|
canvasImage: CanvasImage;
|
||||||
@ -25,4 +26,4 @@ const IAICanvasImage = (props: IAICanvasImageProps) => {
|
|||||||
return <Image x={x} y={y} image={image} listening={false} />;
|
return <Image x={x} y={y} image={image} listening={false} />;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasImage;
|
export default memo(IAICanvasImage);
|
||||||
|
@ -4,7 +4,7 @@ import { systemSelector } from 'features/system/store/systemSelectors';
|
|||||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { memo, useEffect, useState } from 'react';
|
||||||
import { Image as KonvaImage } from 'react-konva';
|
import { Image as KonvaImage } from 'react-konva';
|
||||||
import { canvasSelector } from '../store/canvasSelectors';
|
import { canvasSelector } from '../store/canvasSelectors';
|
||||||
|
|
||||||
@ -66,4 +66,4 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
|||||||
) : null;
|
) : null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasIntermediateImage;
|
export default memo(IAICanvasIntermediateImage);
|
||||||
|
@ -7,7 +7,7 @@ import { Rect } from 'react-konva';
|
|||||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import { isNumber } from 'lodash-es';
|
import { isNumber } from 'lodash-es';
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
export const canvasMaskCompositerSelector = createSelector(
|
export const canvasMaskCompositerSelector = createSelector(
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
@ -125,7 +125,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
|||||||
}, [offset]);
|
}, [offset]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (fillPatternImage) return;
|
if (fillPatternImage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const image = new Image();
|
const image = new Image();
|
||||||
|
|
||||||
image.onload = () => {
|
image.onload = () => {
|
||||||
@ -135,7 +137,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
|||||||
}, [fillPatternImage, maskColorString]);
|
}, [fillPatternImage, maskColorString]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!fillPatternImage) return;
|
if (!fillPatternImage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
fillPatternImage.src = getColoredSVG(maskColorString);
|
fillPatternImage.src = getColoredSVG(maskColorString);
|
||||||
}, [fillPatternImage, maskColorString]);
|
}, [fillPatternImage, maskColorString]);
|
||||||
|
|
||||||
@ -151,8 +155,9 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
|||||||
!isNumber(stageScale) ||
|
!isNumber(stageScale) ||
|
||||||
!isNumber(stageDimensions.width) ||
|
!isNumber(stageDimensions.width) ||
|
||||||
!isNumber(stageDimensions.height)
|
!isNumber(stageDimensions.height)
|
||||||
)
|
) {
|
||||||
return null;
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Rect
|
<Rect
|
||||||
@ -172,4 +177,4 @@ const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasMaskCompositer;
|
export default memo(IAICanvasMaskCompositer);
|
||||||
|
@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
|
|||||||
|
|
||||||
import { Group, Line } from 'react-konva';
|
import { Group, Line } from 'react-konva';
|
||||||
import { isCanvasMaskLine } from '../store/canvasTypes';
|
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
export const canvasLinesSelector = createSelector(
|
export const canvasLinesSelector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
@ -52,4 +53,4 @@ const IAICanvasLines = (props: InpaintingCanvasLinesProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasLines;
|
export default memo(IAICanvasLines);
|
||||||
|
@ -12,6 +12,7 @@ import {
|
|||||||
isCanvasFillRect,
|
isCanvasFillRect,
|
||||||
} from '../store/canvasTypes';
|
} from '../store/canvasTypes';
|
||||||
import IAICanvasImage from './IAICanvasImage';
|
import IAICanvasImage from './IAICanvasImage';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
@ -33,7 +34,9 @@ const selector = createSelector(
|
|||||||
const IAICanvasObjectRenderer = () => {
|
const IAICanvasObjectRenderer = () => {
|
||||||
const { objects } = useAppSelector(selector);
|
const { objects } = useAppSelector(selector);
|
||||||
|
|
||||||
if (!objects) return null;
|
if (!objects) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Group name="outpainting-objects" listening={false}>
|
<Group name="outpainting-objects" listening={false}>
|
||||||
@ -101,4 +104,4 @@ const IAICanvasObjectRenderer = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasObjectRenderer;
|
export default memo(IAICanvasObjectRenderer);
|
||||||
|
@ -1,89 +0,0 @@
|
|||||||
import { Flex, Spinner } from '@chakra-ui/react';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import {
|
|
||||||
canvasSelector,
|
|
||||||
initialCanvasImageSelector,
|
|
||||||
} from 'features/canvas/store/canvasSelectors';
|
|
||||||
import {
|
|
||||||
resizeAndScaleCanvas,
|
|
||||||
resizeCanvas,
|
|
||||||
setCanvasContainerDimensions,
|
|
||||||
setDoesCanvasNeedScaling,
|
|
||||||
} from 'features/canvas/store/canvasSlice';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { useLayoutEffect, useRef } from 'react';
|
|
||||||
|
|
||||||
const canvasResizerSelector = createSelector(
|
|
||||||
canvasSelector,
|
|
||||||
initialCanvasImageSelector,
|
|
||||||
activeTabNameSelector,
|
|
||||||
(canvas, initialCanvasImage, activeTabName) => {
|
|
||||||
const { doesCanvasNeedScaling, isCanvasInitialized } = canvas;
|
|
||||||
return {
|
|
||||||
doesCanvasNeedScaling,
|
|
||||||
activeTabName,
|
|
||||||
initialCanvasImage,
|
|
||||||
isCanvasInitialized,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const IAICanvasResizer = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const {
|
|
||||||
doesCanvasNeedScaling,
|
|
||||||
activeTabName,
|
|
||||||
initialCanvasImage,
|
|
||||||
isCanvasInitialized,
|
|
||||||
} = useAppSelector(canvasResizerSelector);
|
|
||||||
|
|
||||||
const ref = useRef<HTMLDivElement>(null);
|
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
|
||||||
window.setTimeout(() => {
|
|
||||||
if (!ref.current) return;
|
|
||||||
|
|
||||||
const { clientWidth, clientHeight } = ref.current;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
setCanvasContainerDimensions({
|
|
||||||
width: clientWidth,
|
|
||||||
height: clientHeight,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!isCanvasInitialized) {
|
|
||||||
dispatch(resizeAndScaleCanvas());
|
|
||||||
} else {
|
|
||||||
dispatch(resizeCanvas());
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(setDoesCanvasNeedScaling(false));
|
|
||||||
}, 0);
|
|
||||||
}, [
|
|
||||||
dispatch,
|
|
||||||
initialCanvasImage,
|
|
||||||
doesCanvasNeedScaling,
|
|
||||||
activeTabName,
|
|
||||||
isCanvasInitialized,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
ref={ref}
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
gap: 4,
|
|
||||||
width: '100%',
|
|
||||||
height: '100%',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Spinner thickness="2px" size="xl" />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default IAICanvasResizer;
|
|
@ -6,6 +6,7 @@ import { isEqual } from 'lodash-es';
|
|||||||
|
|
||||||
import { Group, Rect } from 'react-konva';
|
import { Group, Rect } from 'react-konva';
|
||||||
import IAICanvasImage from './IAICanvasImage';
|
import IAICanvasImage from './IAICanvasImage';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
@ -88,4 +89,4 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasStagingArea;
|
export default memo(IAICanvasStagingArea);
|
||||||
|
@ -13,7 +13,7 @@ import {
|
|||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import {
|
||||||
@ -129,7 +129,9 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
currentStagingAreaImage?.imageName ?? skipToken
|
currentStagingAreaImage?.imageName ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!currentStagingAreaImage) return null;
|
if (!currentStagingAreaImage) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -138,11 +140,10 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
w="100%"
|
w="100%"
|
||||||
align="center"
|
align="center"
|
||||||
justify="center"
|
justify="center"
|
||||||
filter="drop-shadow(0 0.5rem 1rem rgba(0,0,0))"
|
|
||||||
onMouseOver={handleMouseOver}
|
onMouseOver={handleMouseOver}
|
||||||
onMouseOut={handleMouseOut}
|
onMouseOut={handleMouseOut}
|
||||||
>
|
>
|
||||||
<ButtonGroup isAttached>
|
<ButtonGroup isAttached borderRadius="base" shadow="dark-lg">
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
|
tooltip={`${t('unifiedCanvas.previous')} (Left)`}
|
||||||
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
|
aria-label={`${t('unifiedCanvas.previous')} (Left)`}
|
||||||
@ -207,4 +208,4 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasStagingAreaToolbar;
|
export default memo(IAICanvasStagingAreaToolbar);
|
||||||
|
@ -7,6 +7,7 @@ import { isEqual } from 'lodash-es';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import roundToHundreth from '../util/roundToHundreth';
|
import roundToHundreth from '../util/roundToHundreth';
|
||||||
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
|
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const warningColor = 'var(--invokeai-colors-warning-500)';
|
const warningColor = 'var(--invokeai-colors-warning-500)';
|
||||||
|
|
||||||
@ -162,4 +163,4 @@ const IAICanvasStatusText = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasStatusText;
|
export default memo(IAICanvasStatusText);
|
||||||
|
@ -10,6 +10,7 @@ import {
|
|||||||
COLOR_PICKER_SIZE,
|
COLOR_PICKER_SIZE,
|
||||||
COLOR_PICKER_STROKE_RADIUS,
|
COLOR_PICKER_STROKE_RADIUS,
|
||||||
} from '../util/constants';
|
} from '../util/constants';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const canvasBrushPreviewSelector = createSelector(
|
const canvasBrushPreviewSelector = createSelector(
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
@ -134,7 +135,9 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
|
|||||||
clip,
|
clip,
|
||||||
} = useAppSelector(canvasBrushPreviewSelector);
|
} = useAppSelector(canvasBrushPreviewSelector);
|
||||||
|
|
||||||
if (!shouldDrawBrushPreview) return null;
|
if (!shouldDrawBrushPreview) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Group listening={false} {...clip} {...rest}>
|
<Group listening={false} {...clip} {...rest}>
|
||||||
@ -206,4 +209,4 @@ const IAICanvasToolPreview = (props: GroupConfig) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasToolPreview;
|
export default memo(IAICanvasToolPreview);
|
||||||
|
@ -19,7 +19,7 @@ import { KonvaEventObject } from 'konva/lib/Node';
|
|||||||
import { Vector2d } from 'konva/lib/types';
|
import { Vector2d } from 'konva/lib/types';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { Group, Rect, Transformer } from 'react-konva';
|
import { Group, Rect, Transformer } from 'react-konva';
|
||||||
|
|
||||||
@ -85,7 +85,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
|||||||
useState(false);
|
useState(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!transformerRef.current || !shapeRef.current) return;
|
if (!transformerRef.current || !shapeRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
transformerRef.current.nodes([shapeRef.current]);
|
transformerRef.current.nodes([shapeRef.current]);
|
||||||
transformerRef.current.getLayer()?.batchDraw();
|
transformerRef.current.getLayer()?.batchDraw();
|
||||||
}, []);
|
}, []);
|
||||||
@ -133,7 +135,9 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
|||||||
* not its width and height. We need to un-scale the width and height before
|
* not its width and height. We need to un-scale the width and height before
|
||||||
* setting the values.
|
* setting the values.
|
||||||
*/
|
*/
|
||||||
if (!shapeRef.current) return;
|
if (!shapeRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const rect = shapeRef.current;
|
const rect = shapeRef.current;
|
||||||
|
|
||||||
@ -313,4 +317,4 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasBoundingBox;
|
export default memo(IAICanvasBoundingBox);
|
||||||
|
@ -20,6 +20,7 @@ import {
|
|||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -150,4 +151,4 @@ const IAICanvasMaskOptions = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasMaskOptions;
|
export default memo(IAICanvasMaskOptions);
|
||||||
|
@ -18,7 +18,7 @@ import {
|
|||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { ChangeEvent } from 'react';
|
import { ChangeEvent, memo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaWrench } from 'react-icons/fa';
|
import { FaWrench } from 'react-icons/fa';
|
||||||
@ -163,4 +163,4 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasSettingsButtonPopover;
|
export default memo(IAICanvasSettingsButtonPopover);
|
||||||
|
@ -18,6 +18,7 @@ import {
|
|||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { clamp, isEqual } from 'lodash-es';
|
import { clamp, isEqual } from 'lodash-es';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -252,4 +253,4 @@ const IAICanvasToolChooserOptions = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasToolChooserOptions;
|
export default memo(IAICanvasToolChooserOptions);
|
||||||
|
@ -18,7 +18,6 @@ import {
|
|||||||
import {
|
import {
|
||||||
resetCanvas,
|
resetCanvas,
|
||||||
resetCanvasView,
|
resetCanvasView,
|
||||||
resizeAndScaleCanvas,
|
|
||||||
setIsMaskEnabled,
|
setIsMaskEnabled,
|
||||||
setLayer,
|
setLayer,
|
||||||
setTool,
|
setTool,
|
||||||
@ -48,6 +47,7 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
|
|||||||
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
|
||||||
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
|
||||||
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
import IAICanvasUndoButton from './IAICanvasUndoButton';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[systemSelector, canvasSelector, isStagingSelector],
|
[systemSelector, canvasSelector, isStagingSelector],
|
||||||
@ -166,7 +166,9 @@ const IAICanvasToolbar = () => {
|
|||||||
|
|
||||||
const handleResetCanvasView = (shouldScaleTo1 = false) => {
|
const handleResetCanvasView = (shouldScaleTo1 = false) => {
|
||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
if (!canvasBaseLayer) return;
|
if (!canvasBaseLayer) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const clientRect = canvasBaseLayer.getClientRect({
|
const clientRect = canvasBaseLayer.getClientRect({
|
||||||
skipTransform: true,
|
skipTransform: true,
|
||||||
});
|
});
|
||||||
@ -180,7 +182,6 @@ const IAICanvasToolbar = () => {
|
|||||||
|
|
||||||
const handleResetCanvas = () => {
|
const handleResetCanvas = () => {
|
||||||
dispatch(resetCanvas());
|
dispatch(resetCanvas());
|
||||||
dispatch(resizeAndScaleCanvas());
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleMergeVisible = () => {
|
const handleMergeVisible = () => {
|
||||||
@ -309,4 +310,4 @@ const IAICanvasToolbar = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default IAICanvasToolbar;
|
export default memo(IAICanvasToolbar);
|
||||||
|
@ -32,13 +32,17 @@ const useCanvasDrag = () => {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
handleDragStart: useCallback(() => {
|
handleDragStart: useCallback(() => {
|
||||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dispatch(setIsMovingStage(true));
|
dispatch(setIsMovingStage(true));
|
||||||
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
||||||
|
|
||||||
handleDragMove: useCallback(
|
handleDragMove: useCallback(
|
||||||
(e: KonvaEventObject<MouseEvent>) => {
|
(e: KonvaEventObject<MouseEvent>) => {
|
||||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const newCoordinates = { x: e.target.x(), y: e.target.y() };
|
const newCoordinates = { x: e.target.x(), y: e.target.y() };
|
||||||
|
|
||||||
@ -48,7 +52,9 @@ const useCanvasDrag = () => {
|
|||||||
),
|
),
|
||||||
|
|
||||||
handleDragEnd: useCallback(() => {
|
handleDragEnd: useCallback(() => {
|
||||||
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) return;
|
if (!((tool === 'move' || isStaging) && !isMovingBoundingBox)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dispatch(setIsMovingStage(false));
|
dispatch(setIsMovingStage(false));
|
||||||
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
}, [dispatch, isMovingBoundingBox, isStaging, tool]),
|
||||||
};
|
};
|
||||||
|
@ -134,7 +134,9 @@ const useInpaintingCanvasHotkeys = () => {
|
|||||||
useHotkeys(
|
useHotkeys(
|
||||||
['space'],
|
['space'],
|
||||||
(e: KeyboardEvent) => {
|
(e: KeyboardEvent) => {
|
||||||
if (e.repeat) return;
|
if (e.repeat) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
canvasStage?.container().focus();
|
canvasStage?.container().focus();
|
||||||
|
|
||||||
|
@ -38,7 +38,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
|
|
||||||
return useCallback(
|
return useCallback(
|
||||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||||
if (!stageRef.current) return;
|
if (!stageRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
stageRef.current.container().focus();
|
stageRef.current.container().focus();
|
||||||
|
|
||||||
@ -54,7 +56,9 @@ const useCanvasMouseDown = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
|
|
||||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||||
|
|
||||||
if (!scaledCursorPosition) return;
|
if (!scaledCursorPosition) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
e.evt.preventDefault();
|
e.evt.preventDefault();
|
||||||
|
|
||||||
|
@ -41,11 +41,15 @@ const useCanvasMouseMove = (
|
|||||||
const { updateColorUnderCursor } = useColorPicker();
|
const { updateColorUnderCursor } = useColorPicker();
|
||||||
|
|
||||||
return useCallback(() => {
|
return useCallback(() => {
|
||||||
if (!stageRef.current) return;
|
if (!stageRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||||
|
|
||||||
if (!scaledCursorPosition) return;
|
if (!scaledCursorPosition) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(setCursorPosition(scaledCursorPosition));
|
dispatch(setCursorPosition(scaledCursorPosition));
|
||||||
|
|
||||||
@ -56,7 +60,9 @@ const useCanvasMouseMove = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isDrawing || tool === 'move' || isStaging) return;
|
if (!isDrawing || tool === 'move' || isStaging) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
didMouseMoveRef.current = true;
|
didMouseMoveRef.current = true;
|
||||||
dispatch(
|
dispatch(
|
||||||
|
@ -47,7 +47,9 @@ const useCanvasMouseUp = (
|
|||||||
if (!didMouseMoveRef.current && isDrawing && stageRef.current) {
|
if (!didMouseMoveRef.current && isDrawing && stageRef.current) {
|
||||||
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
const scaledCursorPosition = getScaledCursorPosition(stageRef.current);
|
||||||
|
|
||||||
if (!scaledCursorPosition) return;
|
if (!scaledCursorPosition) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extend the current line.
|
* Extend the current line.
|
||||||
|
@ -35,13 +35,17 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
return useCallback(
|
return useCallback(
|
||||||
(e: KonvaEventObject<WheelEvent>) => {
|
(e: KonvaEventObject<WheelEvent>) => {
|
||||||
// stop default scrolling
|
// stop default scrolling
|
||||||
if (!stageRef.current || isMoveStageKeyHeld) return;
|
if (!stageRef.current || isMoveStageKeyHeld) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
e.evt.preventDefault();
|
e.evt.preventDefault();
|
||||||
|
|
||||||
const cursorPos = stageRef.current.getPointerPosition();
|
const cursorPos = stageRef.current.getPointerPosition();
|
||||||
|
|
||||||
if (!cursorPos) return;
|
if (!cursorPos) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const mousePointTo = {
|
const mousePointTo = {
|
||||||
x: (cursorPos.x - stageRef.current.x()) / stageScale,
|
x: (cursorPos.x - stageRef.current.x()) / stageScale,
|
||||||
|
@ -16,11 +16,15 @@ const useColorPicker = () => {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
updateColorUnderCursor: () => {
|
updateColorUnderCursor: () => {
|
||||||
if (!stage || !canvasBaseLayer) return;
|
if (!stage || !canvasBaseLayer) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const position = stage.getPointerPosition();
|
const position = stage.getPointerPosition();
|
||||||
|
|
||||||
if (!position) return;
|
if (!position) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const pixelRatio = Konva.pixelRatio;
|
const pixelRatio = Konva.pixelRatio;
|
||||||
|
|
||||||
|
@ -3,8 +3,4 @@ import { CanvasState } from './canvasTypes';
|
|||||||
/**
|
/**
|
||||||
* Canvas slice persist denylist
|
* Canvas slice persist denylist
|
||||||
*/
|
*/
|
||||||
export const canvasPersistDenylist: (keyof CanvasState)[] = [
|
export const canvasPersistDenylist: (keyof CanvasState)[] = ['cursorPosition'];
|
||||||
'cursorPosition',
|
|
||||||
'isCanvasInitialized',
|
|
||||||
'doesCanvasNeedScaling',
|
|
||||||
];
|
|
||||||
|
@ -5,10 +5,6 @@ import {
|
|||||||
roundToMultiple,
|
roundToMultiple,
|
||||||
} from 'common/util/roundDownToMultiple';
|
} from 'common/util/roundDownToMultiple';
|
||||||
import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
||||||
import {
|
|
||||||
setActiveTab,
|
|
||||||
setShouldUseCanvasBetaLayout,
|
|
||||||
} from 'features/ui/store/uiSlice';
|
|
||||||
import { IRect, Vector2d } from 'konva/lib/types';
|
import { IRect, Vector2d } from 'konva/lib/types';
|
||||||
import { clamp, cloneDeep } from 'lodash-es';
|
import { clamp, cloneDeep } from 'lodash-es';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
@ -50,12 +46,9 @@ export const initialCanvasState: CanvasState = {
|
|||||||
boundingBoxScaleMethod: 'none',
|
boundingBoxScaleMethod: 'none',
|
||||||
brushColor: { r: 90, g: 90, b: 255, a: 1 },
|
brushColor: { r: 90, g: 90, b: 255, a: 1 },
|
||||||
brushSize: 50,
|
brushSize: 50,
|
||||||
canvasContainerDimensions: { width: 0, height: 0 },
|
|
||||||
colorPickerColor: { r: 90, g: 90, b: 255, a: 1 },
|
colorPickerColor: { r: 90, g: 90, b: 255, a: 1 },
|
||||||
cursorPosition: null,
|
cursorPosition: null,
|
||||||
doesCanvasNeedScaling: false,
|
|
||||||
futureLayerStates: [],
|
futureLayerStates: [],
|
||||||
isCanvasInitialized: false,
|
|
||||||
isDrawing: false,
|
isDrawing: false,
|
||||||
isMaskEnabled: true,
|
isMaskEnabled: true,
|
||||||
isMouseOverBoundingBox: false,
|
isMouseOverBoundingBox: false,
|
||||||
@ -208,7 +201,6 @@ export const canvasSlice = createSlice({
|
|||||||
};
|
};
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
|
|
||||||
state.isCanvasInitialized = false;
|
|
||||||
const newScale = calculateScale(
|
const newScale = calculateScale(
|
||||||
stageDimensions.width,
|
stageDimensions.width,
|
||||||
stageDimensions.height,
|
stageDimensions.height,
|
||||||
@ -228,7 +220,6 @@ export const canvasSlice = createSlice({
|
|||||||
);
|
);
|
||||||
state.stageScale = newScale;
|
state.stageScale = newScale;
|
||||||
state.stageCoordinates = newCoordinates;
|
state.stageCoordinates = newCoordinates;
|
||||||
state.doesCanvasNeedScaling = true;
|
|
||||||
},
|
},
|
||||||
setBoundingBoxDimensions: (state, action: PayloadAction<Dimensions>) => {
|
setBoundingBoxDimensions: (state, action: PayloadAction<Dimensions>) => {
|
||||||
const newDimensions = roundDimensionsTo64(action.payload);
|
const newDimensions = roundDimensionsTo64(action.payload);
|
||||||
@ -258,9 +249,6 @@ export const canvasSlice = createSlice({
|
|||||||
setBoundingBoxPreviewFill: (state, action: PayloadAction<RgbaColor>) => {
|
setBoundingBoxPreviewFill: (state, action: PayloadAction<RgbaColor>) => {
|
||||||
state.boundingBoxPreviewFill = action.payload;
|
state.boundingBoxPreviewFill = action.payload;
|
||||||
},
|
},
|
||||||
setDoesCanvasNeedScaling: (state, action: PayloadAction<boolean>) => {
|
|
||||||
state.doesCanvasNeedScaling = action.payload;
|
|
||||||
},
|
|
||||||
setStageScale: (state, action: PayloadAction<number>) => {
|
setStageScale: (state, action: PayloadAction<number>) => {
|
||||||
state.stageScale = action.payload;
|
state.stageScale = action.payload;
|
||||||
},
|
},
|
||||||
@ -397,7 +385,9 @@ export const canvasSlice = createSlice({
|
|||||||
const { tool, layer, brushColor, brushSize, shouldRestrictStrokesToBox } =
|
const { tool, layer, brushColor, brushSize, shouldRestrictStrokesToBox } =
|
||||||
state;
|
state;
|
||||||
|
|
||||||
if (tool === 'move' || tool === 'colorPicker') return;
|
if (tool === 'move' || tool === 'colorPicker') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const newStrokeWidth = brushSize / 2;
|
const newStrokeWidth = brushSize / 2;
|
||||||
|
|
||||||
@ -434,14 +424,18 @@ export const canvasSlice = createSlice({
|
|||||||
addPointToCurrentLine: (state, action: PayloadAction<number[]>) => {
|
addPointToCurrentLine: (state, action: PayloadAction<number[]>) => {
|
||||||
const lastLine = state.layerState.objects.findLast(isCanvasAnyLine);
|
const lastLine = state.layerState.objects.findLast(isCanvasAnyLine);
|
||||||
|
|
||||||
if (!lastLine) return;
|
if (!lastLine) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
lastLine.points.push(...action.payload);
|
lastLine.points.push(...action.payload);
|
||||||
},
|
},
|
||||||
undo: (state) => {
|
undo: (state) => {
|
||||||
const targetState = state.pastLayerStates.pop();
|
const targetState = state.pastLayerStates.pop();
|
||||||
|
|
||||||
if (!targetState) return;
|
if (!targetState) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
state.futureLayerStates.unshift(cloneDeep(state.layerState));
|
||||||
|
|
||||||
@ -454,7 +448,9 @@ export const canvasSlice = createSlice({
|
|||||||
redo: (state) => {
|
redo: (state) => {
|
||||||
const targetState = state.futureLayerStates.shift();
|
const targetState = state.futureLayerStates.shift();
|
||||||
|
|
||||||
if (!targetState) return;
|
if (!targetState) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||||
|
|
||||||
@ -485,97 +481,14 @@ export const canvasSlice = createSlice({
|
|||||||
state.layerState = initialLayerState;
|
state.layerState = initialLayerState;
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
},
|
},
|
||||||
setCanvasContainerDimensions: (
|
canvasResized: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<Dimensions>
|
action: PayloadAction<{ width: number; height: number }>
|
||||||
) => {
|
) => {
|
||||||
state.canvasContainerDimensions = action.payload;
|
const { width, height } = action.payload;
|
||||||
},
|
|
||||||
resizeAndScaleCanvas: (state) => {
|
|
||||||
const { width: containerWidth, height: containerHeight } =
|
|
||||||
state.canvasContainerDimensions;
|
|
||||||
|
|
||||||
const initialCanvasImage =
|
|
||||||
state.layerState.objects.find(isCanvasBaseImage);
|
|
||||||
|
|
||||||
const newStageDimensions = {
|
const newStageDimensions = {
|
||||||
width: Math.floor(containerWidth),
|
width: Math.floor(width),
|
||||||
height: Math.floor(containerHeight),
|
height: Math.floor(height),
|
||||||
};
|
|
||||||
|
|
||||||
if (!initialCanvasImage) {
|
|
||||||
const newScale = calculateScale(
|
|
||||||
newStageDimensions.width,
|
|
||||||
newStageDimensions.height,
|
|
||||||
512,
|
|
||||||
512,
|
|
||||||
STAGE_PADDING_PERCENTAGE
|
|
||||||
);
|
|
||||||
|
|
||||||
const newCoordinates = calculateCoordinates(
|
|
||||||
newStageDimensions.width,
|
|
||||||
newStageDimensions.height,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
512,
|
|
||||||
512,
|
|
||||||
newScale
|
|
||||||
);
|
|
||||||
|
|
||||||
const newBoundingBoxDimensions = { width: 512, height: 512 };
|
|
||||||
|
|
||||||
state.stageScale = newScale;
|
|
||||||
state.stageCoordinates = newCoordinates;
|
|
||||||
state.stageDimensions = newStageDimensions;
|
|
||||||
state.boundingBoxCoordinates = { x: 0, y: 0 };
|
|
||||||
state.boundingBoxDimensions = newBoundingBoxDimensions;
|
|
||||||
|
|
||||||
if (state.boundingBoxScaleMethod === 'auto') {
|
|
||||||
const scaledDimensions = getScaledBoundingBoxDimensions(
|
|
||||||
newBoundingBoxDimensions
|
|
||||||
);
|
|
||||||
state.scaledBoundingBoxDimensions = scaledDimensions;
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { width: imageWidth, height: imageHeight } = initialCanvasImage;
|
|
||||||
|
|
||||||
const padding = 0.95;
|
|
||||||
|
|
||||||
const newScale = calculateScale(
|
|
||||||
containerWidth,
|
|
||||||
containerHeight,
|
|
||||||
imageWidth,
|
|
||||||
imageHeight,
|
|
||||||
padding
|
|
||||||
);
|
|
||||||
|
|
||||||
const newCoordinates = calculateCoordinates(
|
|
||||||
newStageDimensions.width,
|
|
||||||
newStageDimensions.height,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
imageWidth,
|
|
||||||
imageHeight,
|
|
||||||
newScale
|
|
||||||
);
|
|
||||||
|
|
||||||
state.minimumStageScale = newScale;
|
|
||||||
state.stageScale = newScale;
|
|
||||||
state.stageCoordinates = floorCoordinates(newCoordinates);
|
|
||||||
state.stageDimensions = newStageDimensions;
|
|
||||||
|
|
||||||
state.isCanvasInitialized = true;
|
|
||||||
},
|
|
||||||
resizeCanvas: (state) => {
|
|
||||||
const { width: containerWidth, height: containerHeight } =
|
|
||||||
state.canvasContainerDimensions;
|
|
||||||
|
|
||||||
const newStageDimensions = {
|
|
||||||
width: Math.floor(containerWidth),
|
|
||||||
height: Math.floor(containerHeight),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
state.stageDimensions = newStageDimensions;
|
state.stageDimensions = newStageDimensions;
|
||||||
@ -868,14 +781,6 @@ export const canvasSlice = createSlice({
|
|||||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(setShouldUseCanvasBetaLayout, (state) => {
|
|
||||||
state.doesCanvasNeedScaling = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(setActiveTab, (state) => {
|
|
||||||
state.doesCanvasNeedScaling = true;
|
|
||||||
});
|
|
||||||
builder.addCase(setAspectRatio, (state, action) => {
|
builder.addCase(setAspectRatio, (state, action) => {
|
||||||
const ratio = action.payload;
|
const ratio = action.payload;
|
||||||
if (ratio) {
|
if (ratio) {
|
||||||
@ -907,8 +812,6 @@ export const {
|
|||||||
resetCanvas,
|
resetCanvas,
|
||||||
resetCanvasInteractionState,
|
resetCanvasInteractionState,
|
||||||
resetCanvasView,
|
resetCanvasView,
|
||||||
resizeAndScaleCanvas,
|
|
||||||
resizeCanvas,
|
|
||||||
setBoundingBoxCoordinates,
|
setBoundingBoxCoordinates,
|
||||||
setBoundingBoxDimensions,
|
setBoundingBoxDimensions,
|
||||||
setBoundingBoxPreviewFill,
|
setBoundingBoxPreviewFill,
|
||||||
@ -916,10 +819,8 @@ export const {
|
|||||||
flipBoundingBoxAxes,
|
flipBoundingBoxAxes,
|
||||||
setBrushColor,
|
setBrushColor,
|
||||||
setBrushSize,
|
setBrushSize,
|
||||||
setCanvasContainerDimensions,
|
|
||||||
setColorPickerColor,
|
setColorPickerColor,
|
||||||
setCursorPosition,
|
setCursorPosition,
|
||||||
setDoesCanvasNeedScaling,
|
|
||||||
setInitialCanvasImage,
|
setInitialCanvasImage,
|
||||||
setIsDrawing,
|
setIsDrawing,
|
||||||
setIsMaskEnabled,
|
setIsMaskEnabled,
|
||||||
@ -958,6 +859,7 @@ export const {
|
|||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
canvasSessionIdChanged,
|
canvasSessionIdChanged,
|
||||||
setShouldAntialias,
|
setShouldAntialias,
|
||||||
|
canvasResized,
|
||||||
} = canvasSlice.actions;
|
} = canvasSlice.actions;
|
||||||
|
|
||||||
export default canvasSlice.reducer;
|
export default canvasSlice.reducer;
|
||||||
|
@ -126,12 +126,9 @@ export interface CanvasState {
|
|||||||
boundingBoxScaleMethod: BoundingBoxScale;
|
boundingBoxScaleMethod: BoundingBoxScale;
|
||||||
brushColor: RgbaColor;
|
brushColor: RgbaColor;
|
||||||
brushSize: number;
|
brushSize: number;
|
||||||
canvasContainerDimensions: Dimensions;
|
|
||||||
colorPickerColor: RgbaColor;
|
colorPickerColor: RgbaColor;
|
||||||
cursorPosition: Vector2d | null;
|
cursorPosition: Vector2d | null;
|
||||||
doesCanvasNeedScaling: boolean;
|
|
||||||
futureLayerStates: CanvasLayerState[];
|
futureLayerStates: CanvasLayerState[];
|
||||||
isCanvasInitialized: boolean;
|
|
||||||
isDrawing: boolean;
|
isDrawing: boolean;
|
||||||
isMaskEnabled: boolean;
|
isMaskEnabled: boolean;
|
||||||
isMouseOverBoundingBox: boolean;
|
isMouseOverBoundingBox: boolean;
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
import { AppDispatch, AppGetState } from 'app/store/store';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { debounce } from 'lodash-es';
|
|
||||||
import { setDoesCanvasNeedScaling } from '../canvasSlice';
|
|
||||||
|
|
||||||
const debouncedCanvasScale = debounce((dispatch: AppDispatch) => {
|
|
||||||
dispatch(setDoesCanvasNeedScaling(true));
|
|
||||||
}, 300);
|
|
||||||
|
|
||||||
export const requestCanvasRescale =
|
|
||||||
() => (dispatch: AppDispatch, getState: AppGetState) => {
|
|
||||||
const activeTabName = activeTabNameSelector(getState());
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
|
||||||
debouncedCanvasScale(dispatch);
|
|
||||||
}
|
|
||||||
};
|
|
@ -5,7 +5,9 @@ const getScaledCursorPosition = (stage: Stage) => {
|
|||||||
|
|
||||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||||
|
|
||||||
if (!pointerPosition || !stageTransform) return;
|
if (!pointerPosition || !stageTransform) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||||
|
|
||||||
|
@ -80,19 +80,19 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
flexDir: 'column',
|
flexDir: 'column',
|
||||||
gap: 3,
|
gap: 3,
|
||||||
p: 3,
|
p: 2,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
bg: 'base.200',
|
bg: 'base.250',
|
||||||
_dark: {
|
_dark: {
|
||||||
bg: 'base.850',
|
bg: 'base.750',
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
tooltip={'Toggle this ControlNet'}
|
tooltip="Toggle this ControlNet"
|
||||||
aria-label={'Toggle this ControlNet'}
|
aria-label="Toggle this ControlNet"
|
||||||
isChecked={isEnabled}
|
isChecked={isEnabled}
|
||||||
onChange={handleToggleIsEnabled}
|
onChange={handleToggleIsEnabled}
|
||||||
/>
|
/>
|
||||||
@ -194,7 +194,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ControlNetImagePreview controlNet={controlNet} height={28} />
|
<ControlNetImagePreview controlNet={controlNet} isSmall />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -207,7 +207,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
|
|
||||||
{isExpanded && (
|
{isExpanded && (
|
||||||
<>
|
<>
|
||||||
<ControlNetImagePreview controlNet={controlNet} height="392px" />
|
<ControlNetImagePreview controlNet={controlNet} />
|
||||||
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
|
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
|
||||||
<ControlNetProcessorComponent controlNet={controlNet} />
|
<ControlNetProcessorComponent controlNet={controlNet} />
|
||||||
</>
|
</>
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { Box, Flex, Spinner, SystemStyleObject } from '@chakra-ui/react';
|
import { Box, Flex, Spinner } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import {
|
|
||||||
TypesafeDraggableData,
|
|
||||||
TypesafeDroppableData,
|
|
||||||
} from 'features/dnd/types';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import {
|
||||||
|
TypesafeDraggableData,
|
||||||
|
TypesafeDroppableData,
|
||||||
|
} from 'features/dnd/types';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
@ -21,7 +21,7 @@ import {
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNet: ControlNetConfig;
|
controlNet: ControlNetConfig;
|
||||||
height: SystemStyleObject['h'];
|
isSmall?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -36,15 +36,14 @@ const selector = createSelector(
|
|||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ControlNetImagePreview = (props: Props) => {
|
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||||
const { height } = props;
|
|
||||||
const {
|
const {
|
||||||
controlImage: controlImageName,
|
controlImage: controlImageName,
|
||||||
processedControlImage: processedControlImageName,
|
processedControlImage: processedControlImageName,
|
||||||
processorType,
|
processorType,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
controlNetId,
|
controlNetId,
|
||||||
} = props.controlNet;
|
} = controlNet;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
@ -109,7 +108,7 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
h: height,
|
h: isSmall ? 28 : 366, // magic no touch
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||||
|
@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
|
import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -36,4 +36,4 @@ const ParamControlNetFeatureToggle = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamControlNetFeatureToggle;
|
export default memo(ParamControlNetFeatureToggle);
|
||||||
|
@ -23,7 +23,7 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
|||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
isDisabled={!isEnabled}
|
isDisabled={!isEnabled}
|
||||||
label={'Weight'}
|
label="Weight"
|
||||||
value={weight}
|
value={weight}
|
||||||
onChange={handleWeightChanged}
|
onChange={handleWeightChanged}
|
||||||
min={0}
|
min={0}
|
||||||
|
@ -8,6 +8,7 @@ import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial
|
|||||||
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled';
|
||||||
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||||
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -40,4 +41,4 @@ const ParamDynamicPromptsCollapse = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamDynamicPromptsCollapse;
|
export default memo(ParamDynamicPromptsCollapse);
|
||||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { combinatorialToggled } from '../store/dynamicPromptsSlice';
|
import { combinatorialToggled } from '../store/dynamicPromptsSlice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -34,4 +34,4 @@ const ParamDynamicPromptsCombinatorial = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamDynamicPromptsCombinatorial;
|
export default memo(ParamDynamicPromptsCombinatorial);
|
||||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { isEnabledToggled } from '../store/dynamicPromptsSlice';
|
import { isEnabledToggled } from '../store/dynamicPromptsSlice';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -33,4 +33,4 @@ const ParamDynamicPromptsToggle = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamDynamicPromptsToggle;
|
export default memo(ParamDynamicPromptsToggle);
|
||||||
|
@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import {
|
import {
|
||||||
maxPromptsChanged,
|
maxPromptsChanged,
|
||||||
maxPromptsReset,
|
maxPromptsReset,
|
||||||
@ -60,4 +60,4 @@ const ParamDynamicPromptsMaxPrompts = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamDynamicPromptsMaxPrompts;
|
export default memo(ParamDynamicPromptsMaxPrompts);
|
||||||
|
@ -13,7 +13,7 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe
|
|||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { PropsWithChildren, useCallback, useMemo, useRef } from 'react';
|
import { PropsWithChildren, memo, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants';
|
||||||
|
|
||||||
@ -118,7 +118,7 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
inputRef={inputRef}
|
inputRef={inputRef}
|
||||||
autoFocus
|
autoFocus
|
||||||
placeholder={'Add Embedding'}
|
placeholder="Add Embedding"
|
||||||
value={null}
|
value={null}
|
||||||
data={data}
|
data={data}
|
||||||
nothingFound="No matching Embeddings"
|
nothingFound="No matching Embeddings"
|
||||||
@ -140,4 +140,4 @@ const ParamEmbeddingPopover = (props: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ParamEmbeddingPopover;
|
export default memo(ParamEmbeddingPopover);
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Badge, Flex } from '@chakra-ui/react';
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
const AutoAddIcon = () => {
|
const AutoAddIcon = () => {
|
||||||
return (
|
return (
|
||||||
@ -20,4 +21,4 @@ const AutoAddIcon = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default AutoAddIcon;
|
export default memo(AutoAddIcon);
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user