Merge branch 'main' into pr/6086

This commit is contained in:
blessedcoolant
2024-05-01 00:37:06 +05:30
338 changed files with 11169 additions and 3081 deletions

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging
from invokeai.version import __version__
@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig:
infill_methods = ["tile", "lama", "cv2"]
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
if PatchMatch.patchmatch_available():
infill_methods.append("patchmatch")

View File

@ -219,28 +219,13 @@ async def scan_for_models(
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
# Check if the model is installed by comparing paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
is_installed = any(str(models_path / m.path) == path for m in installed_models)
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:

View File

@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")

View File

@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -14,10 +22,9 @@ from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
@ -36,7 +43,7 @@ from .model import CLIPField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -51,6 +58,9 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -89,7 +99,7 @@ class CompelInvocation(BaseInvocation):
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)
@ -98,27 +108,19 @@ class CompelInvocation(BaseInvocation):
if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(
conditionings=[
BasicConditioningInfo(
embeds=c,
extra_conditioning=ec,
)
]
)
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)
class SDXLPromptInvocationBase:
@ -132,7 +134,7 @@ class SDXLPromptInvocationBase:
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
@ -159,7 +161,7 @@ class SDXLPromptInvocationBase:
)
else:
c_pooled = None
return c, c_pooled, None
return c, c_pooled
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
@ -191,7 +193,7 @@ class SDXLPromptInvocationBase:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
@ -204,17 +206,12 @@ class SDXLPromptInvocationBase:
log_tokenization_for_conjunction(conjunction, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
if get_pooled:
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
else:
c_pooled = None
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
del tokenizer
del text_encoder
del tokenizer_info
@ -224,7 +221,7 @@ class SDXLPromptInvocationBase:
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, ec
return c, c_pooled
@invocation(
@ -232,7 +229,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -255,20 +252,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(
c2, c2_pooled = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
)
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -307,17 +303,19 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1),
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)
@invocation(
@ -345,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -354,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
)
conditioning_name = context.conditioning.save(conditioning_data)

View File

@ -35,22 +35,16 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
class ControlField(BaseModel):
@ -641,3 +635,27 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
resolution=self.image_resolution,
)
return processed_image
@invocation(
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
"""Resize an image using a heuristic method. Preserves edge maps."""
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, gt=0, description="The width to resize to (px)")
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel):
gradient: bool = Field(default=False, description="Used for gradient inpainting")
class TensorField(BaseModel):
"""A tensor primitive field."""
tensor_name: str = Field(description="The name of a tensor.")
class LatentsField(BaseModel):
"""A latents tensor primitive field"""
@ -226,7 +232,11 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
# endregion
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
class MetadataField(RootModel[dict[str, Any]]):

View File

@ -1,154 +1,91 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from abc import abstractmethod
from typing import Literal, get_args
import math
from typing import Literal, Optional, get_args
import numpy as np
from PIL import Image, ImageOps
from PIL import Image
from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.infill_methods.lama import LaMA
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
from invokeai.backend.image_util.infill_methods.tile import infill_tile
from invokeai.backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
logger = InvokeAILogger.get_logger()
def infill_methods() -> list[str]:
methods = ["tile", "solid", "lama", "cv2"]
def get_infill_methods():
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
INFILL_METHODS = get_infill_methods()
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
def infill_lama(im: Image.Image) -> Image.Image:
lama = LaMA()
return lama(im)
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Base class for invocations that preprocess images for Infilling"""
image: ImageField = InputField(description="The image to process")
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
"""Infill the image with the specified method"""
pass
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
image = context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def invoke(self, context: InvocationContext) -> ImageOutput:
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image(context)
# If the input image has no alpha channel, return it
if has_alpha is False:
return ImageOutput.build(context.images.get_dto(self.image.image_name))
def infill_cv2(im: Image.Image) -> Image.Image:
return cv2_inpaint(im)
# Perform Infill action
infilled_image = self.infill(input_image)
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum() # noqa: E712
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
# Return Infilled Image
return ImageOutput.build(infilled_image_dto)
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillColorInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with a solid color"""
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
def infill(self, image: Image.Image):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
return infilled
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillTileInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
default=0,
@ -157,92 +94,74 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The seed to use for tile generation (omit for random)",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
)
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
image: ImageField = InputField(description="The image to infill")
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name).convert("RGBA")
def infill(self, image: Image.Image):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
infill_image = image.copy()
width = int(image.width / self.downscale)
height = int(image.height / self.downscale)
infill_image = infill_image.resize(
infilled = image.resize(
(width, height),
resample=resample_mode,
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(infill_image)
else:
raise ValueError("PatchMatch is not available on this system")
infilled = infill_patchmatch(image)
infilled = infilled.resize(
(image.width, image.height),
resample=resample_mode,
)
infilled.paste(image, (0, 0), mask=image.split()[-1])
# image.paste(infilled, (0, 0), mask=image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
return infilled
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Downloads the LaMa model if it doesn't already exist
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
)
infilled = infill_lama(image.copy())
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
lama = LaMA()
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
return cv2_inpaint(image)
# @invocation(
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
# )
class MosaicInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
image: ImageField = InputField(description="The image to infill")
tile_width: int = InputField(default=64, description="Width of the tile")
tile_height: int = InputField(default=64, description="Height of the tile")
min_color: ColorField = InputField(
default=ColorField(r=0, g=0, b=0, a=255),
description="The min threshold for color",
)
max_color: ColorField = InputField(
default=ColorField(r=255, g=255, b=255, a=255),
description="The max threshold for color",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = infill_cv2(image.copy())
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@ -1,34 +1,41 @@
from builtins import float
from typing import List, Union
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
mask: Optional[TensorField] = Field(
default=None,
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
"regions should be set to True.",
)
@field_validator("weight")
@classmethod
@ -48,12 +55,15 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
@ -61,16 +71,26 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="ViT-H",
ui_order=2,
)
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this IP-Adapter applies to."
)
@field_validator("weight")
@classmethod
@ -86,35 +106,68 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "composition":
if ip_adapter_info.base == "sd-1":
target_blocks = ["down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight,
target_blocks=target_blocks,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,
),
)
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
if not len(image_encoder_models) > 0:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
)
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
if len(image_encoder_models) == 0:
context.logger.error("Error while fetching CLIP Vision Image Encoder")
assert len(image_encoder_models) == 1
return image_encoder_models[0]

View File

@ -1,5 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import math
from contextlib import ExitStack
from functools import singledispatchmethod
@ -9,6 +9,7 @@ import einops
import numpy as np
import numpy.typing as npt
import torch
import torchvision
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin
@ -43,44 +44,41 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import (
DenoiseMaskOutput,
ImageOutput,
LatentsOutput,
)
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
@invocation_output("scheduler_output")
@ -188,7 +186,7 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@ -201,6 +199,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == "float32",
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
@ -236,8 +260,27 @@ class CreateGradientMaskInvocation(BaseInvocation):
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
@ -284,10 +327,10 @@ def get_scheduler(
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
positive_conditioning: ConditioningField = InputField(
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
)
negative_conditioning: ConditioningField = InputField(
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
)
noise: Optional[LatentsField] = InputField(
@ -298,7 +341,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: Union[float, List[float]] = InputField(
default=7.5, ge=1, description=FieldDescriptions.cfg_scale, title="CFG Scale"
default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale"
)
denoising_start: float = InputField(
default=0.0,
@ -365,33 +408,173 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
), regions
return BasicConditioningInfo(embeds=text_embedding), regions
def get_conditioning_data(
self,
context: InvocationContext,
scheduler: Scheduler,
unet: UNet2DConditionModel,
seed: int,
) -> ConditioningData:
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
latent_height: int,
latent_width: int,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
# flip all bits to have noise different from initial
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
)
return conditioning_data
@ -497,8 +680,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
self,
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
@ -514,7 +699,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
return None
ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model)
@ -537,16 +721,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
mask = single_ip_adapter.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
)
)
@ -636,6 +824,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
steps: int,
denoising_start: float,
denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False):
@ -664,7 +853,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order
return num_inference_steps, timesteps, init_timestep
scheduler_step_kwargs = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor
@ -758,7 +955,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
controlnet_data = self.prep_control_data(
context=context,
@ -772,16 +973,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
result_latents = pipeline.latents_from_embeddings(
@ -794,6 +998,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
@ -803,12 +1008,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@invocation(
@ -872,9 +1075,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
@ -886,9 +1087,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
@ -927,9 +1126,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
@ -940,9 +1137,8 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -969,8 +1165,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
@ -982,9 +1177,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1116,8 +1309,7 @@ class BlendLatentsInvocation(BaseInvocation):
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
@ -1170,9 +1362,8 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
@ -1263,7 +1454,7 @@ class IdealSizeInvocation(BaseInvocation):
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(**self.unet.unet.model_dump())
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:

View File

@ -0,0 +1,120 @@
import numpy as np
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.1",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a rectangular mask."""
width: int = InputField(description="The width of the entire mask.")
height: int = InputField(description="The height of the entire mask.")
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
True
)
mask_tensor_name = context.tensors.save(mask)
return MaskOutput(
mask=TensorField(tensor_name=mask_tensor_name),
width=self.width,
height=self.height,
)
@invocation(
"alpha_mask_to_tensor",
title="Alpha Mask to Tensor",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class AlphaMaskToTensorInvocation(BaseInvocation):
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
image: ImageField = InputField(description="The mask image to convert.")
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name)
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
else:
mask[0] = torch.tensor(np.array(image)[:, :, 3] > 0, dtype=torch.bool)
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(mask)),
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"invert_tensor_mask",
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
"""Inverts a tensor mask."""
mask: TensorField = InputField(description="The tensor mask to convert.")
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.tensors.load(self.mask.tensor_name)
inverted = ~mask
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(inverted)),
height=inverted.shape[1],
width=inverted.shape[2],
)
@invocation(
"image_mask_to_tensor",
title="Image Mask to Tensor",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
"""Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value."""
image: ImageField = InputField(description="The mask image to convert.")
cutoff: int = InputField(ge=0, le=255, description="Cutoff (<)", default=128)
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name, mode="L")
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool)
else:
mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool)
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(mask)),
height=mask.shape[1],
width=mask.shape[2],
)

