Merge branch 'main' into depth_anything_v2

This commit is contained in:
blessedcoolant 2024-08-03 00:38:57 +05:30
commit 4f8a4b0f22
29 changed files with 958 additions and 185 deletions

View File

@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

View File

@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
version="1.2.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = 1 - self.minimum_denoise

View File

@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
@ -845,6 +846,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### lora
if self.unet.loras:
for lora_field in self.unet.loras:
ext_manager.add_extension(
LoRAExt(
node_context=context,
model_id=lora_field.lora,
weight=lora_field.weight,
)
)
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
@ -964,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
):
assert isinstance(unet, UNet2DConditionModel)

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].

View File

@ -0,0 +1,100 @@
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
}
@invocation(
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
# Reference:
# - https://arxiv.org/pdf/2303.05499
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
# The model expects a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
# Convert detections to BoundingBoxCollectionOutput.
bounding_boxes: list[BoundingBoxField] = []
for detection in detections:
bounding_boxes.append(
BoundingBoxField(
x_min=detection.box.xmin,
x_max=detection.box.xmax,
y_min=detection.box.ymin,
y_max=detection.box.ymax,
score=detection.score,
)
)
return BoundingBoxCollectionOutput(collection=bounding_boxes)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@ -1,9 +1,10 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
@invocation(
@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy()
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@ -7,6 +7,7 @@ import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ConditioningField,
DenoiseMaskField,
@ -469,3 +470,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
# endregion
# region BoundingBox
@invocation_output("bounding_box_output")
class BoundingBoxOutput(BaseInvocationOutput):
"""Base class for nodes that output a single bounding box"""
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
@invocation_output("bounding_box_collection_output")
class BoundingBoxCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of bounding boxes"""
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
@invocation(
"bounding_box",
title="Bounding Box",
tags=["primitives", "segmentation", "collection", "bounding box"],
category="primitives",
version="1.0.0",
)
class BoundingBoxInvocation(BaseInvocation):
"""Create a bounding box manually by supplying box coordinates"""
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
return BoundingBoxOutput(bounding_box=bounding_box)
# endregion

View File

@ -0,0 +1,161 @@
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
"segment-anything-base": "facebook/sam-vit-base",
"segment-anything-large": "facebook/sam-vit-large",
"segment-anything-huge": "facebook/sam-vit-huge",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
)
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0:
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
return masks
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model from a tensor of shape
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks, _ = masks.max(dim=1)
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all":
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, ConfigDict
class BoundingBox(BaseModel):
"""Bounding box helper class."""
xmin: int
ymin: int
xmax: int
ymax: int
class DetectionResult(BaseModel):
"""Detection result from Grounding DINO."""
score: float
label: str
box: BoundingBox
model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True
)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
class GroundingDinoPipeline(RawModel):
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
management system.
"""
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
assert results is not None
results = [DetectionResult.model_validate(result) for result in results]
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
# CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@ -0,0 +1,50 @@
# This file contains utilities for Grounded-SAM mask refinement based on:
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
import cv2
import numpy as np
import numpy.typing as npt
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
"""Convert a binary mask to a polygon.
Returns:
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask.
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area.
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour.
polygon = largest_contour.reshape(-1, 2).tolist()
return polygon
def polygon_to_mask(
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
) -> npt.NDArray[np.uint8]:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
fill_value (int): Value to fill the polygon with.
Returns:
np.ndarray: Segmentation mask with the polygon filled (with value 255).
"""
# Create an empty mask.
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points.
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255).
cv2.fillPoly(mask, [pts], color=(fill_value,))
return mask

View File

@ -0,0 +1,53 @@
from typing import Optional
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._sam_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
"""Run the SAM model.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Add batch dimension of 1 to the bounding boxes.
boxes = [bounding_boxes]
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)
# There should be only one batch.
assert len(masks) == 1
return masks[0]

View File

@ -3,12 +3,13 @@
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
@ -46,9 +47,19 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
@ -60,6 +71,17 @@ class LoRALayerBase:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase):
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase):
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -186,37 +216,39 @@ class LoKRLayer(LoRALayerBase):
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
self.t2 = values.get("lokr_t2", None)
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
@ -272,7 +304,9 @@ class LoKRLayer(LoRALayerBase):
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
@ -282,15 +316,12 @@ class FullLayer(LoRALayerBase):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
@ -319,8 +350,9 @@ class IA3Layer(LoRALayerBase):
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
@ -458,16 +490,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_down.weight" in values:
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
@ -475,7 +510,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
else:

View File

@ -12,6 +12,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
@ -36,7 +38,16 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(
model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel, DepthAnythingPipeline)
model,
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,
DepthAnythingPipeline,
),
):
return model.calc_size()
else:

View File

@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
@ -85,13 +86,13 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
cached_weights=cached_weights,
):
yield
@ -101,9 +102,9 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@ -113,7 +114,7 @@ class ModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
@ -121,66 +122,26 @@ class ModelPatcher:
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {}
original_weights = OriginalWeightsStorage(cached_weights)
try:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=TorchDevice.CPU_DEVICE)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
yield
finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager

View File

@ -2,14 +2,14 @@ from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List
import torch
from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
@dataclass
@ -56,5 +56,17 @@ class ExtensionBase:
yield None
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
by this context manager.
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
can access original weights values.
"""
yield

View File

@ -1,15 +1,15 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class FreeUExt(ExtensionBase):
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
self._freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,

View File

@ -0,0 +1,137 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Tuple
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
class LoRAExt(ExtensionBase):
def __init__(
self,
node_context: InvocationContext,
model_id: ModelIdentifierField,
weight: float,
):
super().__init__()
self._node_context = node_context
self._model_id = model_id
self._weight = weight
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
self.patch_model(
model=unet,
prefix="lora_unet_",
lora=lora_model,
lora_weight=self._weight,
original_weights=original_weights,
)
del lora_model
yield
@classmethod
@torch.no_grad()
def patch_model(
cls,
model: torch.nn.Module,
prefix: str,
lora: LoRAModelRaw,
lora_weight: float,
original_weights: OriginalWeightsStorage,
):
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param lora: LoRA model to patch in.
:param lora_weight: LoRA patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
"""
if lora_weight == 0:
return
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, lora_param_weight in layer.get_parameters(module).items():
param_key = module_key + "." + param_name
module_param = module.get_parameter(param_name)
# save original weight
original_weights.save(param_key, module_param)
if module_param.shape != lora_param_weight.shape:
# TODO: debug on lycoris
lora_param_weight = lora_param_weight.reshape(module_param.shape)
lora_param_weight *= lora_weight * layer_scale
module_param += lora_param_weight.to(dtype=dtype)
layer.to(device=TorchDevice.CPU_DEVICE)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)

View File

@ -7,6 +7,7 @@ import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@ -67,9 +68,15 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
original_weights = OriginalWeightsStorage(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
yield None
yield None
finally:
with torch.no_grad():
for param_key, weight in original_weights.get_changed_weights():
unet.get_parameter(param_key).copy_(weight)

View File

@ -20,10 +20,14 @@ from diffusers import (
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin
# TODO: add dpmpp_3s/dpmpp_3s_k when fix released
# https://github.com/huggingface/diffusers/issues/9007
SCHEDULER_NAME_VALUES = Literal[
"ddim",
"ddpm",
"deis",
"deis_k",
"lms",
"lms_k",
"pndm",
@ -33,16 +37,21 @@ SCHEDULER_NAME_VALUES = Literal[
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_k",
"kdpm_2_a",
"kdpm_2_a_k",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_3m",
"dpmpp_3m_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
"unipc_k",
"lcm",
"tcd",
]
@ -50,7 +59,8 @@ SCHEDULER_NAME_VALUES = Literal[
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
"ddim": (DDIMScheduler, {}),
"ddpm": (DDPMScheduler, {}),
"deis": (DEISMultistepScheduler, {}),
"deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}),
"deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}),
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"pndm": (PNDMScheduler, {}),
@ -59,17 +69,28 @@ SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str,
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"kdpm_2": (KDPM2DiscreteScheduler, {}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
"kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}),
"kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}),
"kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
"dpmpp_2m_sde": (
DPMSolverMultistepScheduler,
{"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
),
"dpmpp_2m_sde_k": (
DPMSolverMultistepScheduler,
{"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
),
"dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}),
"dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}),
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
"unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}),
"unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}),
"lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}),
}

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from typing import Dict, Iterator, Optional, Tuple
import torch
from invokeai.backend.util.devices import TorchDevice
class OriginalWeightsStorage:
"""A class for tracking the original weights of a model for patch/unpatch operations."""
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
# The original weights of the model.
self._weights: dict[str, torch.Tensor] = {}
# The keys of the weights that have been changed (via `save()`) during the lifetime of this instance.
self._changed_weights: set[str] = set()
if cached_weights:
self._weights.update(cached_weights)
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
self._changed_weights.add(key)
if key in self._weights:
return
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
weight = self._weights.get(key, None)
if weight is not None and copy:
weight = weight.clone()
return weight
def contains(self, key: str) -> bool:
return key in self._weights
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
for key in self._changed_weights:
yield key, self._weights[key]

View File

@ -16,6 +16,8 @@ import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterM
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
import { size } from 'lodash-es';
@ -34,9 +36,10 @@ interface Props {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
destination?: InvokeTabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
@ -67,6 +70,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
}
}, [dispatch, config, logger]);
useEffect(() => {
if (destination) {
dispatch(setActiveTab(destination));
}
}, [dispatch, destination]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);

View File

@ -19,6 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@ -43,6 +44,7 @@ interface Props extends PropsWithChildren {
imageName: string;
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
destination?: InvokeTabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;
@ -62,6 +64,7 @@ const InvokeAIUI = ({
projectUrl,
queueId,
selectedImage,
destination,
customStarUi,
socketOptions,
isDebugging = false,
@ -218,7 +221,7 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<AppDndContext>
<App config={config} selectedImage={selectedImage} />
<App config={config} selectedImage={selectedImage} destination={destination} />
</AppDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@ -32,6 +32,7 @@ export const zSchedulerField = z.enum([
'ddpm',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_3m',
'dpmpp_2m_sde',
'dpmpp_sde',
'heun',
@ -40,12 +41,17 @@ export const zSchedulerField = z.enum([
'pndm',
'unipc',
'euler_k',
'deis_k',
'dpmpp_2s_k',
'dpmpp_2m_k',
'dpmpp_3m_k',
'dpmpp_2m_sde_k',
'dpmpp_sde_k',
'heun_k',
'kdpm_2_k',
'kdpm_2_a_k',
'lms_k',
'unipc_k',
'euler_a',
'kdpm_2_a',
'lcm',

View File

@ -125,19 +125,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
g.upsertMetadata({
cfg_scale,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
scheduler,
vae: vae ?? undefined,
});
} else {
posCondNode = g.addNode({
@ -166,24 +158,33 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
g.upsertMetadata({
cfg_scale,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
scheduler,
vae: vae ?? undefined,
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
creativity,
structure,
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
g.upsertMetadata({
cfg_scale,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
scheduler,
vae: vae ?? undefined,
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
creativity,
structure,
upscale_initial_image: {
image_name: upscaleInitialImage.image_name,
width: upscaleInitialImage.width,
height: upscaleInitialImage.height,
},
upscale_scale: scale,
});
g.setMetadataReceivingNode(l2iNode);
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
g.addEdgeToMetadata(upscaleNode, 'height', 'height');

View File

@ -52,28 +52,34 @@ export const CLIP_SKIP_MAP = {
* Mapping of schedulers to human readable name
*/
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
{ value: 'euler', label: 'Euler' },
{ value: 'deis', label: 'DEIS' },
{ value: 'ddim', label: 'DDIM' },
{ value: 'ddpm', label: 'DDPM' },
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
{ value: 'deis', label: 'DEIS' },
{ value: 'deis_k', label: 'DEIS Karras' },
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
{ value: 'heun', label: 'Heun' },
{ value: 'kdpm_2', label: 'KDPM 2' },
{ value: 'lms', label: 'LMS' },
{ value: 'pndm', label: 'PNDM' },
{ value: 'unipc', label: 'UniPC' },
{ value: 'euler_k', label: 'Euler Karras' },
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
{ value: 'heun_k', label: 'Heun Karras' },
{ value: 'lms_k', label: 'LMS Karras' },
{ value: 'dpmpp_3m', label: 'DPM++ 3M' },
{ value: 'dpmpp_3m_k', label: 'DPM++ 3M Karras' },
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
{ value: 'euler', label: 'Euler' },
{ value: 'euler_k', label: 'Euler Karras' },
{ value: 'euler_a', label: 'Euler Ancestral' },
{ value: 'heun', label: 'Heun' },
{ value: 'heun_k', label: 'Heun Karras' },
{ value: 'kdpm_2', label: 'KDPM 2' },
{ value: 'kdpm_2_k', label: 'KDPM 2 Karras' },
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
{ value: 'kdpm_2_a_k', label: 'KDPM 2 Ancestral Karras' },
{ value: 'lcm', label: 'LCM' },
{ value: 'lms', label: 'LMS' },
{ value: 'lms_k', label: 'LMS Karras' },
{ value: 'pndm', label: 'PNDM' },
{ value: 'tcd', label: 'TCD' },
].sort((a, b) => a.label.localeCompare(b.label));
{ value: 'unipc', label: 'UniPC' },
{ value: 'unipc_k', label: 'UniPC Karras' },
];

View File

@ -3553,7 +3553,7 @@ export type components = {
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
/**
* UNet
* @description UNet (scheduler, LoRAs)
@ -8553,7 +8553,7 @@ export type components = {
* Scheduler
* @description Default scheduler for this model
*/
scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null;
scheduler?: ("ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd") | null;
/**
* Steps
* @description Default number of steps for this model
@ -11467,7 +11467,7 @@ export type components = {
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
/**
* type
* @default scheduler
@ -11483,7 +11483,7 @@ export type components = {
* @description Scheduler to use during inference
* @enum {string}
*/
scheduler: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
scheduler: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
/**
* type
* @default scheduler_output
@ -13261,7 +13261,7 @@ export type components = {
* @default euler
* @enum {string}
*/
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
/**
* UNet
* @description UNet (scheduler, LoRAs)