mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into depth_anything_v2
This commit is contained in:
commit
4f8a4b0f22
@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
ModelPatcher.apply_lora(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
model_state_dict=state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
|||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.1.0",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class CreateGradientMaskInvocation(BaseInvocation):
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||||
blur_tensor = (blur_tensor - 0.5) * 2
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
|
blur_tensor[blur_tensor < 0] = 0.0
|
||||||
|
|
||||||
threshold = 1 - self.minimum_denoise
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
|
|||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
@ -845,6 +846,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.freeu_config:
|
if self.unet.freeu_config:
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
|
### lora
|
||||||
|
if self.unet.loras:
|
||||||
|
for lora_field in self.unet.loras:
|
||||||
|
ext_manager.add_extension(
|
||||||
|
LoRAExt(
|
||||||
|
node_context=context,
|
||||||
|
model_id=lora_field.lora,
|
||||||
|
weight=lora_field.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
### seamless
|
### seamless
|
||||||
if self.unet.seamless_axes:
|
if self.unet.seamless_axes:
|
||||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
@ -964,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (cached_weights, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBoxField(BaseModel):
|
||||||
|
"""A bounding box primitive value."""
|
||||||
|
|
||||||
|
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||||
|
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||||
|
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||||
|
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||||
|
|
||||||
|
score: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||||
|
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_coords(self):
|
||||||
|
if self.x_min > self.x_max:
|
||||||
|
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||||
|
if self.y_min > self.y_max:
|
||||||
|
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
100
invokeai/app/invocations/grounding_dino.py
Normal file
100
invokeai/app/invocations/grounding_dino.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import pipeline
|
||||||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||||
|
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
|
||||||
|
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
||||||
|
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||||
|
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
||||||
|
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"grounding_dino",
|
||||||
|
title="Grounding DINO (Text Prompt Object Detection)",
|
||||||
|
tags=["prompt", "object detection"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class GroundingDinoInvocation(BaseInvocation):
|
||||||
|
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
|
||||||
|
|
||||||
|
# Reference:
|
||||||
|
# - https://arxiv.org/pdf/2303.05499
|
||||||
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
||||||
|
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||||
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
detection_threshold: float = InputField(
|
||||||
|
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||||
|
# The model expects a 3-channel RGB image.
|
||||||
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
detections = self._detect(
|
||||||
|
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert detections to BoundingBoxCollectionOutput.
|
||||||
|
bounding_boxes: list[BoundingBoxField] = []
|
||||||
|
for detection in detections:
|
||||||
|
bounding_boxes.append(
|
||||||
|
BoundingBoxField(
|
||||||
|
x_min=detection.box.xmin,
|
||||||
|
x_max=detection.box.xmax,
|
||||||
|
y_min=detection.box.ymin,
|
||||||
|
y_max=detection.box.ymax,
|
||||||
|
score=detection.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_grounding_dino(model_path: Path):
|
||||||
|
grounding_dino_pipeline = pipeline(
|
||||||
|
model=str(model_path),
|
||||||
|
task="zero-shot-object-detection",
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||||
|
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||||
|
|
||||||
|
def _detect(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
labels: list[str],
|
||||||
|
threshold: float = 0.3,
|
||||||
|
) -> list[DetectionResult]:
|
||||||
|
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||||
|
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||||
|
# actually makes a difference.
|
||||||
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
|
with context.models.load_remote_model(
|
||||||
|
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||||
|
) as detector:
|
||||||
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
@ -1,9 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
height=mask.shape[1],
|
height=mask.shape[1],
|
||||||
width=mask.shape[2],
|
width=mask.shape[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"tensor_mask_to_image",
|
||||||
|
title="Tensor Mask to Image",
|
||||||
|
tags=["mask"],
|
||||||
|
category="mask",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Convert a mask tensor to an image."""
|
||||||
|
|
||||||
|
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
# Ensure that the mask is binary.
|
||||||
|
if mask.dtype != torch.bool:
|
||||||
|
mask = mask > 0.5
|
||||||
|
mask_np = (mask.float() * 255).byte().cpu().numpy()
|
||||||
|
|
||||||
|
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||||
|
image_dto = context.images.save(image=mask_pil)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
|
BoundingBoxField,
|
||||||
ColorField,
|
ColorField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@ -469,3 +470,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# region BoundingBox
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_output")
|
||||||
|
class BoundingBoxOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single bounding box"""
|
||||||
|
|
||||||
|
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_collection_output")
|
||||||
|
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of bounding boxes"""
|
||||||
|
|
||||||
|
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"bounding_box",
|
||||||
|
title="Bounding Box",
|
||||||
|
tags=["primitives", "segmentation", "collection", "bounding box"],
|
||||||
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class BoundingBoxInvocation(BaseInvocation):
|
||||||
|
"""Create a bounding box manually by supplying box coordinates"""
|
||||||
|
|
||||||
|
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
|
||||||
|
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
|
||||||
|
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
|
||||||
|
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
|
||||||
|
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
||||||
|
return BoundingBoxOutput(bounding_box=bounding_box)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
161
invokeai/app/invocations/segment_anything.py
Normal file
161
invokeai/app/invocations/segment_anything.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||||
|
from transformers.models.sam import SamModel
|
||||||
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||||
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
|
|
||||||
|
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||||
|
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||||
|
"segment-anything-base": "facebook/sam-vit-base",
|
||||||
|
"segment-anything-large": "facebook/sam-vit-large",
|
||||||
|
"segment-anything-huge": "facebook/sam-vit-huge",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"segment_anything",
|
||||||
|
title="Segment Anything",
|
||||||
|
tags=["prompt", "segmentation"],
|
||||||
|
category="segmentation",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class SegmentAnythingInvocation(BaseInvocation):
|
||||||
|
"""Runs a Segment Anything Model."""
|
||||||
|
|
||||||
|
# Reference:
|
||||||
|
# - https://arxiv.org/pdf/2304.02643
|
||||||
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||||
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||||
|
apply_polygon_refinement: bool = InputField(
|
||||||
|
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||||
|
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||||
|
default="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
# The models expect a 3-channel RGB image.
|
||||||
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
if len(self.bounding_boxes) == 0:
|
||||||
|
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
masks = self._segment(context=context, image=image_pil)
|
||||||
|
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||||
|
|
||||||
|
# masks contains bool values, so we merge them via max-reduce.
|
||||||
|
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||||
|
|
||||||
|
mask_tensor_name = context.tensors.save(combined_mask)
|
||||||
|
height, width = combined_mask.shape
|
||||||
|
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_sam_model(model_path: Path):
|
||||||
|
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(sam_model, SamModel)
|
||||||
|
|
||||||
|
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||||
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
|
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
|
def _segment(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
|
# Convert the bounding boxes to the SAM input format.
|
||||||
|
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||||
|
|
||||||
|
with (
|
||||||
|
context.models.load_remote_model(
|
||||||
|
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||||
|
) as sam_pipeline,
|
||||||
|
):
|
||||||
|
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||||
|
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||||
|
|
||||||
|
masks = self._process_masks(masks)
|
||||||
|
if self.apply_polygon_refinement:
|
||||||
|
masks = self._apply_polygon_refinement(masks)
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||||
|
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||||
|
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||||
|
"""
|
||||||
|
assert masks.dtype == torch.bool
|
||||||
|
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||||
|
masks, _ = masks.max(dim=1)
|
||||||
|
# Split the first dimension into a list of masks.
|
||||||
|
return list(masks.cpu().unbind(dim=0))
|
||||||
|
|
||||||
|
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||||
|
"""Apply polygon refinement to the masks.
|
||||||
|
|
||||||
|
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||||
|
- Smooth the edges of the mask slightly.
|
||||||
|
- Ensure that each mask consists of a single closed polygon
|
||||||
|
- Removes small mask pieces.
|
||||||
|
- Removes holes from the mask.
|
||||||
|
"""
|
||||||
|
# Convert tensor masks to np masks.
|
||||||
|
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||||
|
|
||||||
|
# Apply polygon refinement.
|
||||||
|
for idx, mask in enumerate(np_masks):
|
||||||
|
shape = mask.shape
|
||||||
|
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||||
|
polygon = mask_to_polygon(mask)
|
||||||
|
mask = polygon_to_mask(polygon, shape)
|
||||||
|
np_masks[idx] = mask
|
||||||
|
|
||||||
|
# Convert np masks back to tensor masks.
|
||||||
|
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||||
|
"""Filter the detected masks based on the specified mask filter."""
|
||||||
|
assert len(masks) == len(bounding_boxes)
|
||||||
|
|
||||||
|
if self.mask_filter == "all":
|
||||||
|
return masks
|
||||||
|
elif self.mask_filter == "largest":
|
||||||
|
# Find the largest mask.
|
||||||
|
return [max(masks, key=lambda x: float(x.sum()))]
|
||||||
|
elif self.mask_filter == "highest_box_score":
|
||||||
|
# Find the index of the bounding box with the highest score.
|
||||||
|
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||||
|
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||||
|
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||||
|
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||||
|
return [masks[max_score_idx]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
@ -0,0 +1,22 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBox(BaseModel):
|
||||||
|
"""Bounding box helper class."""
|
||||||
|
|
||||||
|
xmin: int
|
||||||
|
ymin: int
|
||||||
|
xmax: int
|
||||||
|
ymax: int
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionResult(BaseModel):
|
||||||
|
"""Detection result from Grounding DINO."""
|
||||||
|
|
||||||
|
score: float
|
||||||
|
label: str
|
||||||
|
box: BoundingBox
|
||||||
|
model_config = ConfigDict(
|
||||||
|
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||||
|
arbitrary_types_allowed=True
|
||||||
|
)
|
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingDinoPipeline(RawModel):
|
||||||
|
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||||
|
management system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||||
|
self._pipeline = pipeline
|
||||||
|
|
||||||
|
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||||
|
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||||
|
assert results is not None
|
||||||
|
results = [DetectionResult.model_validate(result) for result in results]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||||
|
# CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
|
self._pipeline.device = self._pipeline.model.device
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self._pipeline.model)
|
@ -0,0 +1,50 @@
|
|||||||
|
# This file contains utilities for Grounded-SAM mask refinement based on:
|
||||||
|
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
|
||||||
|
"""Convert a binary mask to a polygon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||||
|
"""
|
||||||
|
# Find contours in the binary mask.
|
||||||
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
# Find the contour with the largest area.
|
||||||
|
largest_contour = max(contours, key=cv2.contourArea)
|
||||||
|
|
||||||
|
# Extract the vertices of the contour.
|
||||||
|
polygon = largest_contour.reshape(-1, 2).tolist()
|
||||||
|
|
||||||
|
return polygon
|
||||||
|
|
||||||
|
|
||||||
|
def polygon_to_mask(
|
||||||
|
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
|
||||||
|
) -> npt.NDArray[np.uint8]:
|
||||||
|
"""Convert a polygon to a segmentation mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||||
|
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||||
|
fill_value (int): Value to fill the polygon with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Segmentation mask with the polygon filled (with value 255).
|
||||||
|
"""
|
||||||
|
# Create an empty mask.
|
||||||
|
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Convert polygon to an array of points.
|
||||||
|
pts = np.array(polygon, dtype=np.int32)
|
||||||
|
|
||||||
|
# Fill the polygon with white color (255).
|
||||||
|
cv2.fillPoly(mask, [pts], color=(fill_value,))
|
||||||
|
|
||||||
|
return mask
|
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.models.sam import SamModel
|
||||||
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAnythingPipeline(RawModel):
|
||||||
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||||
|
|
||||||
|
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
||||||
|
self._sam_model = sam_model
|
||||||
|
self._sam_processor = sam_processor
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
|
self._sam_model.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self._sam_model)
|
||||||
|
|
||||||
|
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||||
|
"""Run the SAM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The image to segment.
|
||||||
|
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||||
|
[xmin, ymin, xmax, ymax].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||||
|
"""
|
||||||
|
# Add batch dimension of 1 to the bounding boxes.
|
||||||
|
boxes = [bounding_boxes]
|
||||||
|
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||||
|
outputs = self._sam_model(**inputs)
|
||||||
|
masks = self._sam_processor.post_process_masks(
|
||||||
|
masks=outputs.pred_masks,
|
||||||
|
original_sizes=inputs.original_sizes,
|
||||||
|
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# There should be only one batch.
|
||||||
|
assert len(masks) == 1
|
||||||
|
return masks[0]
|
@ -3,12 +3,13 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
@ -46,9 +47,19 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
params = {"weight": self.get_weight(orig_module.weight)}
|
||||||
|
bias = self.get_bias(orig_module.bias)
|
||||||
|
if bias is not None:
|
||||||
|
params["bias"] = bias
|
||||||
|
return params
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
for val in [self.bias]:
|
for val in [self.bias]:
|
||||||
@ -60,6 +71,17 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
"""Log a warning if values contains unhandled keys."""
|
||||||
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||||
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||||
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
|
if unknown_keys:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
if "lora_mid.weight" in values:
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lora_up.weight",
|
||||||
|
"lora_down.weight",
|
||||||
|
"lora_mid.weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.w1_b = values["hada_w1_b"]
|
self.w1_b = values["hada_w1_b"]
|
||||||
self.w2_a = values["hada_w2_a"]
|
self.w2_a = values["hada_w2_a"]
|
||||||
self.w2_b = values["hada_w2_b"]
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
self.t1 = values.get("hada_t1", None)
|
||||||
if "hada_t1" in values:
|
self.t2 = values.get("hada_t2", None)
|
||||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"hada_w1_a",
|
||||||
|
"hada_w1_b",
|
||||||
|
"hada_w2_a",
|
||||||
|
"hada_w2_b",
|
||||||
|
"hada_t1",
|
||||||
|
"hada_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -186,37 +216,39 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
self.w1 = values.get("lokr_w1", None)
|
||||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
if self.w1 is None:
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
self.w1_a = values["lokr_w1_a"]
|
||||||
self.w1_b = values["lokr_w1_b"]
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
self.w2 = values.get("lokr_w2", None)
|
||||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
if self.w2 is None:
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
self.w2_a = values["lokr_w2_a"]
|
||||||
self.w2_b = values["lokr_w2_b"]
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
self.t2 = values.get("lokr_t2", None)
|
||||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
if self.w1_b is not None:
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
elif "lokr_w2_b" in values:
|
elif self.w2_b is not None:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = self.w2_b.shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lokr_w1",
|
||||||
|
"lokr_w1_a",
|
||||||
|
"lokr_w1_b",
|
||||||
|
"lokr_w2",
|
||||||
|
"lokr_w2_a",
|
||||||
|
"lokr_w2_b",
|
||||||
|
"lokr_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
@ -272,7 +304,9 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
# weight: torch.Tensor
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -282,15 +316,12 @@ class FullLayer(LoRALayerBase):
|
|||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.weight = values["diff"]
|
self.weight = values["diff"]
|
||||||
|
self.bias = values.get("diff_b", None)
|
||||||
if len(values.keys()) > 1:
|
|
||||||
_keys = list(values.keys())
|
|
||||||
_keys.remove("diff")
|
|
||||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"diff", "diff_b"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -319,8 +350,9 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = values["on_input"]
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"weight", "on_input"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if not self.on_input:
|
if not self.on_input:
|
||||||
weight = weight.reshape(-1, 1)
|
weight = weight.reshape(-1, 1)
|
||||||
@ -458,16 +490,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||||
|
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_up.weight" in values:
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
# loha
|
# loha
|
||||||
elif "hada_w1_b" in values:
|
elif "hada_w1_a" in values:
|
||||||
layer = LoHALayer(layer_key, values)
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
# lokr
|
# lokr
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
# diff
|
# diff
|
||||||
@ -475,7 +510,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
layer = FullLayer(layer_key, values)
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
# ia3
|
# ia3
|
||||||
elif "weight" in values and "on_input" in values:
|
elif "on_input" in values:
|
||||||
layer = IA3Layer(layer_key, values)
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -12,6 +12,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
@ -36,7 +38,16 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||||
return 0
|
return 0
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel, DepthAnythingPipeline)
|
model,
|
||||||
|
(
|
||||||
|
TextualInversionModelRaw,
|
||||||
|
IPAdapter,
|
||||||
|
LoRAModelRaw,
|
||||||
|
SpandrelImageToImageModel,
|
||||||
|
GroundingDinoPipeline,
|
||||||
|
SegmentAnythingPipeline,
|
||||||
|
DepthAnythingPipeline,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
else:
|
else:
|
||||||
|
@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@ -85,13 +86,13 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -101,9 +102,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -113,7 +114,7 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@ -121,66 +122,26 @@ class ModelPatcher:
|
|||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights = {}
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
for lora_model, lora_weight in loras:
|
||||||
for lora, lora_weight in loras:
|
LoRAExt.patch_model(
|
||||||
# assert lora.device.type == "cpu"
|
model=model,
|
||||||
for layer_key, layer in lora.layers.items():
|
prefix=prefix,
|
||||||
if not layer_key.startswith(prefix):
|
lora=lora_model,
|
||||||
continue
|
lora_weight=lora_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
yield
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
if module_key not in original_weights:
|
|
||||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
|
||||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
|
||||||
else:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
assert hasattr(layer_weight, "reshape")
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
yield # wait for context manager exit
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -2,14 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -56,5 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
yield None
|
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||||
|
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
||||||
|
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
||||||
|
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
||||||
|
by this context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
|
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
||||||
|
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
||||||
|
can access original weights values.
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class FreeUExt(ExtensionBase):
|
class FreeUExt(ExtensionBase):
|
||||||
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
|
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
weight: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
|
self.patch_model(
|
||||||
|
model=unet,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
lora=lora_model,
|
||||||
|
lora_weight=self._weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def patch_model(
|
||||||
|
cls,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
lora: LoRAModelRaw,
|
||||||
|
lora_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param lora: LoRA model to patch in.
|
||||||
|
:param lora_weight: LoRA patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if lora_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# assert lora.device.type == "cpu"
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||||
|
param_key = module_key + "." + param_name
|
||||||
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
|
# save original weight
|
||||||
|
original_weights.save(param_key, module_param)
|
||||||
|
|
||||||
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
|
||||||
|
lora_param_weight *= lora_weight * layer_scale
|
||||||
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
|
assert "." not in lora_key
|
||||||
|
|
||||||
|
if not lora_key.startswith(prefix):
|
||||||
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||||
|
|
||||||
|
module = model
|
||||||
|
module_key = ""
|
||||||
|
key_parts = lora_key[len(prefix) :].split("_")
|
||||||
|
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
|
while len(key_parts) > 0:
|
||||||
|
try:
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key += "." + submodule_name
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
except Exception:
|
||||||
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
|
return (module_key, module)
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
@ -67,9 +68,15 @@ class ExtensionsManager:
|
|||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# TODO: create weight patch logic in PR with extension which uses it
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
with ExitStack() as exit_stack:
|
try:
|
||||||
for ext in self._extensions:
|
with ExitStack() as exit_stack:
|
||||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
for ext in self._extensions:
|
||||||
|
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with torch.no_grad():
|
||||||
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
|
unet.get_parameter(param_key).copy_(weight)
|
||||||
|
@ -20,10 +20,14 @@ from diffusers import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
# TODO: add dpmpp_3s/dpmpp_3s_k when fix released
|
||||||
|
# https://github.com/huggingface/diffusers/issues/9007
|
||||||
|
|
||||||
SCHEDULER_NAME_VALUES = Literal[
|
SCHEDULER_NAME_VALUES = Literal[
|
||||||
"ddim",
|
"ddim",
|
||||||
"ddpm",
|
"ddpm",
|
||||||
"deis",
|
"deis",
|
||||||
|
"deis_k",
|
||||||
"lms",
|
"lms",
|
||||||
"lms_k",
|
"lms_k",
|
||||||
"pndm",
|
"pndm",
|
||||||
@ -33,16 +37,21 @@ SCHEDULER_NAME_VALUES = Literal[
|
|||||||
"euler_k",
|
"euler_k",
|
||||||
"euler_a",
|
"euler_a",
|
||||||
"kdpm_2",
|
"kdpm_2",
|
||||||
|
"kdpm_2_k",
|
||||||
"kdpm_2_a",
|
"kdpm_2_a",
|
||||||
|
"kdpm_2_a_k",
|
||||||
"dpmpp_2s",
|
"dpmpp_2s",
|
||||||
"dpmpp_2s_k",
|
"dpmpp_2s_k",
|
||||||
"dpmpp_2m",
|
"dpmpp_2m",
|
||||||
"dpmpp_2m_k",
|
"dpmpp_2m_k",
|
||||||
"dpmpp_2m_sde",
|
"dpmpp_2m_sde",
|
||||||
"dpmpp_2m_sde_k",
|
"dpmpp_2m_sde_k",
|
||||||
|
"dpmpp_3m",
|
||||||
|
"dpmpp_3m_k",
|
||||||
"dpmpp_sde",
|
"dpmpp_sde",
|
||||||
"dpmpp_sde_k",
|
"dpmpp_sde_k",
|
||||||
"unipc",
|
"unipc",
|
||||||
|
"unipc_k",
|
||||||
"lcm",
|
"lcm",
|
||||||
"tcd",
|
"tcd",
|
||||||
]
|
]
|
||||||
@ -50,7 +59,8 @@ SCHEDULER_NAME_VALUES = Literal[
|
|||||||
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
|
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
|
||||||
"ddim": (DDIMScheduler, {}),
|
"ddim": (DDIMScheduler, {}),
|
||||||
"ddpm": (DDPMScheduler, {}),
|
"ddpm": (DDPMScheduler, {}),
|
||||||
"deis": (DEISMultistepScheduler, {}),
|
"deis": (DEISMultistepScheduler, {"use_karras_sigmas": False}),
|
||||||
|
"deis_k": (DEISMultistepScheduler, {"use_karras_sigmas": True}),
|
||||||
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
"lms": (LMSDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"pndm": (PNDMScheduler, {}),
|
"pndm": (PNDMScheduler, {}),
|
||||||
@ -59,17 +69,28 @@ SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str,
|
|||||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
"kdpm_2": (KDPM2DiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
"kdpm_2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||||
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
"kdpm_2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||||
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
|
||||||
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
"dpmpp_2s_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
|
||||||
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
"dpmpp_2m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 2}),
|
||||||
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2}),
|
||||||
|
"dpmpp_2m_sde": (
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
{"use_karras_sigmas": False, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
|
||||||
|
),
|
||||||
|
"dpmpp_2m_sde_k": (
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
{"use_karras_sigmas": True, "solver_order": 2, "algorithm_type": "sde-dpmsolver++"},
|
||||||
|
),
|
||||||
|
"dpmpp_3m": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "solver_order": 3}),
|
||||||
|
"dpmpp_3m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 3}),
|
||||||
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
"dpmpp_sde": (DPMSolverSDEScheduler, {"use_karras_sigmas": False, "noise_sampler_seed": 0}),
|
||||||
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
"dpmpp_sde_k": (DPMSolverSDEScheduler, {"use_karras_sigmas": True, "noise_sampler_seed": 0}),
|
||||||
"unipc": (UniPCMultistepScheduler, {"cpu_only": True}),
|
"unipc": (UniPCMultistepScheduler, {"use_karras_sigmas": False, "cpu_only": True}),
|
||||||
|
"unipc_k": (UniPCMultistepScheduler, {"use_karras_sigmas": True, "cpu_only": True}),
|
||||||
"lcm": (LCMScheduler, {}),
|
"lcm": (LCMScheduler, {}),
|
||||||
"tcd": (TCDScheduler, {}),
|
"tcd": (TCDScheduler, {}),
|
||||||
}
|
}
|
||||||
|
39
invokeai/backend/util/original_weights_storage.py
Normal file
39
invokeai/backend/util/original_weights_storage.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
class OriginalWeightsStorage:
|
||||||
|
"""A class for tracking the original weights of a model for patch/unpatch operations."""
|
||||||
|
|
||||||
|
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
# The original weights of the model.
|
||||||
|
self._weights: dict[str, torch.Tensor] = {}
|
||||||
|
# The keys of the weights that have been changed (via `save()`) during the lifetime of this instance.
|
||||||
|
self._changed_weights: set[str] = set()
|
||||||
|
if cached_weights:
|
||||||
|
self._weights.update(cached_weights)
|
||||||
|
|
||||||
|
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
|
||||||
|
self._changed_weights.add(key)
|
||||||
|
if key in self._weights:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
|
||||||
|
|
||||||
|
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
|
||||||
|
weight = self._weights.get(key, None)
|
||||||
|
if weight is not None and copy:
|
||||||
|
weight = weight.clone()
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def contains(self, key: str) -> bool:
|
||||||
|
return key in self._weights
|
||||||
|
|
||||||
|
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
for key in self._changed_weights:
|
||||||
|
yield key, self._weights[key]
|
@ -16,6 +16,8 @@ import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterM
|
|||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
|
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
@ -34,9 +36,10 @@ interface Props {
|
|||||||
imageName: string;
|
imageName: string;
|
||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
|
destination?: InvokeTabName | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
const App = ({ config = DEFAULT_CONFIG, selectedImage, destination }: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(languageSelector);
|
||||||
const logger = useLogger('system');
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -67,6 +70,12 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
|||||||
}
|
}
|
||||||
}, [dispatch, config, logger]);
|
}, [dispatch, config, logger]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (destination) {
|
||||||
|
dispatch(setActiveTab(destination));
|
||||||
|
}
|
||||||
|
}, [dispatch, destination]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
dispatch(appStarted());
|
dispatch(appStarted());
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
@ -19,6 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
|
|||||||
import Loading from 'common/components/Loading/Loading';
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||||
|
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import type { PropsWithChildren, ReactNode } from 'react';
|
import type { PropsWithChildren, ReactNode } from 'react';
|
||||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
@ -43,6 +44,7 @@ interface Props extends PropsWithChildren {
|
|||||||
imageName: string;
|
imageName: string;
|
||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
|
destination?: InvokeTabName;
|
||||||
customStarUi?: CustomStarUi;
|
customStarUi?: CustomStarUi;
|
||||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||||
isDebugging?: boolean;
|
isDebugging?: boolean;
|
||||||
@ -62,6 +64,7 @@ const InvokeAIUI = ({
|
|||||||
projectUrl,
|
projectUrl,
|
||||||
queueId,
|
queueId,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
|
destination,
|
||||||
customStarUi,
|
customStarUi,
|
||||||
socketOptions,
|
socketOptions,
|
||||||
isDebugging = false,
|
isDebugging = false,
|
||||||
@ -218,7 +221,7 @@ const InvokeAIUI = ({
|
|||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<AppDndContext>
|
<AppDndContext>
|
||||||
<App config={config} selectedImage={selectedImage} />
|
<App config={config} selectedImage={selectedImage} destination={destination} />
|
||||||
</AppDndContext>
|
</AppDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
|
@ -32,6 +32,7 @@ export const zSchedulerField = z.enum([
|
|||||||
'ddpm',
|
'ddpm',
|
||||||
'dpmpp_2s',
|
'dpmpp_2s',
|
||||||
'dpmpp_2m',
|
'dpmpp_2m',
|
||||||
|
'dpmpp_3m',
|
||||||
'dpmpp_2m_sde',
|
'dpmpp_2m_sde',
|
||||||
'dpmpp_sde',
|
'dpmpp_sde',
|
||||||
'heun',
|
'heun',
|
||||||
@ -40,12 +41,17 @@ export const zSchedulerField = z.enum([
|
|||||||
'pndm',
|
'pndm',
|
||||||
'unipc',
|
'unipc',
|
||||||
'euler_k',
|
'euler_k',
|
||||||
|
'deis_k',
|
||||||
'dpmpp_2s_k',
|
'dpmpp_2s_k',
|
||||||
'dpmpp_2m_k',
|
'dpmpp_2m_k',
|
||||||
|
'dpmpp_3m_k',
|
||||||
'dpmpp_2m_sde_k',
|
'dpmpp_2m_sde_k',
|
||||||
'dpmpp_sde_k',
|
'dpmpp_sde_k',
|
||||||
'heun_k',
|
'heun_k',
|
||||||
|
'kdpm_2_k',
|
||||||
|
'kdpm_2_a_k',
|
||||||
'lms_k',
|
'lms_k',
|
||||||
|
'unipc_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
'kdpm_2_a',
|
'kdpm_2_a',
|
||||||
'lcm',
|
'lcm',
|
||||||
|
@ -125,19 +125,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
||||||
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
cfg_scale,
|
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
model: Graph.getModelMetadataField(modelConfig),
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
vae: vae ?? undefined,
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
posCondNode = g.addNode({
|
posCondNode = g.addNode({
|
||||||
@ -166,24 +158,33 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
||||||
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
|
||||||
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
|
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
cfg_scale,
|
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model: Graph.getModelMetadataField(modelConfig),
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
vae: vae ?? undefined,
|
|
||||||
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
|
|
||||||
creativity,
|
|
||||||
structure,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
|
||||||
|
|
||||||
|
g.upsertMetadata({
|
||||||
|
cfg_scale,
|
||||||
|
model: Graph.getModelMetadataField(modelConfig),
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
scheduler,
|
||||||
|
vae: vae ?? undefined,
|
||||||
|
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
|
||||||
|
creativity,
|
||||||
|
structure,
|
||||||
|
upscale_initial_image: {
|
||||||
|
image_name: upscaleInitialImage.image_name,
|
||||||
|
width: upscaleInitialImage.width,
|
||||||
|
height: upscaleInitialImage.height,
|
||||||
|
},
|
||||||
|
upscale_scale: scale,
|
||||||
|
});
|
||||||
|
|
||||||
g.setMetadataReceivingNode(l2iNode);
|
g.setMetadataReceivingNode(l2iNode);
|
||||||
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
|
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
|
||||||
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
|
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
|
||||||
|
@ -52,28 +52,34 @@ export const CLIP_SKIP_MAP = {
|
|||||||
* Mapping of schedulers to human readable name
|
* Mapping of schedulers to human readable name
|
||||||
*/
|
*/
|
||||||
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
||||||
{ value: 'euler', label: 'Euler' },
|
|
||||||
{ value: 'deis', label: 'DEIS' },
|
|
||||||
{ value: 'ddim', label: 'DDIM' },
|
{ value: 'ddim', label: 'DDIM' },
|
||||||
{ value: 'ddpm', label: 'DDPM' },
|
{ value: 'ddpm', label: 'DDPM' },
|
||||||
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
|
{ value: 'deis', label: 'DEIS' },
|
||||||
|
{ value: 'deis_k', label: 'DEIS Karras' },
|
||||||
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
|
{ value: 'dpmpp_2s', label: 'DPM++ 2S' },
|
||||||
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
|
|
||||||
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
|
|
||||||
{ value: 'heun', label: 'Heun' },
|
|
||||||
{ value: 'kdpm_2', label: 'KDPM 2' },
|
|
||||||
{ value: 'lms', label: 'LMS' },
|
|
||||||
{ value: 'pndm', label: 'PNDM' },
|
|
||||||
{ value: 'unipc', label: 'UniPC' },
|
|
||||||
{ value: 'euler_k', label: 'Euler Karras' },
|
|
||||||
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
|
|
||||||
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
|
{ value: 'dpmpp_2s_k', label: 'DPM++ 2S Karras' },
|
||||||
|
{ value: 'dpmpp_2m', label: 'DPM++ 2M' },
|
||||||
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
|
{ value: 'dpmpp_2m_k', label: 'DPM++ 2M Karras' },
|
||||||
|
{ value: 'dpmpp_2m_sde', label: 'DPM++ 2M SDE' },
|
||||||
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
|
{ value: 'dpmpp_2m_sde_k', label: 'DPM++ 2M SDE Karras' },
|
||||||
{ value: 'heun_k', label: 'Heun Karras' },
|
{ value: 'dpmpp_3m', label: 'DPM++ 3M' },
|
||||||
{ value: 'lms_k', label: 'LMS Karras' },
|
{ value: 'dpmpp_3m_k', label: 'DPM++ 3M Karras' },
|
||||||
|
{ value: 'dpmpp_sde', label: 'DPM++ SDE' },
|
||||||
|
{ value: 'dpmpp_sde_k', label: 'DPM++ SDE Karras' },
|
||||||
|
{ value: 'euler', label: 'Euler' },
|
||||||
|
{ value: 'euler_k', label: 'Euler Karras' },
|
||||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||||
|
{ value: 'heun', label: 'Heun' },
|
||||||
|
{ value: 'heun_k', label: 'Heun Karras' },
|
||||||
|
{ value: 'kdpm_2', label: 'KDPM 2' },
|
||||||
|
{ value: 'kdpm_2_k', label: 'KDPM 2 Karras' },
|
||||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||||
|
{ value: 'kdpm_2_a_k', label: 'KDPM 2 Ancestral Karras' },
|
||||||
{ value: 'lcm', label: 'LCM' },
|
{ value: 'lcm', label: 'LCM' },
|
||||||
|
{ value: 'lms', label: 'LMS' },
|
||||||
|
{ value: 'lms_k', label: 'LMS Karras' },
|
||||||
|
{ value: 'pndm', label: 'PNDM' },
|
||||||
{ value: 'tcd', label: 'TCD' },
|
{ value: 'tcd', label: 'TCD' },
|
||||||
].sort((a, b) => a.label.localeCompare(b.label));
|
{ value: 'unipc', label: 'UniPC' },
|
||||||
|
{ value: 'unipc_k', label: 'UniPC Karras' },
|
||||||
|
];
|
||||||
|
@ -3553,7 +3553,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* UNet
|
* UNet
|
||||||
* @description UNet (scheduler, LoRAs)
|
* @description UNet (scheduler, LoRAs)
|
||||||
@ -8553,7 +8553,7 @@ export type components = {
|
|||||||
* Scheduler
|
* Scheduler
|
||||||
* @description Default scheduler for this model
|
* @description Default scheduler for this model
|
||||||
*/
|
*/
|
||||||
scheduler?: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd") | null;
|
scheduler?: ("ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd") | null;
|
||||||
/**
|
/**
|
||||||
* Steps
|
* Steps
|
||||||
* @description Default number of steps for this model
|
* @description Default number of steps for this model
|
||||||
@ -11467,7 +11467,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default scheduler
|
* @default scheduler
|
||||||
@ -11483,7 +11483,7 @@ export type components = {
|
|||||||
* @description Scheduler to use during inference
|
* @description Scheduler to use during inference
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default scheduler_output
|
* @default scheduler_output
|
||||||
@ -13261,7 +13261,7 @@ export type components = {
|
|||||||
* @default euler
|
* @default euler
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
scheduler?: "ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm" | "tcd";
|
scheduler?: "ddim" | "ddpm" | "deis" | "deis_k" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_k" | "kdpm_2_a" | "kdpm_2_a_k" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_3m" | "dpmpp_3m_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "unipc_k" | "lcm" | "tcd";
|
||||||
/**
|
/**
|
||||||
* UNet
|
* UNet
|
||||||
* @description UNet (scheduler, LoRAs)
|
* @description UNet (scheduler, LoRAs)
|
||||||
|
Loading…
Reference in New Issue
Block a user