View File

@ -2,16 +2,7 @@ from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -22,6 +13,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from ...version import __version__
@ -43,6 +35,8 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
@field_validator("seed", mode="before")
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)

View File

@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
TensorField,
UIComponent,
)
from invokeai.app.services.images.images_common import ImageDTO
@ -405,9 +406,19 @@ class ColorInvocation(BaseInvocation):
# endregion
# region Conditioning
@invocation_output("mask_output")
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor."""
mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@ -8,11 +8,11 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
class T2IAdapterField(BaseModel):

View File

@ -4,7 +4,6 @@ from typing import Literal
import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict
@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image)

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import locale
import os
import re
import shutil
@ -26,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@ -104,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
@ -317,11 +318,10 @@ class InvokeAIAppConfig(BaseSettings):
@staticmethod
def find_root() -> Path:
"""Choose the runtime root directory when not specified on command line or init file."""
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
root = (venv.parent).resolve()
elif venv := os.environ.get("VIRTUAL_ENV", None):
root = Path(venv).parent.resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@ -402,7 +427,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path) as file:
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)

View File

@ -318,10 +318,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
pc_path_max = (
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
) # hardcoded for windows with long names enabled
pc_name_max = get_pc_name_max(directory)
pc_path_max = get_pc_path_max(directory)
if "/" in filename:
return False
if filename.startswith(".."):
@ -419,6 +417,26 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._logger.warning(excp)
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):
try:
return os.pathconf(directory, "PC_NAME_MAX")
except OSError:
# macOS w/ external drives raise OSError
pass
return 260 # hardcoded for windows
def get_pc_path_max(directory: str) -> int:
if hasattr(os, "pathconf"):
try:
return os.pathconf(directory, "PC_PATH_MAX")
except OSError:
# some platforms may not have this value
pass
return 32767 # hardcoded for windows with long names enabled
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))

View File

@ -1,8 +1,8 @@
"""Model installation class."""
import locale
import os
import re
import signal
import threading
import time
from hashlib import sha256
@ -12,6 +12,7 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
@ -41,7 +42,8 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
@ -110,17 +112,6 @@ class ModelInstallService(ModelInstallServiceBase):
def start(self, invoker: Optional[Invoker] = None) -> None:
"""Start the installer thread."""
# Yes, this is weird. When the installer thread is running, the
# thread masks the ^C signal. When we receive a
# sigINT, we stop the thread, reset sigINT, and send a new
# sigINT to the parent process.
def sigint_handler(signum, frame):
self.stop()
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.raise_signal(signal.SIGINT)
signal.signal(signal.SIGINT, sigint_handler)
with self._lock:
if self._running:
raise Exception("Attempt to start the installer service twice")
@ -130,7 +121,8 @@ class ModelInstallService(ModelInstallServiceBase):
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
# We should only do the scan when the flag is set (which should only be set when testing).
if self.app_config.scan_models_on_startup:
self._register_orphaned_models()
with catch_sigint():
self._register_orphaned_models()
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
@ -323,7 +315,8 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
if legacy_models_yaml_path.exists():
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
legacy_models_yaml = yaml.safe_load(file)
yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
@ -564,7 +557,7 @@ class ModelInstallService(ModelInstallServiceBase):
# The model is not in the models directory - we don't need to move it.
return model
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
@ -632,11 +625,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
@ -752,6 +744,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
@ -760,6 +754,7 @@ class ModelInstallService(ModelInstallServiceBase):
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job
def _stat_size(self, path: Path) -> int:

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@ -80,8 +82,9 @@ class ModelManagerService(ModelManagerServiceBase):
ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@ -1,6 +1,6 @@
import shutil
import tempfile
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@ -17,12 +17,6 @@ if TYPE_CHECKING:
T = TypeVar("T")
@dataclass
class DeleteAllResult:
deleted_count: int
freed_space_bytes: float
class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
@ -35,6 +29,12 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
self._ephemeral = ephemeral
self._base_output_dir = output_dir
self._base_output_dir.mkdir(parents=True, exist_ok=True)
if self._ephemeral:
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
shutil.rmtree(temp_dir)
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
self._tempdir = (
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None

View File

@ -86,6 +86,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self._resume_event.is_set():

View File

@ -245,6 +245,18 @@ class ImagesInterface(InvocationContextInterface):
"""
return self._services.images.get_dto(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
"""Gets the internal path to an image or thumbnail.
Args:
image_name: The name of the image to get the path of.
thumbnail: Get the path of the thumbnail instead of the full image
Returns:
The local path of the image or thumbnail.
"""
return self._services.images.get_path(image_name, thumbnail)
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:

View File

@ -1,13 +1,21 @@
from typing import Union
from typing import Any, Literal, Union
import cv2
import numpy as np
import torch
from controlnet_aux.util import HWC3
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange
from PIL import Image
from invokeai.backend.image_util.util import nms, normalize_image_channel_count
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
###################################################################
# Copy of scripts/lvminthin.py from Mikubill/sd-webui-controlnet
###################################################################
@ -68,17 +76,6 @@ def lvmin_thin(x, prunings=True):
return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
################################################################################
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
################################################################################
@ -134,98 +131,122 @@ def pixel_perfect_resolution(
return int(np.round(estimation))
def clone_contiguous(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
"""Get a memory-contiguous clone of the given numpy array, as a safety measure and to improve computation efficiency."""
return np.ascontiguousarray(x).copy()
def np_img_to_torch(np_img: np.ndarray[Any, Any], device: torch.device) -> torch.Tensor:
"""Convert a numpy image to a PyTorch tensor. The image is normalized to 0-1, rearranged to BCHW format and sent to
the specified device."""
torch_img = torch.from_numpy(np_img)
normalized = torch_img.float() / 255.0
bchw = rearrange(normalized, "h w c -> 1 c h w")
on_device = bchw.to(device)
return on_device.clone()
def heuristic_resize(np_img: np.ndarray[Any, Any], size: tuple[int, int]) -> np.ndarray[Any, Any]:
"""Resizes an image using a heuristic to choose the best resizing strategy.
- If the image appears to be an edge map, special handling will be applied to ensure the edges are not distorted.
- Single-pixel edge maps use NMS and thinning to keep the edges as single-pixel lines.
- Low-color-count images are resized with nearest-neighbor to preserve color information (for e.g. segmentation maps).
- The alpha channel is handled separately to ensure it is resized correctly.
Args:
np_img (np.ndarray): The input image.
size (tuple[int, int]): The target size for the image.
Returns:
np.ndarray: The resized image.
Adapted from https://github.com/Mikubill/sd-webui-controlnet.
"""
# Return early if the image is already at the requested size
if np_img.shape[0] == size[1] and np_img.shape[1] == size[0]:
return np_img
# If the image has an alpha channel, separate it for special handling later.
inpaint_mask = None
if np_img.ndim == 3 and np_img.shape[2] == 4:
inpaint_mask = np_img[:, :, 3]
np_img = np_img[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (np_img.shape[0] * np_img.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (np_img.shape[0] * np_img.shape[1])
unique_color_count = np.unique(np_img.reshape(-1, np_img.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
# If the image has only two colors, it is likely binary. Check if the image has one-pixel edges.
is_binary = np.min(np_img) < 16 and np.max(np_img) > 240
if is_binary:
eroded = cv2.erode(np_img, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
dilated = cv2.dilate(eroded, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(dilated < np_img)[0].shape[0]
all_edge_count = np.where(np_img > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
# With a low color count, we assume this is a map where exact colors are important. Near-neighbor preserves
# the colors as needed.
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
# This works best for downscaling
interpolation = cv2.INTER_AREA
else:
# Fall back for other cases
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
# This may be further transformed depending on the binary nature of the image.
resized = cv2.resize(np_img, size, interpolation=interpolation)
if inpaint_mask is not None:
# Resize the inpaint mask to match the resized image using the same interpolation method.
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
# If the image is binary, we will perform some additional processing to ensure the edges are preserved.
if is_binary:
resized = np.mean(resized.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
# Use NMS and thinning to keep the edges as single-pixel lines.
resized = nms(resized)
_, resized = cv2.threshold(resized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
resized = lvmin_thin(resized, prunings=new_size_is_bigger)
else:
_, resized = cv2.threshold(resized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
resized = np.stack([resized] * 3, axis=2)
# Restore the alpha channel if it was present.
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
resized = np.concatenate([resized, inpaint_mask], axis=2)
return resized
###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI
###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
# if 'inpaint' in module:
# np_img = np_img.astype(np.float32)
# else:
# np_img = HWC3(np_img)
np_img = HWC3(np_img)
def np_img_resize(
np_img: np.ndarray,
resize_mode: CONTROLNET_RESIZE_VALUES,
h: int,
w: int,
device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, np.ndarray[Any, Any]]:
np_img = normalize_image_channel_count(np_img)
def safe_numpy(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = y.copy()
y = np.ascontiguousarray(y)
y = y.copy()
return y
def get_pytorch_control(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, "h w c -> 1 c h w")
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray, size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h))
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
np_img = heuristic_resize(np_img, (w, h))
np_img = clone_contiguous(np_img)
return np_img_to_torch(np_img, device), np_img
old_h, old_w, _ = np_img.shape
old_w = float(old_w)
@ -236,7 +257,6 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def safeint(x: Union[int, float]) -> int:
return int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT
k = min(k0, k1)
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
@ -245,23 +265,23 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
# Inpaint hijack
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
np_img = high_quality_background
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
np_img = clone_contiguous(np_img)
return np_img_to_torch(np_img, device), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
np_img = clone_contiguous(np_img)
return np_img_to_torch(np_img, device), np_img
def prepare_control_image(
@ -269,12 +289,12 @@ def prepare_control_image(
width: int,
height: int,
num_channels: int = 3,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
device: str = "cuda",
dtype: torch.dtype = torch.float16,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Pre-process images for ControlNets or T2I-Adapters.
Args:
@ -292,26 +312,15 @@ def prepare_control_image(
resize_mode (str, optional): Defaults to "just_resize_simple".
Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range.
Returns:
torch.Tensor: The pre-processed input tensor.
"""
if (
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
or resize_mode == "fill_resize_simple"
):
if resize_mode == "just_resize_simple":
image = image.convert("RGB")
if resize_mode == "just_resize_simple":
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "fill_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
image = image.resize((width, height), resample=Image.LANCZOS)
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
@ -328,8 +337,7 @@ def prepare_control_image(
resize_mode=resize_mode,
h=height,
w=width,
# device=torch.device('cpu')
device=device,
device=torch.device(device),
)
else:
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")

View File

@ -2,7 +2,7 @@
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch # noqa: F401
from .infill_methods.patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401
from .util import InitImageResizer, make_grid # noqa: F401

View File

@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
@ -56,7 +56,7 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
@ -81,7 +81,7 @@ class DepthAnythingDetector:
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())
self.model.to(self.device)
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
@ -94,7 +94,7 @@ class DepthAnythingDetector:
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)

View File

@ -7,7 +7,7 @@ import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
@ -28,9 +28,9 @@ config = get_config()
class Wholebody:
def __init__(self):
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)

View File

@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
non_maximum_suppression,
nms,
normalize_image_channel_count,
np_to_pil,
pil_to_np,
@ -134,7 +134,7 @@ class HEDProcessor:
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
if scribble:
detected_map = non_maximum_suppression(detected_map, 127, 3.0)
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0

View File

@ -7,7 +7,8 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
def norm_img(np_img):
@ -28,8 +29,16 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))

View File

@ -0,0 +1,60 @@
from typing import Tuple
import numpy as np
from PIL import Image
def infill_mosaic(
image: Image.Image,
tile_shape: Tuple[int, int] = (64, 64),
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
) -> Image.Image:
"""
image:PIL - A PIL Image
tile_shape: Tuple[int,int] - Tile width & Tile Height
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
"""
np_image = np.array(image) # Convert image to np array
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
# Create color tiles to paste in the empty areas of the image
tile_width, tile_height = tile_shape
# Clip the range of colors in the image to a particular spectrum only
r_min, g_min, b_min, _ = min_color
r_max, g_max, b_max, _ = max_color
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
tiles = []
for _ in range(256):
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
tile[:, :] = color
tiles.append(tile)
# Fill the transparent area with tiles
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
for x in range(image.width):
for y in range(image.height):
tile = tiles[np.random.randint(len(tiles))]
try:
filled_image[
y - (y % tile_height) : y - (y % tile_height) + tile_height,
x - (x % tile_width) : x - (x % tile_width) + tile_width,
] = tile
except ValueError:
# Need to handle edge cases - literally
pass
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
image = Image.composite(
image, filled_image, image.split()[-1]
) # Composite the original image on top of the filled tiles
return image

View File

@ -0,0 +1,67 @@
"""
This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
"""
patch_match = None
tried_load: bool = False
def __init__(self):
super().__init__()
@classmethod
def _load_patch_match(cls):
if cls.tried_load:
return
if get_config().patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
logger.info("Patchmatch initialized")
cls.patch_match = pm
else:
logger.info("Patchmatch not loaded (nonfatal)")
else:
logger.info("Patchmatch loading disabled")
cls.tried_load = True
@classmethod
def patchmatch_available(cls) -> bool:
cls._load_patch_match()
if not cls.patch_match:
return False
return cls.patch_match.patchmatch_available
@classmethod
def inpaint(cls, image: Image.Image) -> Image.Image:
if cls.patch_match is None or not cls.patchmatch_available():
return image
np_image = np.array(image)
mask = 255 - np_image[:, :, 3]
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
return Image.fromarray(infilled, mode="RGB")
def infill_patchmatch(image: Image.Image) -> Image.Image:
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
if not IS_PATCHMATCH_AVAILABLE:
logger.warning("PatchMatch is not available on this system")
return image
return PatchMatch.inpaint(image)

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Smoke test for the tile infill\"\"\"\n",
"\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"from PIL import Image\n",
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
"\n",
"images: list[tuple[str, Image.Image]] = []\n",
"\n",
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
" images.append((i.name, Image.open(i)))\n",
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
"\n",
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
"\n",
"for name, image in images:\n",
" try:\n",
" output = infill_tile(image, seed=0, tile_size=32)\n",
" outputs.append((name, image, output.infilled, output.tile_image))\n",
" except ValueError as e:\n",
" print(f\"Skipping image {name}: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display the images in jupyter notebook\n",
"import matplotlib.pyplot as plt\n",
"from PIL import ImageOps\n",
"\n",
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
"plt.subplots_adjust(hspace=0)\n",
"\n",
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
" # Add a border to each image, helps to see the edges\n",
" size = original.size\n",
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
" if tile_image:\n",
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
"\n",
" axes[i, 0].imshow(original)\n",
" axes[i, 0].axis(\"off\")\n",
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
"\n",
" if tile_image:\n",
" axes[i, 1].imshow(tile_image)\n",
" axes[i, 1].axis(\"off\")\n",
" axes[i, 1].set_title(\"Tile Image\")\n",
" else:\n",
" axes[i, 1].axis(\"off\")\n",
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
"\n",
" axes[i, 2].imshow(filled)\n",
" axes[i, 2].axis(\"off\")\n",
" axes[i, 2].set_title(\"Filled\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".invokeai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,122 @@
from dataclasses import dataclass
from typing import Optional
import numpy as np
from PIL import Image
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
"""
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
Args:
img_array: numpy array of the image.
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
Returns:
A list of numpy arrays, each representing a tile.
"""
tiles: list[np.ndarray] = []
rows, cols = img_array.shape[:2]
tile_width, tile_height = tile_size
for y in range(0, rows - tile_height + 1, tile_height):
for x in range(0, cols - tile_width + 1, tile_width):
tile = img_array[y : y + tile_height, x : x + tile_width]
# Check if the image has an alpha channel and the tile is completely opaque
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
tiles.append(tile)
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
tiles.append(tile)
if not tiles:
raise ValueError(
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
)
return tiles
def create_filled_image(
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
) -> np.ndarray:
"""
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
Args:
img_array: numpy array of the original image.
tile_pool: A list of numpy arrays, each representing a tile.
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
Returns:
A numpy array representing the filled image.
"""
rows, cols, _ = img_array.shape
tile_width, tile_height = tile_size
# Prep an empty RGB image
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
# Make the random tile selection reproducible
rng = np.random.default_rng(seed)
for y in range(0, rows, tile_height):
for x in range(0, cols, tile_width):
# Pick a random tile from the pool
tile = tile_pool[rng.integers(len(tile_pool))]
# Calculate the space available (may be less than tile size near the edges)
space_y = min(tile_height, rows - y)
space_x = min(tile_width, cols - x)
# Crop the tile if necessary to fit into the available space
cropped_tile = tile[:space_y, :space_x, :3]
# Fill the available space with the (possibly cropped) tile
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
return filled_img_array
@dataclass
class InfillTileOutput:
infilled: Image.Image
tile_image: Optional[Image.Image] = None
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
"""Infills an image with random tiles from the image itself.
If the image is not an RGBA image, it is returned untouched.
Args:
image: The image to infill.
tile_size: The size of the tiles to use for infilling.
Raises:
ValueError: If there are not enough opaque pixels to generate any tiles.
"""
if image_to_infill.mode != "RGBA":
return InfillTileOutput(infilled=image_to_infill)
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
_tile_size = (tile_size, tile_size)
np_image = np.array(image_to_infill, dtype=np.uint8)
# Create the pool of tiles that we will use to infill
tile_pool = create_tile_pool(np_image, _tile_size)
# Create an image from the tiles, same size as the original
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
# Paste the OG image over the tile image, effectively infilling the area
tile_image = Image.fromarray(tile_np_image, "RGB")
infilled = tile_image.copy()
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
# I think we want this to be "RGBA"?
infilled.convert("RGBA")
return InfillTileOutput(infilled=infilled, tile_image=tile_image)

View File

@ -1,49 +0,0 @@
"""
This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
"""
patch_match = None
tried_load: bool = False
def __init__(self):
super().__init__()
@classmethod
def _load_patch_match(self):
if self.tried_load:
return
if get_config().patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
logger.info("Patchmatch initialized")
else:
logger.info("Patchmatch not loaded (nonfatal)")
self.patch_match = pm
else:
logger.info("Patchmatch loading disabled")
self.tried_load = True
@classmethod
def patchmatch_available(self) -> bool:
self._load_patch_match()
return self.patch_match and self.patch_match.patchmatch_available
@classmethod
def inpaint(self, *args, **kwargs) -> np.ndarray:
if self.patchmatch_available():
return self.patch_match.inpaint(*args, **kwargs)

View File

@ -11,7 +11,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
"""
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
@ -65,7 +65,7 @@ class RealESRGAN:
self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None
self.half = half
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))

View File

@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -51,7 +51,7 @@ class SafetyChecker:
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)

View File

@ -1,4 +1,5 @@
from math import ceil, floor, sqrt
from typing import Optional
import cv2
import numpy as np
@ -143,20 +144,21 @@ def resize_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.n
h = float(input_image.shape[0])
w = float(input_image.shape[1])
scaling_factor = float(resolution) / min(h, w)
h *= scaling_factor
w *= scaling_factor
h = int(np.round(h / 64.0)) * 64
w = int(np.round(w / 64.0)) * 64
h = int(h * scaling_factor)
w = int(w * scaling_factor)
if scaling_factor > 1:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_LANCZOS4)
else:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA)
def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
def nms(np_img: np.ndarray, threshold: Optional[int] = None, sigma: Optional[float] = None) -> np.ndarray:
"""
Apply non-maximum suppression to an image.
If both threshold and sigma are provided, the image will blurred before the suppression and thresholded afterwards,
resulting in a binary output image.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
@ -166,23 +168,36 @@ def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
Returns:
The image after non-maximum suppression.
Raises:
ValueError: If only one of threshold and sigma provided.
"""
image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma)
# Raise a value error if only one of threshold and sigma is provided
if (threshold is None) != (sigma is None):
raise ValueError("Both threshold and sigma must be provided if one is provided.")
if sigma is not None and threshold is not None:
# Blurring the image can help to thin out features
np_img = cv2.GaussianBlur(np_img.astype(np.float32), (0, 0), sigma)
filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(image)
nms_img = np.zeros_like(np_img)
for f in [filter_1, filter_2, filter_3, filter_4]:
np.putmask(y, cv2.dilate(image, kernel=f) == image, image)
np.putmask(nms_img, cv2.dilate(np_img, kernel=f) == np_img, np_img)
z = np.zeros_like(y, dtype=np.uint8)
z[y > threshold] = 255
return z
if sigma is not None and threshold is not None:
# We blurred - now threshold to get a binary image
thresholded = np.zeros_like(nms_img, dtype=np.uint8)
thresholded[nms_img > threshold] = 255
return thresholded
return nms_img
def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:

View File

@ -1,182 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
assert len(weights) == len(scales)
self._weights = weights
self._scales = scales
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@ -1,8 +1,11 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from typing import Optional, Union
import pathlib
from typing import List, Optional, TypedDict, Union
import safetensors
import safetensors.torch
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@ -13,10 +16,17 @@ from ..raw_model import RawModel
from .resampler import Resampler
class IPAdapterStateDict(TypedDict):
ip_adapter: dict[str, torch.Tensor]
image_proj: dict[str, torch.Tensor]
class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
):
super().__init__()
self.cross_attention_dim = cross_attention_dim
@ -25,7 +35,7 @@ class ImageProjModel(torch.nn.Module):
self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -45,7 +55,7 @@ class ImageProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
@ -57,7 +67,7 @@ class ImageProjModel(torch.nn.Module):
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
super().__init__()
self.proj = torch.nn.Sequential(
@ -68,7 +78,7 @@ class MLPProjModel(torch.nn.Module):
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -87,7 +97,7 @@ class MLPProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
def forward(self, image_embeds: torch.Tensor):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -97,7 +107,7 @@ class IPAdapter(RawModel):
def __init__(
self,
state_dict: dict[str, torch.Tensor],
state_dict: IPAdapterStateDict,
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@ -129,24 +139,27 @@ class IPAdapter(RawModel):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
@ -157,31 +170,32 @@ class IPAdapterPlus(IPAdapter):
).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict):
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
@ -192,24 +206,48 @@ class IPAdapterPlusXL(IPAdapterPlus):
).to(self.device, dtype=self.dtype)
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
return state_dict
def build_ip_adapter(
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
# IPAdapter (with ImageProjModel)
if "proj.weight" in state_dict["image_proj"]:
return IPAdapter(state_dict, device=device, dtype=dtype)
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
elif "proj_in.weight" in state_dict["image_proj"]:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
# IPAdapterFull (with MLPProjModel)
elif "proj.0.weight" in state_dict["image_proj"]:
return IPAdapterFull(state_dict, device=device, dtype=dtype)
# Unrecognized IP Adapter Architectures
else:
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")

View File

@ -9,8 +9,8 @@ import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
def FeedForward(dim: int, mult: int = 4):
inner_dim = dim * mult
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
def reshape_tensor(x: torch.Tensor, heads: int):
bs, length, _ = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
def forward(self, x: torch.Tensor, latents: torch.Tensor):
"""
Args:
x (torch.Tensor): image features
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
dim: int = 1024,
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
embedding_dim: int = 768,
output_dim: int = 1024,
ff_mult: int = 4,
):
super().__init__()
@ -110,7 +110,15 @@ class Resampler(nn.Module):
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
def from_state_dict(
cls,
state_dict: dict[str, torch.Tensor],
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ff_mult: int = 4,
):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
@ -145,7 +153,7 @@ class Resampler(nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, x):
def forward(self, x: torch.Tensor):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)

View File

@ -1,53 +0,0 @@
from contextlib import contextmanager
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
class UNetPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
def __init__(self, ip_adapters: list[IPAdapter]):
self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float):
self._scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales,
)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -301,12 +301,12 @@ class MainConfigBase(ModelConfigBase):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
class IPAdapterBaseConfig(ModelConfigBase):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
# TO DO: The loader is not thread safe!
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@ -117,7 +117,7 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def stats(self) -> CacheStats:
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
pass

View File

@ -30,15 +30,12 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
@ -269,12 +264,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
if torch.device(source_device).type == torch.device(target_device).type:
return
# may raise an exception here if insufficient GPU VRAM
self._check_free_vram(target_device, cache_entry.size)
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
try:
cache_entry.model.to(target_device)
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
@ -329,11 +326,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, model_size: int) -> None:
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
@ -388,12 +385,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
self._delete_cache_entry(cache_entry)
del cache_entry
else:
@ -415,18 +411,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.stats.cleared = models_cleared
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
vram_device = ( # mem_get_info() needs an indexed device
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]

View File

@ -34,7 +34,6 @@ class ModelLocker(ModelLockerBase):
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
@ -51,6 +50,7 @@ class ModelLocker(ModelLockerBase):
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:

View File

@ -7,19 +7,13 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""
@ -32,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
raise ValueError("There are no submodels in an IP-Adapter model.")
model_path = Path(config.path)
model: RawModel = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
ip_adapter_ckpt_path=model_path,
device=torch.device("cpu"),
dtype=self._torch_dtype,
)

View File

@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
from . import (
AnyModelConfig,
@ -43,6 +43,7 @@ class ModelMerger(object):
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer
self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models(
self,
@ -68,7 +69,7 @@ class ModelMerger(object):
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
@ -151,7 +152,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key

View File

@ -51,6 +51,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: "sd_xl_inpaint.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
@ -230,9 +231,10 @@ class ModelProbe(object):
return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
@ -323,7 +325,7 @@ class ModelProbe(object):
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
model = torch.load(model_path, map_location="cpu")
assert isinstance(model, dict)
return model
else:
@ -527,8 +529,25 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Class for probing IP Adapters"""
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
@ -768,7 +787,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
)
############## register probe classes ######
# Register probe classes
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)

View File

@ -155,7 +155,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
description="IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -163,7 +163,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter Plus",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_plus_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
description="Refined IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -171,7 +171,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter Plus Face",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_plus_face_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -179,7 +179,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter SDXL",
base=BaseModelType.StableDiffusionXL,
source="InvokeAI/ip_adapter_sdxl",
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
description="IP-Adapter for SDXL models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],

View File

@ -21,12 +21,11 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device
from invokeai.backend.util.devices import TorchDevice
@dataclass
@ -149,16 +148,6 @@ class ControlNetData:
resize_mode: str = Field(default="just_resize")
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
weight: Union[float, List[float]] = Field(default=1.0)
# weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
@ -266,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
@ -295,7 +284,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
@ -308,7 +298,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: Optional[int] = None,
seed: int,
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents
@ -326,20 +316,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
if mask is not None:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
@ -348,6 +324,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
@ -355,6 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
@ -380,7 +366,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
latents: torch.Tensor,
timesteps,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
@ -397,22 +384,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents
ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
@ -435,11 +422,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
@ -463,14 +450,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
t: torch.Tensor,
latents: torch.Tensor,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -485,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
# Handle ControlNet(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
@ -550,6 +520,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
ip_adapter_data=ip_adapter_data,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
@ -569,7 +540,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:

View File

@ -1,27 +1,17 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import math
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
@dataclass
class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo]
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
@ -35,6 +25,8 @@ class ConditioningFieldData:
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
@ -57,37 +49,75 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
scheduler_args: dict[str, Any] = field(default_factory=dict)
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0
@property
def dtype(self):
return self.text_embeddings.dtype
def scale_for_step(self, step_index: int, total_steps: int) -> float:
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
return weight
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
return 0.0
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass
class Range:
start: int
end: int
class TextConditioningRegions:
def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData:
def __init__(
self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -1,218 +0,0 @@
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
import enum
from dataclasses import dataclass, field
from typing import Optional
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.backend.util.devices import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class CrossAttnControlContext:
def __init__(self, arguments: Arguments):
"""
:param arguments: Arguments for the cross-attention control process
"""
self.cross_attention_mask: Optional[torch.Tensor] = None
self.cross_attention_index_map: Optional[torch.Tensor] = None
self.arguments = arguments
def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None
) -> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts["s_start"] <= percent_through < opts["s_end"]:
to_control.append(CrossAttentionType.SELF)
if opts["t_start"] <= percent_through < opts["t_end"]:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:return: None
"""
# adapted from init_attention_edit
device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length, dtype=torch_dtype(device))
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
) -> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if (
attention_type is CrossAttentionType.SELF
or swap_cross_attn_context is None
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
):
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask,
target_length=sequence_length,
batch_size=batch_size,
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
original_text_embeddings = encoder_hidden_states
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
# do slices
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attn_slice = torch.index_select(
original_attn_slice, -1, swap_cross_attn_context.index_map
)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# done
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -0,0 +1,214 @@
from dataclasses import dataclass
from typing import List, Optional, cast
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
Supported custom features:
- IP-Adapter
- Regional prompt attention
"""
def __init__(
self,
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
layer-specific are passed to __init__().
Args:
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
for the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_attention_weights = ip_adapter_attention_weights
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
apply regional prompt masking.
regional_ip_data: The IP-Adapter data for the current batch.
"""
# If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape
# Handle regional prompt attention masks.
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
# Apply IP-Adapter conditioning.
if is_cross_attention:
if self._ip_adapter_attention_weights:
assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@ -0,0 +1,72 @@
import torch
class RegionalIPData:
"""A class to manage the data for regional IP-Adapter conditioning."""
def __init__(
self,
image_prompt_embeds: list[torch.Tensor],
scales: list[float],
masks: list[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
max_downscale_factor: int = 8,
):
"""Initialize a `IPAdapterConditioningData` object."""
assert len(image_prompt_embeds) == len(scales) == len(masks)
# The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
self.image_prompt_embeds = image_prompt_embeds
# The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales
# The IP-Adapter masks.
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
# regions and 0.0 for excluded regions.
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
def _prepare_masks(
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
) -> dict[int, torch.Tensor]:
"""Prepare the masks for the IP-Adapter attention."""
# Concatenate the masks so that they can be processed more efficiently.
mask_tensor = torch.cat(masks, dim=1)
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
masks_by_seq_len: dict[int, torch.Tensor] = {}
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, num_ip_adapters, h, w = mask_tensor.shape
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
assert b == 1
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
# the spatial features.
query_seq_len = h * w
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
return masks_by_seq_len
def get_masks(self, query_seq_len: int) -> torch.Tensor:
"""Get the mask for the given query sequence length."""
return self._masks_by_seq_len[query_seq_len]

View File

@ -0,0 +1,105 @@
import torch
import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:
"""A class to manage the prompt data for regional conditioning."""
def __init__(
self,
regions: list[TextConditioningRegions],
device: torch.device,
dtype: torch.dtype,
max_downscale_factor: int = 8,
):
"""Initialize a `RegionalPromptData` object.
Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x.
"""
self._regions = regions
self._device = device
self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = -10000.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors."""
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_sample_regions in regions:
batch_sample_masks_by_seq_len.append({})
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, _num_prompts, h, w = batch_sample_masks.shape
assert b == 1
query_seq_len = h * w
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
return batch_sample_masks_by_seq_len
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
Returns:
torch.Tensor: The cross-attention score mask.
shape: (batch_size, query_seq_len, key_seq_len).
dtype: float
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
batch_sample_query_mask = batch_sample_query_scores > 0.5
batch_sample_query_scores[batch_sample_query_mask] = 0.0
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
return attn_mask

View File

@ -1,26 +1,20 @@
from __future__ import annotations
import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from .cross_attention_control import (
CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext,
setup_cross_attention_control_attention_processors,
IPAdapterData,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
@ -58,31 +52,8 @@ class InvokeAIDiffuserComponent:
self.conditioning = None
self.model = model
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
):
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = CrossAttnControlContext(
arguments=extra_conditioning_info.cross_attention_control_args,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors)
def do_controlnet_step(
self,
control_data,
@ -90,7 +61,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data,
conditioning_data: TextConditioningData,
):
down_block_res_samples, mid_block_res_sample = None, None
@ -123,28 +94,28 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
@ -153,8 +124,8 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
conditioning_data.uncond_text.embeds,
conditioning_data.cond_text.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
@ -198,24 +169,15 @@ class InvokeAIDiffuserComponent:
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance:
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
if self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
@ -223,7 +185,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -236,6 +200,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -294,53 +261,84 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(
self,
x,
sigma,
conditioning_data: ConditioningData,
x: torch.Tensor,
sigma: torch.Tensor,
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {}
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
)
both_results = self.model_forward_callback(
x_twice,
@ -360,8 +358,10 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
cross_attention_control_types_to_do: list[CrossAttentionType],
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -391,53 +391,48 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
#####################
# Unconditioned pass
#####################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.uncond_text.add_time_ids,
}
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
@ -449,36 +444,43 @@ class InvokeAIDiffuserComponent:
# Conditioned pass
###################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,

View File

@ -0,0 +1,68 @@
from contextlib import contextmanager
from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them (if IP-Adapters are being applied).
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor") or self._ip_adapters is None:
# "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@ -2,7 +2,6 @@
Initialization file for invokeai.backend.util
"""
from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger
from .util import GIG, Chdir, directory_size
@ -11,6 +10,4 @@ __all__ = [
"directory_size",
"Chdir",
"InvokeAILogger",
"choose_precision",
"choose_torch_device",
]

View File

@ -0,0 +1,29 @@
"""
This module defines a context manager `catch_sigint()` which temporarily replaces
the sigINT handler defined by the ASGI in order to allow the user to ^C the application
and shut it down immediately. This was implemented in order to allow the user to interrupt
slow model hashing during startup.
Use like this:
from invokeai.backend.util.catch_sigint import catch_sigint
with catch_sigint():
run_some_hard_to_interrupt_process()
"""
import signal
from contextlib import contextmanager
from typing import Generator
def sigint_handler(signum, frame): # type: ignore
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.raise_signal(signal.SIGINT)
@contextmanager
def catch_sigint() -> Generator[None, None, None]:
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, sigint_handler)
yield
signal.signal(signal.SIGINT, original_handler)

View File

@ -1,91 +1,110 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Literal, Optional, Union
from typing import Dict, Literal, Optional, Union
import torch
from torch import autocast
from deprecated import deprecated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def choose_precision(device: torch.device) -> TorchPrecisionNames:
"""Return the string representation of the recommended torch device."""
torch_dtype = TorchDevice.choose_torch_dtype(device)
return PRECISION_TO_NAME[torch_dtype]
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
"""Return the torch.device to use for accelerated inference."""
return TorchDevice.choose_torch_device()
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def torch_dtype(device: torch.device) -> torch.dtype:
"""Return the torch precision for the recommended torch device."""
return TorchDevice.choose_torch_dtype(device)
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
class TorchDevice:
"""Abstraction layer for torch devices."""
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
elif torch.cuda.is_available():
device = CUDA_DEVICE
elif torch.backends.mps.is_available():
device = MPS_DEVICE
else:
return CPU_DEVICE
else:
return torch.device(config.device)
device = CPU_DEVICE
return cls.normalize(device)
def get_torch_device_name() -> str:
device = choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def choose_precision(
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device."""
app_config = app_config or get_config()
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
if app_config.precision == "float32":
return "float32"
elif app_config.precision == "bfloat16":
return "bfloat16"
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return cls._to_dtype("float32")
elif config.precision == "auto":
# Default to float16 for CUDA devices
return cls._to_dtype("float16")
else:
return "float16"
elif device.type == "mps":
return "float16"
return "float32"
# Use the user-defined precision
return cls._to_dtype(config.precision)
elif device.type == "mps" and torch.backends.mps.is_available():
if config.precision == "auto":
# Default to float16 for MPS devices
return cls._to_dtype("float16")
else:
# Use the user-defined precision
return cls._to_dtype(config.precision)
# CPU / safe fallback
return cls._to_dtype("float32")
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def torch_dtype(
device: Optional[torch.device] = None,
app_config: Optional[InvokeAIAppConfig] = None,
) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device, app_config)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
else:
# "auto", "autocast", "float32"
return torch.float32
@classmethod
def get_torch_device_name(cls) -> str:
"""Return the device name for the current torch device."""
device = cls.choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
def choose_autocast(precision):
"""Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == "autocast" or precision == "float16":
return autocast
return nullcontext
def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == "cuda":
@classmethod
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
"""Add the device index to CUDA devices."""
device = torch.device(device)
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
device = torch.device(device.type, torch.cuda.current_device())
return device
return device
@classmethod
def empty_cache(cls) -> None:
"""Clear the GPU device cache."""
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]

View File

@ -0,0 +1,53 @@
import torch
def to_standard_mask_dim(mask: torch.Tensor) -> torch.Tensor:
"""Standardize the dimensions of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The shape can be (1, h, w) or (h, w).
Returns:
torch.Tensor: The output mask tensor. The shape is (1, h, w).
"""
# Get the mask height and width.
if mask.ndim == 2:
mask = mask.unsqueeze(0)
elif mask.ndim == 3 and mask.shape[0] == 1:
pass
else:
raise ValueError(f"Unsupported mask shape: {mask.shape}. Expected (1, h, w) or (h, w).")
return mask
def to_standard_float_mask(mask: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
"""Standardize the format of a mask tensor.
Args:
mask (torch.Tensor): A mask tensor. The dtype can be any bool, float, or int type. The shape must be (1, h, w)
or (h, w).
out_dtype (torch.dtype): The dtype of the output mask tensor. Must be a float type.
Returns:
torch.Tensor: The output mask tensor. The dtype is out_dtype. The shape is (1, h, w). All values are either 0.0
or 1.0.
"""
if not out_dtype.is_floating_point:
raise ValueError(f"out_dtype must be a float type, but got {out_dtype}")
mask = to_standard_mask_dim(mask)
mask = mask.to(out_dtype)
# Set masked regions to 1.0.
if mask.dtype == torch.bool:
mask = mask.to(out_dtype)
else:
mask = mask.to(out_dtype)
mask_region = mask > 0.5
mask[mask_region] = 1.0
mask[~mask_region] = 0.0
return mask

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -11,6 +11,7 @@ import { createStore } from '../src/app/store/store';
// @ts-ignore
import translationEN from '../public/locales/en.json';
import { ReduxInit } from './ReduxInit';
import { $store } from 'app/store/nanostores/store';
i18n.use(initReactI18next).init({
lng: 'en',
@ -25,6 +26,7 @@ i18n.use(initReactI18next).init({
});
const store = createStore(undefined, false);
$store.set(store);
$baseUrl.set('http://localhost:9090');
const preview: Preview = {

View File

@ -8,7 +8,7 @@
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>Invoke - Community Edition</title>
<link rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
<link id="invoke-favicon" rel="icon" type="icon" href="assets/images/invoke-favicon.svg" />
<style>
html,
body {
@ -23,4 +23,4 @@
<script type="module" src="/src/main.tsx"></script>
</body>
</html>
</html>

View File

@ -1,6 +1,7 @@
import type { KnipConfig } from 'knip';
const config: KnipConfig = {
project: ['src/**/*.{ts,tsx}!'],
ignore: [
// This file is only used during debugging
'src/app/store/middleware/debugLoggerMiddleware.ts',
@ -10,6 +11,9 @@ const config: KnipConfig = {
'src/features/nodes/types/v2/**',
],
ignoreBinaries: ['only-allow'],
paths: {
'public/*': ['public/*'],
},
};
export default config;

View File

@ -24,8 +24,8 @@
"build": "pnpm run lint && vite build",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip --tags=-@knipignore",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
"lint:knip": "knip",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:0 src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit",
@ -52,6 +52,7 @@
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.1",
"@dagrejs/graphlib": "^2.2.1",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
@ -94,11 +95,13 @@
"reactflow": "^11.10.4",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.3.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"use-debounce": "^10.0.0",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^9.0.1",
"zod": "^3.22.4",

View File

@ -11,6 +11,9 @@ dependencies:
'@chakra-ui/react-use-size':
specifier: ^2.1.0
version: 2.1.0(react@18.2.0)
'@dagrejs/dagre':
specifier: ^1.1.1
version: 1.1.1
'@dagrejs/graphlib':
specifier: ^2.2.1
version: 2.2.1
@ -137,6 +140,9 @@ dependencies:
redux-remember:
specifier: ^5.1.0
version: 5.1.0(redux@5.0.1)
redux-undo:
specifier: ^1.1.0
version: 1.1.0
rfdc:
specifier: ^1.3.1
version: 1.3.1
@ -152,6 +158,9 @@ dependencies:
use-debounce:
specifier: ^10.0.0
version: 10.0.0(react@18.2.0)
use-device-pixel-ratio:
specifier: ^1.1.2
version: 1.1.2(react@18.2.0)
use-image:
specifier: ^1.1.1
version: 1.1.1(react-dom@18.2.0)(react@18.2.0)
@ -3092,6 +3101,12 @@ packages:
dev: true
optional: true
/@dagrejs/dagre@1.1.1:
resolution: {integrity: sha512-AQfT6pffEuPE32weFzhS/u3UpX+bRXUARIXL7UqLaxz497cN8pjuBlX6axO4IIECE2gBV8eLFQkGCtKX5sDaUA==}
dependencies:
'@dagrejs/graphlib': 2.2.1
dev: false
/@dagrejs/graphlib@2.2.1:
resolution: {integrity: sha512-xJsN1v6OAxXk6jmNdM+OS/bBE8nDCwM0yDNprXR18ZNatL6to9ggod9+l2XtiLhXfLm0NkE7+Er/cpdlM+SkUA==}
engines: {node: '>17.0.0'}
@ -11953,6 +11968,10 @@ packages:
redux: 5.0.1
dev: false
/redux-undo@1.1.0:
resolution: {integrity: sha512-zzLFh2qeF0MTIlzDhDLm9NtkfBqCllQJ3OCuIl5RKlG/ayHw6GUdIFdMhzMS9NnrnWdBX5u//ExMOHpfudGGOg==}
dev: false
/redux@5.0.1:
resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==}
dev: false
@ -13308,6 +13327,14 @@ packages:
react: 18.2.0
dev: false
/use-device-pixel-ratio@1.1.2(react@18.2.0):
resolution: {integrity: sha512-nFxV0HwLdRUt20kvIgqHYZe6PK/v4mU1X8/eLsT1ti5ck0l2ob0HDRziaJPx+YWzBo6dMm4cTac3mcyk68Gh+A==}
peerDependencies:
react: '>=16.8.0'
dependencies:
react: 18.2.0
dev: false
/use-image@1.1.1(react-dom@18.2.0)(react@18.2.0):
resolution: {integrity: sha512-n4YO2k8AJG/BcDtxmBx8Aa+47kxY5m335dJiCQA5tTeVU4XdhrhqR6wT0WISRXwdMEOv5CSjqekDZkEMiiWaYQ==}
peerDependencies:

View File

@ -0,0 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="16" height="16" rx="2" fill="#E6FD13"/>
<path d="M9.61889 5.45H12.5V3.5H3.5V5.45H6.38111L9.61889 10.55H12.5V12.5H3.5V10.55H6.38111" stroke="black"/>
<circle cx="12" cy="4" r="3" fill="#f5480c" stroke="#0d1117" stroke-width="1"/>
</svg>

After

Width:  |  Height:  |  Size: 345 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@ -291,7 +291,6 @@
"canvasMerged": "تم دمج الخط",
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
"parametersSet": "تم تعيين المعلمات",
"parametersNotSet": "لم يتم تعيين المعلمات",
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
},

View File

@ -75,7 +75,8 @@
"copy": "Kopieren",
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
"toResolve": "Lösen",
"add": "Hinzufügen"
"add": "Hinzufügen",
"loglevel": "Protokoll Stufe"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@ -84,7 +85,8 @@
"loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade",
"deleteImage": "Lösche Bild",
"deleteImage_one": "Lösche Bild",
"deleteImage_other": "",
"copy": "Kopieren",
"download": "Runterladen",
"setCurrentImage": "Setze aktuelle Bild",
@ -388,7 +390,14 @@
"vaePrecision": "VAE-Präzision",
"variant": "Variante",
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
"noModelSelected": "Kein Modell ausgewählt"
"noModelSelected": "Kein Modell ausgewählt",
"huggingFace": "HuggingFace",
"defaultSettings": "Standardeinstellungen",
"edit": "Bearbeiten",
"cancel": "Stornieren",
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
"addModels": "Model hinzufügen",
"deleteModelImage": "Lösche Model Bild"
},
"parameters": {
"images": "Bilder",
@ -472,7 +481,6 @@
"canvasMerged": "Leinwand zusammengeführt",
"sentToImageToImage": "Gesendet an Bild zu Bild",
"sentToUnifiedCanvas": "Gesendet an Leinwand",
"parametersSet": "Parameter festlegen",
"parametersNotSet": "Parameter nicht festgelegt",
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
"setCanvasInitialImage": "Ausgangsbild setzen",
@ -677,7 +685,8 @@
"body": "Körper",
"hands": "Hände",
"dwOpenpose": "DW Openpose",
"dwOpenposeDescription": "Posenschätzung mit DW Openpose"
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
},
"queue": {
"status": "Status",
@ -765,7 +774,10 @@
"recallParameters": "Parameter wiederherstellen",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"allPrompts": "Alle Prompts",
"imageDimensions": "Bilder Auslösungen"
"imageDimensions": "Bilder Auslösungen",
"parameterSet": "Parameter {{parameter}} setzen",
"recallParameter": "{{label}} Abrufen",
"parsingFailed": "Parsing Fehlgeschlagen"
},
"popovers": {
"noiseUseCPU": {
@ -1030,7 +1042,8 @@
"title": "Bild"
},
"advanced": {
"title": "Erweitert"
"title": "Erweitert",
"options": "$t(accordions.advanced.title) Optionen"
},
"control": {
"title": "Kontrolle"

View File

@ -69,6 +69,7 @@
"auto": "Auto",
"back": "Back",
"batch": "Batch Manager",
"beta": "Beta",
"cancel": "Cancel",
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
@ -83,6 +84,8 @@
"direction": "Direction",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"positivePrompt": "Positive Prompt",
"negativePrompt": "Negative Prompt",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
"error": "Error",
@ -135,7 +138,9 @@
"red": "Red",
"green": "Green",
"blue": "Blue",
"alpha": "Alpha"
"alpha": "Alpha",
"selected": "Selected",
"viewer": "Viewer"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@ -151,6 +156,7 @@
"balanced": "Balanced",
"base": "Base",
"beginEndStepPercent": "Begin / End Step Percentage",
"beginEndStepPercentShort": "Begin/End %",
"bgth": "bg_th",
"canny": "Canny",
"cannyDescription": "Canny edge detection",
@ -213,11 +219,17 @@
"resize": "Resize",
"resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe",
"saveControlImage": "Save Control Image",
"scribble": "scribble",
"selectModel": "Select a model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"selectCLIPVisionModel": "Select a CLIP Vision model",
"setControlImageDimensions": "Copy size to W/H (optimize for model)",
"setControlImageDimensionsForce": "Copy size to W/H (ignore model)",
"showAdvanced": "Show Advanced",
"small": "Small",
"toggleControlNet": "Toggle this ControlNet",
@ -325,7 +337,8 @@
"drop": "Drop",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage": "Delete Image",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
"download": "Download",
@ -655,6 +668,7 @@
"install": "Install",
"installAll": "Install All",
"installRepo": "Install Repo",
"ipAdapters": "IP Adapters",
"load": "Load",
"localOnly": "local only",
"manual": "Manual",
@ -682,6 +696,7 @@
"noModelsInstalled": "No Models Installed",
"noModelsInstalledDesc1": "Install models with the",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching Models",
"none": "none",
"path": "Path",
"pathToConfig": "Path To Config",
@ -766,6 +781,8 @@
"float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
@ -846,6 +863,7 @@
"version": "Version",
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",
@ -881,10 +899,16 @@
"denoisingStrength": "Denoising Strength",
"downloadImage": "Download Image",
"general": "General",
"globalSettings": "Global Settings",
"height": "Height",
"imageFit": "Fit Initial Image To Output Size",
"images": "Images",
"infillMethod": "Infill Method",
"infillMosaicTileWidth": "Tile Width",
"infillMosaicTileHeight": "Tile Height",
"infillMosaicMinColor": "Min Color",
"infillMosaicMaxColor": "Max Color",
"infillColorValue": "Fill Color",
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",
@ -1033,10 +1057,10 @@
"metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set",
"parameterSet": "{{parameter}} set",
"parametersNotSet": "Parameters Not Set",
"parametersSet": "Parameters Set",
"problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image",
@ -1166,6 +1190,10 @@
"heading": "Resize Mode",
"paragraphs": ["Method to fit Control Adapter's input image size to the output generation size."]
},
"ipAdapterMethod": {
"heading": "Method",
"paragraphs": ["Method by which to apply the current IP Adapter."]
},
"controlNetWeight": {
"heading": "Weight",
"paragraphs": [
@ -1415,6 +1443,8 @@
"eraseBoundingBox": "Erase Bounding Box",
"eraser": "Eraser",
"fillBoundingBox": "Fill Bounding Box",
"hideBoundingBox": "Hide Bounding Box",
"initialFitImageSize": "Fit Image Size on Drop",
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
"layer": "Layer",
"limitStrokesToBox": "Limit Strokes to Box",
@ -1431,6 +1461,7 @@
"saveMask": "Save $t(unifiedCanvas.mask)",
"saveToGallery": "Save To Gallery",
"scaledBoundingBox": "Scaled Bounding Box",
"showBoundingBox": "Show Bounding Box",
"showCanvasDebugInfo": "Show Additional Canvas Info",
"showGrid": "Show Grid",
"showResultsOn": "Show Results (On)",
@ -1473,9 +1504,44 @@
"workflowName": "Workflow Name",
"newWorkflowCreated": "New Workflow Created",
"workflowCleared": "Workflow Cleared",
"workflowEditorMenu": "Workflow Editor Menu"
"workflowEditorMenu": "Workflow Editor Menu",
"loadFromGraph": "Load Workflow from Graph",
"convertGraph": "Convert Graph",
"loadWorkflow": "$t(common.load) Workflow",
"autoLayout": "Auto Layout"
},
"app": {
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"deleteAll": "Delete All",
"addLayer": "Add Layer",
"moveToFront": "Move to Front",
"moveToBack": "Move to Back",
"moveForward": "Move Forward",
"moveBackward": "Move Backward",
"brushSize": "Brush Size",
"controlLayers": "Control Layers (BETA)",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"rectangle": "Rectangle",
"maskPreviewColor": "Mask Preview Color",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"regionalGuidance": "Regional Guidance",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"controlNetLayer": "$t(common.controlNet) $t(unifiedCanvas.layer)",
"ipAdapterLayer": "$t(common.ipAdapter) $t(unifiedCanvas.layer)",
"opacity": "Opacity",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Global $t(common.ipAdapter)",
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"opacityFilter": "Opacity Filter"
}
}

View File

@ -33,7 +33,9 @@
"autoSwitchNewImages": "Auto seleccionar Imágenes nuevas",
"loadMore": "Cargar más",
"noImagesInGallery": "No hay imágenes para mostrar",
"deleteImage": "Eliminar Imagen",
"deleteImage_one": "Eliminar Imagen",
"deleteImage_many": "",
"deleteImage_other": "",
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
"assets": "Activos",
@ -363,7 +365,6 @@
"canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersSet": "Parámetros establecidos",
"parametersNotSet": "Parámetros no establecidos",
"metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor",

View File

@ -298,7 +298,6 @@
"canvasMerged": "Canvas fusionné",
"sentToImageToImage": "Envoyé à Image à Image",
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
"parametersSet": "Paramètres définis",
"parametersNotSet": "Paramètres non définis",
"metadataLoadFailed": "Échec du chargement des métadonnées"
},

View File

@ -306,7 +306,6 @@
"canvasMerged": "קנבס מוזג",
"sentToImageToImage": "נשלח לתמונה לתמונה",
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
"parametersSet": "הגדרת פרמטרים",
"parametersNotSet": "פרמטרים לא הוגדרו",
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
},

View File

@ -82,7 +82,9 @@
"autoSwitchNewImages": "Passaggio automatico a nuove immagini",
"loadMore": "Carica altro",
"noImagesInGallery": "Nessuna immagine da visualizzare",
"deleteImage": "Elimina l'immagine",
"deleteImage_one": "Elimina l'immagine",
"deleteImage_many": "Elimina {{count}} immagini",
"deleteImage_other": "Elimina {{count}} immagini",
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
"assets": "Risorse",
@ -366,7 +368,7 @@
"modelConverted": "Modello convertito",
"alpha": "Alpha",
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
"convertToDiffusersHelpText3": "Il file Checkpoint su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
"convertToDiffusersHelpText3": "Il file del modello su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
"v2_base": "v2 (512px)",
"v2_768": "v2 (768px)",
"none": "nessuno",
@ -443,7 +445,9 @@
"noModelsInstalled": "Nessun modello installato",
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
"main": "Principali",
"noModelsInstalledDesc1": "Installa i modelli con"
"noModelsInstalledDesc1": "Installa i modelli con",
"ipAdapters": "Adattatori IP",
"noMatchingModels": "Nessun modello corrispondente"
},
"parameters": {
"images": "Immagini",
@ -525,7 +529,12 @@
"aspect": "Aspetto",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
"remixImage": "Remixa l'immagine",
"coherenceEdgeSize": "Dim. bordo"
"coherenceEdgeSize": "Dim. bordo",
"infillMosaicTileWidth": "Larghezza piastrella",
"infillMosaicMinColor": "Colore minimo",
"infillMosaicMaxColor": "Colore massimo",
"infillMosaicTileHeight": "Altezza piastrella",
"infillColorValue": "Colore di riempimento"
},
"settings": {
"models": "Modelli",
@ -568,7 +577,6 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Immagine a Immagine",
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
"parametersSet": "Parametri impostati",
"parametersNotSet": "Parametri non impostati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",
@ -620,7 +628,8 @@
"uploadInitialImage": "Carica l'immagine iniziale",
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata"
"modelImportCanceled": "Importazione del modello annullata",
"parameters": "Parametri"
},
"tooltip": {
"feature": {
@ -689,7 +698,10 @@
"coherenceModeBoxBlur": "Sfocatura Box",
"coherenceModeStaged": "Maschera espansa",
"invertBrushSizeScrollDirection": "Inverti scorrimento per dimensione pennello",
"discardCurrent": "Scarta l'attuale"
"discardCurrent": "Scarta l'attuale",
"initialFitImageSize": "Adatta dimensione immagine al rilascio",
"hideBoundingBox": "Nascondi il rettangolo di selezione",
"showBoundingBox": "Mostra il rettangolo di selezione"
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@ -832,7 +844,8 @@
"editMode": "Modifica nell'editor del flusso di lavoro",
"resetToDefaultValue": "Ripristina il valore predefinito",
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
"edit": "Modifica"
"edit": "Modifica",
"graph": "Grafico"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@ -937,7 +950,8 @@
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
"mediapipeFace": "Mediapipe Volto",
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
},
"queue": {
"queueFront": "Aggiungi all'inizio della coda",
@ -1345,13 +1359,13 @@
]
},
"seamlessTilingXAxis": {
"heading": "Asse X di piastrellatura senza cuciture",
"heading": "Piastrella senza giunte sull'asse X",
"paragraphs": [
"Affianca senza soluzione di continuità un'immagine lungo l'asse orizzontale."
]
},
"seamlessTilingYAxis": {
"heading": "Asse Y di piastrellatura senza cuciture",
"heading": "Piastrella senza giunte sull'asse Y",
"paragraphs": [
"Affianca senza soluzione di continuità un'immagine lungo l'asse verticale."
]
@ -1475,7 +1489,11 @@
"name": "Nome",
"updated": "Aggiornato",
"projectWorkflows": "Flussi di lavoro del progetto",
"opened": "Aperto"
"opened": "Aperto",
"convertGraph": "Converti grafico",
"loadWorkflow": "$t(common.load) Flusso di lavoro",
"autoLayout": "Disposizione automatica",
"loadFromGraph": "Carica il flusso di lavoro dal grafico"
},
"app": {
"storeNotInitialized": "Il negozio non è inizializzato"

View File

@ -90,7 +90,7 @@
"problemDeletingImages": "画像の削除中に問題が発生",
"drop": "ドロップ",
"dropOrUpload": "$t(gallery.drop) またはアップロード",
"deleteImage": "画像を削除",
"deleteImage_other": "画像を削除",
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
"deleteImagePermanent": "削除された画像は復元できません。",
"download": "ダウンロード",

View File

@ -82,7 +82,7 @@
"drop": "드랍",
"problemDeletingImages": "이미지 삭제 중 발생한 문제",
"downloadSelection": "선택 항목 다운로드",
"deleteImage": "이미지 삭제",
"deleteImage_other": "이미지 삭제",
"currentlyInUse": "이 이미지는 현재 다음 기능에서 사용되고 있습니다:",
"dropOrUpload": "$t(gallery.drop) 또는 업로드",
"copy": "복사",

View File

@ -42,7 +42,8 @@
"autoSwitchNewImages": "Wissel autom. naar nieuwe afbeeldingen",
"loadMore": "Laad meer",
"noImagesInGallery": "Geen afbeeldingen om te tonen",
"deleteImage": "Verwijder afbeelding",
"deleteImage_one": "Verwijder afbeelding",
"deleteImage_other": "",
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
"assets": "Eigen onderdelen",
@ -420,7 +421,6 @@
"canvasMerged": "Canvas samengevoegd",
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
"parametersSet": "Parameters ingesteld",
"parametersNotSet": "Parameters niet ingesteld",
"metadataLoadFailed": "Fout bij laden metagegevens",
"serverError": "Serverfout",

View File

@ -267,7 +267,6 @@
"canvasMerged": "Scalono widoczne warstwy",
"sentToImageToImage": "Wysłano do Obraz na obraz",
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
"parametersSet": "Ustawiono parametry",
"parametersNotSet": "Nie ustawiono parametrów",
"metadataLoadFailed": "Błąd wczytywania metadanych"
},

View File

@ -310,7 +310,6 @@
"canvasMerged": "Tela Fundida",
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
"parametersSet": "Parâmetros Definidos",
"parametersNotSet": "Parâmetros Não Definidos",
"metadataLoadFailed": "Falha ao tentar carregar metadados"
},

View File

@ -307,7 +307,6 @@
"canvasMerged": "Tela Fundida",
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
"parametersSet": "Parâmetros Definidos",
"parametersNotSet": "Parâmetros Não Definidos",
"metadataLoadFailed": "Falha ao tentar carregar metadados"
},

View File

@ -86,7 +86,9 @@
"noImagesInGallery": "Изображений нет",
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
"deleteImage": "Удалить изображение",
"deleteImage_one": "Удалить изображение",
"deleteImage_few": "",
"deleteImage_many": "",
"assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное",
@ -448,7 +450,9 @@
"loraModels": "LoRAs",
"main": "Основные",
"noModelsInstalled": "Нет установленных моделей",
"noModelsInstalledDesc1": "Установите модели с помощью"
"noModelsInstalledDesc1": "Установите модели с помощью",
"noMatchingModels": "Нет подходящих моделей",
"ipAdapters": "IP адаптеры"
},
"parameters": {
"images": "Изображения",
@ -532,7 +536,12 @@
"lockAspectRatio": "Заблокировать соотношение",
"remixImage": "Ремикс изображения",
"coherenceMinDenoise": "Мин. шумоподавление",
"coherenceEdgeSize": "Размер края"
"coherenceEdgeSize": "Размер края",
"infillMosaicTileWidth": "Ширина плиток",
"infillMosaicTileHeight": "Высота плиток",
"infillMosaicMinColor": "Мин цвет",
"infillMosaicMaxColor": "Макс цвет",
"infillColorValue": "Цвет заливки"
},
"settings": {
"models": "Модели",
@ -575,7 +584,6 @@
"canvasMerged": "Холст объединен",
"sentToImageToImage": "Отправить в img2img",
"sentToUnifiedCanvas": "Отправлено на Единый холст",
"parametersSet": "Параметры заданы",
"parametersNotSet": "Параметры не заданы",
"metadataLoadFailed": "Не удалось загрузить метаданные",
"serverError": "Ошибка сервера",
@ -627,7 +635,8 @@
"uploadInitialImage": "Загрузить начальное изображение",
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен"
"modelImportCanceled": "Импорт модели отменен",
"parameters": "Параметры"
},
"tooltip": {
"feature": {
@ -696,7 +705,8 @@
"coherenceModeGaussianBlur": "Размытие по Гауссу",
"coherenceModeBoxBlur": "коробчатое размытие",
"discardCurrent": "Отбросить текущее",
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти"
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти",
"initialFitImageSize": "Подогнать размер изображения при перебросе"
},
"accessibility": {
"uploadImage": "Загрузить изображение",
@ -922,7 +932,8 @@
"modelSize": "Размер модели",
"small": "Маленький",
"body": "Тело",
"hands": "Руки"
"hands": "Руки",
"selectCLIPVisionModel": "Выбрать модель CLIP Vision"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",

View File

@ -298,7 +298,8 @@
"noImagesInGallery": "Gösterilecek Görsel Yok",
"autoSwitchNewImages": "Yeni Görseli Biter Bitmez Gör",
"currentlyInUse": "Bu görsel şurada kullanımda:",
"deleteImage": "Görseli Sil",
"deleteImage_one": "Görseli Sil",
"deleteImage_other": "",
"loadMore": "Daha Getir",
"setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi",

View File

@ -315,7 +315,6 @@
"canvasMerged": "Полотно об'єднане",
"sentToImageToImage": "Надіслати до img2img",
"sentToUnifiedCanvas": "Надіслати на полотно",
"parametersSet": "Параметри задані",
"parametersNotSet": "Параметри не задані",
"metadataLoadFailed": "Не вдалося завантажити метадані",
"serverError": "Помилка сервера",

View File

@ -65,7 +65,12 @@
"nextPage": "下一页",
"saveAs": "保存为",
"ai": "ai",
"or": "或"
"or": "或",
"aboutDesc": "使用 Invoke 工作?查看:",
"add": "添加",
"loglevel": "日志级别",
"copy": "复制",
"localSystem": "本地系统"
},
"gallery": {
"galleryImageSize": "预览大小",
@ -73,7 +78,7 @@
"autoSwitchNewImages": "自动切换到新图像",
"loadMore": "加载更多",
"noImagesInGallery": "无图像可用于显示",
"deleteImage": "删除图片",
"deleteImage_other": "删除图片",
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
"deleteImagePermanent": "删除的图片无法被恢复。",
"assets": "素材",
@ -487,7 +492,6 @@
"canvasMerged": "画布已合并",
"sentToImageToImage": "已发送到图生图",
"sentToUnifiedCanvas": "已发送到统一画布",
"parametersSet": "参数已设定",
"parametersNotSet": "参数未设定",
"metadataLoadFailed": "加载元数据失败",
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
@ -600,7 +604,8 @@
"loadMore": "加载更多",
"mode": "模式",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "创建问题"
"createIssue": "创建问题",
"about": "关于"
},
"tooltip": {
"feature": {
@ -1202,7 +1207,16 @@
"workflows": "工作流",
"noDescription": "无描述",
"uploadWorkflow": "从文件中加载",
"newWorkflowCreated": "已创建新的工作流"
"newWorkflowCreated": "已创建新的工作流",
"name": "名称",
"defaultWorkflows": "默认工作流",
"created": "已创建",
"ascending": "升序",
"descending": "降序",
"updated": "已更新",
"userWorkflows": "我的工作流",
"projectWorkflows": "项目工作流",
"opened": "已打开"
},
"app": {
"storeNotInitialized": "商店尚未初始化"
@ -1220,7 +1234,8 @@
"title": "生成"
},
"advanced": {
"title": "高级"
"title": "高级",
"options": "$t(accordions.advanced.title) 选项"
},
"image": {
"title": "图像"

View File

@ -1,5 +1,6 @@
import { Box, useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { useSocketIO } from 'app/hooks/useSocketIO';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
import { useLogger } from 'app/logging/useLogger';
import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -70,6 +71,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
}, [dispatch]);
useStarterModelsToast();
useSyncQueueStatus();
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>

View File

@ -0,0 +1,25 @@
import { useEffect } from 'react';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
const baseTitle = document.title;
const invokeLogoSVG = 'assets/images/invoke-favicon.svg';
const invokeAlertLogoSVG = 'assets/images/invoke-alert-favicon.svg';
/**
* This hook synchronizes the queue status with the page's title and favicon.
* It should be considered a singleton and only used once in the component tree.
*/
export const useSyncQueueStatus = () => {
const { queueSize } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
}),
});
useEffect(() => {
document.title = queueSize > 0 ? `(${queueSize}) ${baseTitle}` : baseTitle;
const faviconEl = document.getElementById('invoke-favicon');
if (faviconEl instanceof HTMLLinkElement) {
faviconEl.href = queueSize > 0 ? invokeAlertLogoSVG : invokeLogoSVG;
}
}, [queueSize]);
};

View File

@ -27,7 +27,8 @@ export type LoggerNamespace =
| 'socketio'
| 'session'
| 'queue'
| 'dnd';
| 'dnd'
| 'controlLayers';
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });

Some files were not shown because too many files have changed in this diff Show More