mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix merge conflicts with main
This commit is contained in:
commit
470a39935c
@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
|||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import get_torch_device_name
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
torch_device_name = get_torch_device_name()
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .model import CLIPField
|
from .model import CLIPField
|
||||||
@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=get_pooled,
|
requires_pooled=get_pooled,
|
||||||
|
@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
# TODO:
|
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device),
|
||||||
@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if device == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
def slerp(
|
def slerp(
|
||||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||||
@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
blended_latents = blended_latents.to("cpu")
|
blended_latents = blended_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=blended_latents)
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||||
|
@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -46,7 +46,7 @@ def get_noise(
|
|||||||
height // downsampling_factor,
|
height // downsampling_factor,
|
||||||
width // downsampling_factor,
|
width // downsampling_factor,
|
||||||
],
|
],
|
||||||
dtype=torch_dtype(device),
|
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||||
device=noise_device_type,
|
device=noise_device_type,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).to("cpu")
|
).to("cpu")
|
||||||
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@field_validator("seed", mode="before")
|
@field_validator("seed", mode="before")
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
noise = get_noise(
|
noise = get_noise(
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
device=choose_torch_device(),
|
device=TorchDevice.choose_torch_device(),
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,6 @@ from typing import Literal
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageOutput
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@ -33,9 +32,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
|||||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
@ -115,9 +111,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
|
|||||||
DEFAULT_VRAM_CACHE = 0.25
|
DEFAULT_VRAM_CACHE = 0.25
|
||||||
DEFAULT_CONVERT_CACHE = 20.0
|
DEFAULT_CONVERT_CACHE = 20.0
|
||||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||||
|
|
||||||
|
|
||||||
def get_default_ram_cache_size() -> float:
|
def get_default_ram_cache_size() -> float:
|
||||||
@ -106,7 +106,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
lazy_offload: Keep models in VRAM until their space is needed.
|
lazy_offload: Keep models in VRAM until their space is needed.
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||||
@ -377,6 +377,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
parsed_config_dict["vram"] = v
|
parsed_config_dict["vram"] = v
|
||||||
|
# autocast was removed in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
if k == "conf_path":
|
if k == "conf_path":
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||||
if k == "legacy_conf_dir":
|
if k == "legacy_conf_dir":
|
||||||
@ -399,6 +402,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
|
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
# autocast was removed from precision in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
else:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
if k == "schema_version":
|
||||||
|
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||||
|
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
"""Load and migrate a config file to the latest version.
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
@ -425,17 +450,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||||
migrated_config.write_file(config_path)
|
migrated_config.write_file(config_path)
|
||||||
return migrated_config
|
return migrated_config
|
||||||
else:
|
|
||||||
# Attempt to load as a v4 config file
|
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||||
try:
|
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
loaded_config_dict.write_file(config_path)
|
||||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
|
||||||
assert (
|
# Attempt to load as a v4 config file
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
try:
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
return config
|
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||||
except Exception as e:
|
assert (
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
|
|||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
|||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -643,11 +644,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
@staticmethod
|
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
precision = choose_precision(choose_torch_device())
|
precision = TorchDevice.choose_torch_dtype()
|
||||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||||
|
|
||||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
return ModelInstallJob(
|
return ModelInstallJob(
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_record_service: ModelRecordServiceBase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
execution_device: torch.device = choose_torch_device(),
|
execution_device: Optional[torch.device] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Construct the model manager service instance.
|
Construct the model manager service instance.
|
||||||
@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
lazy_offloading=app_config.lazy_offload,
|
lazy_offloading=app_config.lazy_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
|
@ -12,7 +12,7 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
@ -47,7 +47,7 @@ class DepthAnythingDetector:
|
|||||||
self.context = context
|
self.context = context
|
||||||
self.model: Optional[DPT_DINOv2] = None
|
self.model: Optional[DPT_DINOv2] = None
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
|
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
|
||||||
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
|
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
|
||||||
@ -68,7 +68,7 @@ class DepthAnythingDetector:
|
|||||||
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.model.to(choose_torch_device())
|
self.model.to(self.device)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
|||||||
|
|
||||||
image_height, image_width = np_image.shape[:2]
|
image_height, image_width = np_image.shape[:2]
|
||||||
np_image = transform({"image": np_image})["image"]
|
np_image = transform({"image": np_image})["image"]
|
||||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
depth = self.model(tensor_image)
|
depth = self.model(tensor_image)
|
||||||
|
@ -4,11 +4,10 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
@ -23,9 +22,9 @@ config = get_config()
|
|||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self, context: InvocationContext):
|
def __init__(self, context: InvocationContext):
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||||
@ -65,7 +65,7 @@ class RealESRGAN:
|
|||||||
self.pre_pad = pre_pad
|
self.pre_pad = pre_pad
|
||||||
self.mod_scale: Optional[int] = None
|
self.mod_scale: Optional[int] = None
|
||||||
self.half = half
|
self.half = half
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
# prefer to use params_ema
|
# prefer to use params_ema
|
||||||
if "params_ema" in loadnet:
|
if "params_ema" in loadnet:
|
||||||
|
@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
@ -51,7 +51,7 @@ class SafetyChecker:
|
|||||||
cls._load_safety_checker()
|
cls._load_safety_checker()
|
||||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||||
return False
|
return False
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
features = cls.feature_extractor([image], return_tensors="pt")
|
features = cls.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
cls.safety_checker.to(device)
|
cls.safety_checker.to(device)
|
||||||
|
@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
|||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
# TO DO: The loader is not thread safe!
|
# TO DO: The loader is not thread safe!
|
||||||
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -31,15 +31,12 @@ import torch
|
|||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
@ -245,9 +242,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
"""Move model into the indicated device.
|
"""Move model into the indicated device.
|
||||||
@ -417,10 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.stats.cleared = models_cleared
|
self.stats.cleared = models_cleared
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
|
@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
|||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -43,6 +43,7 @@ class ModelMerger(object):
|
|||||||
Initialize a ModelMerger object with the model installer.
|
Initialize a ModelMerger object with the model installer.
|
||||||
"""
|
"""
|
||||||
self._installer = installer
|
self._installer = installer
|
||||||
|
self._dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
self,
|
self,
|
||||||
@ -68,7 +69,7 @@ class ModelMerger(object):
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
|
|
||||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||||
@ -151,7 +152,7 @@ class ModelMerger(object):
|
|||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||||
|
|
||||||
# register model and get its unique key
|
# register model and get its unique key
|
||||||
|
@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import normalize_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif self.unet.device.type == "cuda":
|
elif self.unet.device.type == "cuda":
|
||||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
Initialization file for invokeai.backend.util
|
Initialization file for invokeai.backend.util
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .devices import choose_precision, choose_torch_device
|
|
||||||
from .logging import InvokeAILogger
|
from .logging import InvokeAILogger
|
||||||
from .util import GIG, Chdir, directory_size
|
from .util import GIG, Chdir, directory_size
|
||||||
|
|
||||||
@ -11,6 +10,4 @@ __all__ = [
|
|||||||
"directory_size",
|
"directory_size",
|
||||||
"Chdir",
|
"Chdir",
|
||||||
"InvokeAILogger",
|
"InvokeAILogger",
|
||||||
"choose_precision",
|
|
||||||
"choose_torch_device",
|
|
||||||
]
|
]
|
||||||
|
@ -1,89 +1,110 @@
|
|||||||
from __future__ import annotations
|
from typing import Dict, Literal, Optional, Union
|
||||||
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from deprecated import deprecated
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import PRECISION, get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
# legacy APIs
|
||||||
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||||
|
"""Return the string representation of the recommended torch device."""
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||||
|
return PRECISION_TO_NAME[torch_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
config = get_config()
|
return TorchDevice.choose_torch_device()
|
||||||
if config.device == "auto":
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.device("cuda")
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
return torch.device("mps")
|
"""Return the torch precision for the recommended torch device."""
|
||||||
|
return TorchDevice.choose_torch_dtype(device)
|
||||||
|
|
||||||
|
|
||||||
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDevice:
|
||||||
|
"""Abstraction layer for torch devices."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_device(cls) -> torch.device:
|
||||||
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
|
app_config = get_config()
|
||||||
|
if app_config.device != "auto":
|
||||||
|
device = torch.device(app_config.device)
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = CUDA_DEVICE
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = MPS_DEVICE
|
||||||
else:
|
else:
|
||||||
return CPU_DEVICE
|
device = CPU_DEVICE
|
||||||
else:
|
return cls.normalize(device)
|
||||||
return torch.device(config.device)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
|
"""Return the precision to use for accelerated inference."""
|
||||||
|
device = device or cls.choose_torch_device()
|
||||||
|
config = get_config()
|
||||||
|
if device.type == "cuda" and torch.cuda.is_available():
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
|
# These GPUs have limited support for float16
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
elif config.precision == "auto":
|
||||||
|
# Default to float16 for CUDA devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
|
||||||
def get_torch_device_name() -> str:
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||||
device = choose_torch_device()
|
if config.precision == "auto":
|
||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
# Default to float16 for MPS devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
# CPU / safe fallback
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_torch_device_name(cls) -> str:
|
||||||
|
"""Return the device name for the current torch device."""
|
||||||
|
device = cls.choose_torch_device()
|
||||||
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
@classmethod
|
||||||
"""Return an appropriate precision for the given torch device."""
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||||
app_config = get_config()
|
"""Add the device index to CUDA devices."""
|
||||||
if device.type == "cuda":
|
device = torch.device(device)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
||||||
# These GPUs have limited support for float16
|
|
||||||
return "float32"
|
|
||||||
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for CUDA devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
elif device.type == "mps":
|
|
||||||
if app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for MPS devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
# CPU / safe fallback
|
|
||||||
return "float32"
|
|
||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
|
||||||
precision = choose_precision(device)
|
|
||||||
if precision == "float16":
|
|
||||||
return torch.float16
|
|
||||||
if precision == "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
# "auto", "autocast", "float32"
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision: PRECISION):
|
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
|
||||||
# float16 currently requires autocast to avoid errors like:
|
|
||||||
# 'expected scalar type Half but found Float'
|
|
||||||
if precision == "autocast" or precision == "float16":
|
|
||||||
return autocast
|
|
||||||
return nullcontext
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
|
||||||
"""Ensure device has a device index defined, if appropriate."""
|
|
||||||
device = torch.device(device)
|
|
||||||
if device.index is None:
|
|
||||||
# cuda might be the only torch backend that currently uses the device index?
|
|
||||||
# I don't see anything like `current_device` for cpu or mps.
|
|
||||||
if device.type == "cuda":
|
|
||||||
device = torch.device(device.type, torch.cuda.current_device())
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_cache(cls) -> None:
|
||||||
|
"""Clear the GPU device cache."""
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||||
|
return NAME_TO_PRECISION[precision_name]
|
||||||
|
@ -770,6 +770,8 @@
|
|||||||
"float": "Float",
|
"float": "Float",
|
||||||
"fullyContainNodes": "Fully Contain Nodes to Select",
|
"fullyContainNodes": "Fully Contain Nodes to Select",
|
||||||
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
|
||||||
|
"showEdgeLabels": "Show Edge Labels",
|
||||||
|
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
|
||||||
"hideLegendNodes": "Hide Field Type Legend",
|
"hideLegendNodes": "Hide Field Type Legend",
|
||||||
"hideMinimapnodes": "Hide MiniMap",
|
"hideMinimapnodes": "Hide MiniMap",
|
||||||
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
|
||||||
|
@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
|
|
||||||
export const useGlobalHotkeys = () => {
|
export const useGlobalHotkeys = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
|
const isModelManagerEnabled = useFeatureStatus('modelManager');
|
||||||
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
|
@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
|
|||||||
|
|
||||||
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
|
||||||
const boardName = useBoardName(board_id);
|
const boardName = useBoardName(board_id);
|
||||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||||
|
|
||||||
const [bulkDownload] = useBulkDownloadImagesMutation();
|
const [bulkDownload] = useBulkDownloadImagesMutation();
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
|
|||||||
const selection = useAppSelector((s) => s.gallery.selection);
|
const selection = useAppSelector((s) => s.gallery.selection);
|
||||||
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
|
||||||
|
|
||||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
|
|||||||
const selection = useAppSelector((s) => s.gallery.selection);
|
const selection = useAppSelector((s) => s.gallery.selection);
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
|
|
||||||
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
|
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
|
||||||
|
|
||||||
const [starImages] = useStarImagesMutation();
|
const [starImages] = useStarImagesMutation();
|
||||||
const [unstarImages] = useUnstarImagesMutation();
|
const [unstarImages] = useUnstarImagesMutation();
|
||||||
|
@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
const { downloadImage } = useDownloadImage();
|
const { downloadImage } = useDownloadImage();
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
|||||||
[imageDTO?.image_name]
|
[imageDTO?.image_name]
|
||||||
);
|
);
|
||||||
const isSelected = useAppSelector(selectIsSelected);
|
const isSelected = useAppSelector(selectIsSelected);
|
||||||
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
|
const isMultiSelectEnabled = useFeatureStatus('multiselect');
|
||||||
|
|
||||||
const handleClick = useCallback(
|
const handleClick = useCallback(
|
||||||
(e: MouseEvent<HTMLDivElement>) => {
|
(e: MouseEvent<HTMLDivElement>) => {
|
||||||
|
@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
|
|||||||
import ParamHrfToggle from './ParamHrfToggle';
|
import ParamHrfToggle from './ParamHrfToggle';
|
||||||
|
|
||||||
export const HrfSettings = memo(() => {
|
export const HrfSettings = memo(() => {
|
||||||
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
const isHRFFeatureEnabled = useFeatureStatus('hrf');
|
||||||
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
|
||||||
|
|
||||||
if (!isHRFFeatureEnabled) {
|
if (!isHRFFeatureEnabled) {
|
||||||
|
@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
|
|||||||
|
|
||||||
export const useStarterModelsToast = () => {
|
export const useStarterModelsToast = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
|
const isEnabled = useFeatureStatus('starterModels');
|
||||||
const [didToast, setDidToast] = useState(false);
|
const [didToast, setDidToast] = useState(false);
|
||||||
const [mainModels, { data }] = useMainModels();
|
const [mainModels, { data }] = useMainModels();
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { CSSProperties } from 'react';
|
import type { CSSProperties } from 'react';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import type { EdgeProps } from 'reactflow';
|
import type { EdgeProps } from 'reactflow';
|
||||||
import { BaseEdge, getBezierPath } from 'reactflow';
|
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
|
||||||
|
|
||||||
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
import { makeEdgeSelector } from './util/makeEdgeSelector';
|
||||||
|
|
||||||
@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
|
|||||||
[source, sourceHandleId, target, targetHandleId, selected]
|
[source, sourceHandleId, target, targetHandleId, selected]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
|
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
|
||||||
|
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
|
||||||
|
|
||||||
const [edgePath] = getBezierPath({
|
const [edgePath, labelX, labelY] = getBezierPath({
|
||||||
sourceX,
|
sourceX,
|
||||||
sourceY,
|
sourceY,
|
||||||
sourcePosition,
|
sourcePosition,
|
||||||
@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
|
|||||||
[isSelected, shouldAnimate, stroke]
|
[isSelected, shouldAnimate, stroke]
|
||||||
);
|
);
|
||||||
|
|
||||||
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
|
return (
|
||||||
|
<>
|
||||||
|
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
|
||||||
|
{label && shouldShowEdgeLabels && (
|
||||||
|
<EdgeLabelRenderer>
|
||||||
|
<Flex
|
||||||
|
className="nodrag nopan"
|
||||||
|
pointerEvents="all"
|
||||||
|
position="absolute"
|
||||||
|
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
|
||||||
|
bg="base.800"
|
||||||
|
borderRadius="base"
|
||||||
|
borderWidth={1}
|
||||||
|
borderColor={isSelected ? 'undefined' : 'transparent'}
|
||||||
|
opacity={isSelected ? 1 : 0.5}
|
||||||
|
py={1}
|
||||||
|
px={3}
|
||||||
|
shadow="md"
|
||||||
|
>
|
||||||
|
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</EdgeLabelRenderer>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(InvocationDefaultEdge);
|
export default memo(InvocationDefaultEdge);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
|
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
|
||||||
import { getFieldColor } from './getEdgeColor';
|
import { getFieldColor } from './getEdgeColor';
|
||||||
@ -10,6 +10,7 @@ const defaultReturnValue = {
|
|||||||
isSelected: false,
|
isSelected: false,
|
||||||
shouldAnimate: false,
|
shouldAnimate: false,
|
||||||
stroke: colorTokenToCssVar('base.500'),
|
stroke: colorTokenToCssVar('base.500'),
|
||||||
|
label: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const makeEdgeSelector = (
|
export const makeEdgeSelector = (
|
||||||
@ -19,25 +20,34 @@ export const makeEdgeSelector = (
|
|||||||
targetHandleId: string | null | undefined,
|
targetHandleId: string | null | undefined,
|
||||||
selected?: boolean
|
selected?: boolean
|
||||||
) =>
|
) =>
|
||||||
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
|
createMemoizedSelector(
|
||||||
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
selectNodesSlice,
|
||||||
const targetNode = nodes.nodes.find((node) => node.id === target);
|
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
|
||||||
|
const sourceNode = nodes.nodes.find((node) => node.id === source);
|
||||||
|
const targetNode = nodes.nodes.find((node) => node.id === target);
|
||||||
|
|
||||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||||
|
|
||||||
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
|
||||||
if (!sourceNode || !sourceHandleId) {
|
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||||
return defaultReturnValue;
|
return defaultReturnValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
||||||
|
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||||
|
|
||||||
|
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||||
|
|
||||||
|
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
|
||||||
|
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
|
||||||
|
|
||||||
|
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isSelected,
|
||||||
|
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
||||||
|
stroke,
|
||||||
|
label,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
);
|
||||||
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
|
|
||||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
|
||||||
|
|
||||||
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
|
||||||
|
|
||||||
return {
|
|
||||||
isSelected,
|
|
||||||
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
|
|
||||||
stroke,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
|
|||||||
|
|
||||||
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
const InvocationNodeFooter = ({ nodeId }: Props) => {
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
className={DRAG_HANDLE_CLASSNAME}
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
|
@ -24,6 +24,7 @@ import {
|
|||||||
selectNodesSlice,
|
selectNodesSlice,
|
||||||
shouldAnimateEdgesChanged,
|
shouldAnimateEdgesChanged,
|
||||||
shouldColorEdgesChanged,
|
shouldColorEdgesChanged,
|
||||||
|
shouldShowEdgeLabelsChanged,
|
||||||
shouldSnapToGridChanged,
|
shouldSnapToGridChanged,
|
||||||
shouldValidateGraphChanged,
|
shouldValidateGraphChanged,
|
||||||
} from 'features/nodes/store/nodesSlice';
|
} from 'features/nodes/store/nodesSlice';
|
||||||
@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
|
|||||||
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
const formLabelProps: FormLabelProps = { flexGrow: 1 };
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
|
const {
|
||||||
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
|
selectionMode,
|
||||||
|
} = nodes;
|
||||||
return {
|
return {
|
||||||
shouldAnimateEdges,
|
shouldAnimateEdges,
|
||||||
shouldValidateGraph,
|
shouldValidateGraph,
|
||||||
shouldSnapToGrid,
|
shouldSnapToGrid,
|
||||||
shouldColorEdges,
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
selectionModeIsChecked: selectionMode === SelectionMode.Full,
|
selectionModeIsChecked: selectionMode === SelectionMode.Full,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@ -52,8 +61,14 @@ type Props = {
|
|||||||
const WorkflowEditorSettings = ({ children }: Props) => {
|
const WorkflowEditorSettings = ({ children }: Props) => {
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
|
const {
|
||||||
useAppSelector(selector);
|
shouldAnimateEdges,
|
||||||
|
shouldValidateGraph,
|
||||||
|
shouldSnapToGrid,
|
||||||
|
shouldColorEdges,
|
||||||
|
shouldShowEdgeLabels,
|
||||||
|
selectionModeIsChecked,
|
||||||
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const handleChangeShouldValidate = useCallback(
|
const handleChangeShouldValidate = useCallback(
|
||||||
(e: ChangeEvent<HTMLInputElement>) => {
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleChangeShouldShowEdgeLabels = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
|
|||||||
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
<FormControl>
|
||||||
|
<Flex w="full">
|
||||||
|
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
|
||||||
|
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
|
||||||
|
</Flex>
|
||||||
|
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
<Divider />
|
||||||
<Heading size="sm" pt={4}>
|
<Heading size="sm" pt={4}>
|
||||||
{t('common.advanced')}
|
{t('common.advanced')}
|
||||||
</Heading>
|
</Heading>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
import { selectNodeTemplate } from 'features/nodes/store/selectors';
|
||||||
@ -10,7 +10,7 @@ import { useMemo } from 'react';
|
|||||||
export const useOutputFieldNames = (nodeId: string) => {
|
export const useOutputFieldNames = (nodeId: string) => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(selectNodesSlice, (nodes) => {
|
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||||
const template = selectNodeTemplate(nodes, nodeId);
|
const template = selectNodeTemplate(nodes, nodeId);
|
||||||
if (!template) {
|
if (!template) {
|
||||||
return EMPTY_ARRAY;
|
return EMPTY_ARRAY;
|
||||||
|
@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
|
|||||||
|
|
||||||
export const useWithFooter = (nodeId: string) => {
|
export const useWithFooter = (nodeId: string) => {
|
||||||
const hasImageOutput = useHasImageOutput(nodeId);
|
const hasImageOutput = useHasImageOutput(nodeId);
|
||||||
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
|
|
||||||
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
|
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
|
||||||
return withFooter;
|
return withFooter;
|
||||||
};
|
};
|
||||||
|
@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
|
|||||||
shouldAnimateEdges: true,
|
shouldAnimateEdges: true,
|
||||||
shouldSnapToGrid: false,
|
shouldSnapToGrid: false,
|
||||||
shouldColorEdges: true,
|
shouldColorEdges: true,
|
||||||
|
shouldShowEdgeLabels: false,
|
||||||
isAddNodePopoverOpen: false,
|
isAddNodePopoverOpen: false,
|
||||||
nodeOpacity: 1,
|
nodeOpacity: 1,
|
||||||
selectedNodes: [],
|
selectedNodes: [],
|
||||||
@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
|
|||||||
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAnimateEdges = action.payload;
|
state.shouldAnimateEdges = action.payload;
|
||||||
},
|
},
|
||||||
|
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldShowEdgeLabels = action.payload;
|
||||||
|
},
|
||||||
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldSnapToGrid = action.payload;
|
state.shouldSnapToGrid = action.payload;
|
||||||
},
|
},
|
||||||
@ -831,6 +835,7 @@ export const {
|
|||||||
viewportChanged,
|
viewportChanged,
|
||||||
edgeAdded,
|
edgeAdded,
|
||||||
nodeTemplatesBuilt,
|
nodeTemplatesBuilt,
|
||||||
|
shouldShowEdgeLabelsChanged,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
// This is used for tracking `state.workflow.isTouched`
|
// This is used for tracking `state.workflow.isTouched`
|
||||||
|
@ -32,6 +32,7 @@ export type NodesState = {
|
|||||||
isAddNodePopoverOpen: boolean;
|
isAddNodePopoverOpen: boolean;
|
||||||
addNewNodePosition: XYPosition | null;
|
addNewNodePosition: XYPosition | null;
|
||||||
selectionMode: SelectionMode;
|
selectionMode: SelectionMode;
|
||||||
|
shouldShowEdgeLabels: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type WorkflowMode = 'edit' | 'view';
|
export type WorkflowMode = 'edit' | 'view';
|
||||||
|
@ -1,24 +1,18 @@
|
|||||||
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
|
||||||
|
|
||||||
const ParamInfillColorOptions = () => {
|
const ParamInfillColorOptions = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
const infillColor = useAppSelector(selectInfillColor);
|
||||||
() =>
|
|
||||||
createSelector(selectGenerationSlice, (generation) => ({
|
|
||||||
infillColor: generation.infillColorValue,
|
|
||||||
})),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { infillColor } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
|
@ -1,35 +1,23 @@
|
|||||||
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import {
|
import {
|
||||||
selectGenerationSlice,
|
|
||||||
setInfillMosaicMaxColor,
|
setInfillMosaicMaxColor,
|
||||||
setInfillMosaicMinColor,
|
setInfillMosaicMinColor,
|
||||||
setInfillMosaicTileHeight,
|
setInfillMosaicTileHeight,
|
||||||
setInfillMosaicTileWidth,
|
setInfillMosaicTileWidth,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { RgbaColor } from 'react-colorful';
|
import type { RgbaColor } from 'react-colorful';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const ParamInfillMosaicTileSize = () => {
|
const ParamInfillMosaicTileSize = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
|
||||||
() =>
|
const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
|
||||||
createSelector(selectGenerationSlice, (generation) => ({
|
const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
|
||||||
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
|
||||||
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
|
||||||
infillMosaicMinColor: generation.infillMosaicMinColor,
|
|
||||||
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
|
||||||
})),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
|
||||||
useAppSelector(selector);
|
|
||||||
|
|
||||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const clearQueueDisclosure = useDisclosure();
|
const clearQueueDisclosure = useDisclosure();
|
||||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||||
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
const { queueSize } = useGetQueueStatusQuery(undefined, {
|
||||||
selectFromResult: (res) => ({
|
selectFromResult: (res) => ({
|
||||||
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
|
||||||
|
@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
|
|||||||
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
|
||||||
|
|
||||||
const QueueControls = () => {
|
const QueueControls = () => {
|
||||||
const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
const isPrependEnabled = useFeatureStatus('prependQueue');
|
||||||
return (
|
return (
|
||||||
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
|
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
|
||||||
<ButtonGroup size="lg" isAttached={false}>
|
<ButtonGroup size="lg" isAttached={false}>
|
||||||
|
@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
|
|||||||
import QueueTabQueueControls from './QueueTabQueueControls';
|
import QueueTabQueueControls from './QueueTabQueueControls';
|
||||||
|
|
||||||
const QueueTabContent = () => {
|
const QueueTabContent = () => {
|
||||||
const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
|
const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||||
|
@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
|
|||||||
import ResumeProcessorButton from './ResumeProcessorButton';
|
import ResumeProcessorButton from './ResumeProcessorButton';
|
||||||
|
|
||||||
const QueueTabQueueControls = () => {
|
const QueueTabQueueControls = () => {
|
||||||
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
|
const isPauseEnabled = useFeatureStatus('pauseQueue');
|
||||||
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
|
const isResumeEnabled = useFeatureStatus('resumeQueue');
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
|
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
|
||||||
{isPauseEnabled || isResumeEnabled ? (
|
{isPauseEnabled || isResumeEnabled ? (
|
||||||
|
@ -13,7 +13,7 @@ export const useQueueFront = () => {
|
|||||||
const [_, { isLoading }] = useEnqueueBatchMutation({
|
const [_, { isLoading }] = useEnqueueBatchMutation({
|
||||||
fixedCacheKey: 'enqueueBatch',
|
fixedCacheKey: 'enqueueBatch',
|
||||||
});
|
});
|
||||||
const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
|
const prependEnabled = useFeatureStatus('prependQueue');
|
||||||
|
|
||||||
const isDisabled = useMemo(() => {
|
const isDisabled = useMemo(() => {
|
||||||
return !isReady || !prependEnabled;
|
return !isReady || !prependEnabled;
|
||||||
|
@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
|
|||||||
export const ControlSettingsAccordion: React.FC = memo(() => {
|
export const ControlSettingsAccordion: React.FC = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { controlAdapterIds, badges } = useAppSelector(selector);
|
const { controlAdapterIds, badges } = useAppSelector(selector);
|
||||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
const isControlNetEnabled = useFeatureStatus('controlNet');
|
||||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||||
id: 'control-settings',
|
id: 'control-settings',
|
||||||
defaultIsOpen: true,
|
defaultIsOpen: true,
|
||||||
@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
|
|||||||
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
|
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
|
||||||
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
|
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
|
||||||
|
|
||||||
if (isControlNetDisabled) {
|
if (!isControlNetEnabled) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const language = useAppSelector((s) => s.system.language);
|
const language = useAppSelector((s) => s.system.language);
|
||||||
const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled;
|
const isLocalizationEnabled = useFeatureStatus('localization');
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === language), [language]);
|
const value = useMemo(() => options.find((o) => o.value === language), [language]);
|
||||||
|
|
||||||
|
@ -23,9 +23,9 @@ const SettingsMenu = () => {
|
|||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
useGlobalMenuClose(onClose);
|
useGlobalMenuClose(onClose);
|
||||||
|
|
||||||
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
|
const isBugLinkEnabled = useFeatureStatus('bugLink');
|
||||||
const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled;
|
const isDiscordLinkEnabled = useFeatureStatus('discordLink');
|
||||||
const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled;
|
const isGithubLinkEnabled = useFeatureStatus('githubLink');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
|
||||||
|
@ -1,32 +1,24 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { AppFeature, SDFeature } from 'app/types/invokeai';
|
import type { AppFeature, SDFeature } from 'app/types/invokeai';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
|
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
|
||||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
const selectIsFeatureEnabled = useMemo(
|
||||||
|
|
||||||
const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures);
|
|
||||||
|
|
||||||
const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures);
|
|
||||||
|
|
||||||
const isFeatureDisabled = useMemo(
|
|
||||||
() =>
|
() =>
|
||||||
disabledFeatures.includes(feature as AppFeature) ||
|
createSelector(selectConfigSlice, (config) => {
|
||||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
return !(
|
||||||
disabledTabs.includes(feature as InvokeTabName),
|
config.disabledFeatures.includes(feature as AppFeature) ||
|
||||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
config.disabledSDFeatures.includes(feature as SDFeature) ||
|
||||||
|
config.disabledTabs.includes(feature as InvokeTabName)
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
[feature]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isFeatureEnabled = useMemo(
|
const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled);
|
||||||
() =>
|
|
||||||
!(
|
|
||||||
disabledFeatures.includes(feature as AppFeature) ||
|
|
||||||
disabledSDFeatures.includes(feature as SDFeature) ||
|
|
||||||
disabledTabs.includes(feature as InvokeTabName)
|
|
||||||
),
|
|
||||||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
|
|
||||||
);
|
|
||||||
|
|
||||||
return { isFeatureDisabled, isFeatureEnabled };
|
return isFeatureEnabled;
|
||||||
};
|
};
|
||||||
|
132
tests/backend/util/test_devices.py
Normal file
132
tests/backend/util/test_devices.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Test abstract device class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||||
|
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||||
|
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
|
||||||
|
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_device = TorchDevice.choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_cuda(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
|
||||||
|
def test_device_dtype_mps(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_override(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
config.precision = "float32"
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert TorchDevice.normalize("mps") == torch.device("mps")
|
||||||
|
assert TorchDevice.normalize("cpu") == torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_legacy_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_legacy_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
returned_dtype = torch_dtype(torch_device)
|
||||||
|
assert returned_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_precision_name():
|
||||||
|
config = get_config()
|
||||||
|
config.precision = "auto"
|
||||||
|
with (
|
||||||
|
pytest.deprecated_call(),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
assert "float16" == choose_precision(torch.device("cuda"))
|
||||||
|
assert "float16" == choose_precision(torch.device("mps"))
|
||||||
|
assert "float32" == choose_precision(torch.device("cpu"))
|
Loading…
Reference in New Issue
Block a user