mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into maryhipp/style-presets
This commit is contained in:
commit
a7b83c8b5b
@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
ModelPatcher.apply_lora(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
model_state_dict=state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
@ -21,6 +21,8 @@ from controlnet_aux import (
|
|||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
from transformers import pipeline
|
||||||
|
from transformers.pipelines import DepthEstimationPipeline
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||||
from invokeai.backend.image_util.canny import get_canny_edges
|
from invokeai.backend.image_util.canny import get_canny_edges
|
||||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||||
from invokeai.backend.image_util.hed import HEDProcessor
|
from invokeai.backend.image_util.hed import HEDProcessor
|
||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return color_map
|
return color_map
|
||||||
|
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
||||||
|
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
||||||
|
DEPTH_ANYTHING_MODELS = {
|
||||||
|
"large": "LiheYoung/depth-anything-large-hf",
|
||||||
|
"base": "LiheYoung/depth-anything-base-hf",
|
||||||
|
"small": "LiheYoung/depth-anything-small-hf",
|
||||||
|
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
|||||||
title="Depth Anything Processor",
|
title="Depth Anything Processor",
|
||||||
tags=["controlnet", "depth", "depth anything"],
|
tags=["controlnet", "depth", "depth anything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.1.2",
|
version="1.1.3",
|
||||||
)
|
)
|
||||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||||
|
|
||||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||||
default="small", description="The size of the depth model to use"
|
default="small_v2", description="The size of the depth model to use"
|
||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
def loader(model_path: Path):
|
def load_depth_anything(model_path: Path):
|
||||||
return DepthAnythingDetector.load_model(
|
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||||
)
|
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||||
|
|
||||||
with self._context.models.load_remote_model(
|
with self._context.models.load_remote_model(
|
||||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||||
) as model:
|
) as depth_anything_detector:
|
||||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
depth_map = depth_anything_detector.generate_depth(image)
|
||||||
return processed_image
|
|
||||||
|
# Resizing to user target specified size
|
||||||
|
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
||||||
|
depth_map = depth_map.resize((self.resolution, new_height))
|
||||||
|
|
||||||
|
return depth_map
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -62,6 +62,7 @@ from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetEx
|
|||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
@ -845,6 +846,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.freeu_config:
|
if self.unet.freeu_config:
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
|
### lora
|
||||||
|
if self.unet.loras:
|
||||||
|
for lora_field in self.unet.loras:
|
||||||
|
ext_manager.add_extension(
|
||||||
|
LoRAExt(
|
||||||
|
node_context=context,
|
||||||
|
model_id=lora_field.lora,
|
||||||
|
weight=lora_field.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
### seamless
|
### seamless
|
||||||
if self.unet.seamless_axes:
|
if self.unet.seamless_axes:
|
||||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
@ -964,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (cached_weights, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBoxField(BaseModel):
|
||||||
|
"""A bounding box primitive value."""
|
||||||
|
|
||||||
|
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
||||||
|
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
||||||
|
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
||||||
|
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
||||||
|
|
||||||
|
score: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
||||||
|
"when the bounding box was produced by a detector and has an associated confidence score.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_coords(self):
|
||||||
|
if self.x_min > self.x_max:
|
||||||
|
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
||||||
|
if self.y_min > self.y_max:
|
||||||
|
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
100
invokeai/app/invocations/grounding_dino.py
Normal file
100
invokeai/app/invocations/grounding_dino.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import pipeline
|
||||||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
||||||
|
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
|
||||||
|
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
||||||
|
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||||
|
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
||||||
|
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"grounding_dino",
|
||||||
|
title="Grounding DINO (Text Prompt Object Detection)",
|
||||||
|
tags=["prompt", "object detection"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class GroundingDinoInvocation(BaseInvocation):
|
||||||
|
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
|
||||||
|
|
||||||
|
# Reference:
|
||||||
|
# - https://arxiv.org/pdf/2303.05499
|
||||||
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
||||||
|
prompt: str = InputField(description="The prompt describing the object to segment.")
|
||||||
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
detection_threshold: float = InputField(
|
||||||
|
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
default=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
||||||
|
# The model expects a 3-channel RGB image.
|
||||||
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
detections = self._detect(
|
||||||
|
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert detections to BoundingBoxCollectionOutput.
|
||||||
|
bounding_boxes: list[BoundingBoxField] = []
|
||||||
|
for detection in detections:
|
||||||
|
bounding_boxes.append(
|
||||||
|
BoundingBoxField(
|
||||||
|
x_min=detection.box.xmin,
|
||||||
|
x_max=detection.box.xmax,
|
||||||
|
y_min=detection.box.ymin,
|
||||||
|
y_max=detection.box.ymax,
|
||||||
|
score=detection.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_grounding_dino(model_path: Path):
|
||||||
|
grounding_dino_pipeline = pipeline(
|
||||||
|
model=str(model_path),
|
||||||
|
task="zero-shot-object-detection",
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
||||||
|
return GroundingDinoPipeline(grounding_dino_pipeline)
|
||||||
|
|
||||||
|
def _detect(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
labels: list[str],
|
||||||
|
threshold: float = 0.3,
|
||||||
|
) -> list[DetectionResult]:
|
||||||
|
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
||||||
|
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
||||||
|
# actually makes a difference.
|
||||||
|
labels = [label if label.endswith(".") else label + "." for label in labels]
|
||||||
|
|
||||||
|
with context.models.load_remote_model(
|
||||||
|
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
||||||
|
) as detector:
|
||||||
|
assert isinstance(detector, GroundingDinoPipeline)
|
||||||
|
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
@ -1,9 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
height=mask.shape[1],
|
height=mask.shape[1],
|
||||||
width=mask.shape[2],
|
width=mask.shape[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"tensor_mask_to_image",
|
||||||
|
title="Tensor Mask to Image",
|
||||||
|
tags=["mask"],
|
||||||
|
category="mask",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Convert a mask tensor to an image."""
|
||||||
|
|
||||||
|
mask: TensorField = InputField(description="The mask tensor to convert.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
mask = context.tensors.load(self.mask.tensor_name)
|
||||||
|
# Ensure that the mask is binary.
|
||||||
|
if mask.dtype != torch.bool:
|
||||||
|
mask = mask > 0.5
|
||||||
|
mask_np = (mask.float() * 255).byte().cpu().numpy()
|
||||||
|
|
||||||
|
mask_pil = Image.fromarray(mask_np, mode="L")
|
||||||
|
image_dto = context.images.save(image=mask_pil)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
|
BoundingBoxField,
|
||||||
ColorField,
|
ColorField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@ -469,3 +470,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
# region BoundingBox
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_output")
|
||||||
|
class BoundingBoxOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single bounding box"""
|
||||||
|
|
||||||
|
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("bounding_box_collection_output")
|
||||||
|
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a collection of bounding boxes"""
|
||||||
|
|
||||||
|
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"bounding_box",
|
||||||
|
title="Bounding Box",
|
||||||
|
tags=["primitives", "segmentation", "collection", "bounding box"],
|
||||||
|
category="primitives",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class BoundingBoxInvocation(BaseInvocation):
|
||||||
|
"""Create a bounding box manually by supplying box coordinates"""
|
||||||
|
|
||||||
|
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
|
||||||
|
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
|
||||||
|
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
|
||||||
|
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
|
||||||
|
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
||||||
|
return BoundingBoxOutput(bounding_box=bounding_box)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
161
invokeai/app/invocations/segment_anything.py
Normal file
161
invokeai/app/invocations/segment_anything.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||||
|
from transformers.models.sam import SamModel
|
||||||
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
|
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||||
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||||
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
|
|
||||||
|
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||||
|
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
||||||
|
"segment-anything-base": "facebook/sam-vit-base",
|
||||||
|
"segment-anything-large": "facebook/sam-vit-large",
|
||||||
|
"segment-anything-huge": "facebook/sam-vit-huge",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"segment_anything",
|
||||||
|
title="Segment Anything",
|
||||||
|
tags=["prompt", "segmentation"],
|
||||||
|
category="segmentation",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class SegmentAnythingInvocation(BaseInvocation):
|
||||||
|
"""Runs a Segment Anything Model."""
|
||||||
|
|
||||||
|
# Reference:
|
||||||
|
# - https://arxiv.org/pdf/2304.02643
|
||||||
|
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||||
|
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||||
|
image: ImageField = InputField(description="The image to segment.")
|
||||||
|
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||||
|
apply_polygon_refinement: bool = InputField(
|
||||||
|
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||||
|
description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||||
|
default="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
# The models expect a 3-channel RGB image.
|
||||||
|
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
if len(self.bounding_boxes) == 0:
|
||||||
|
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
masks = self._segment(context=context, image=image_pil)
|
||||||
|
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||||
|
|
||||||
|
# masks contains bool values, so we merge them via max-reduce.
|
||||||
|
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||||
|
|
||||||
|
mask_tensor_name = context.tensors.save(combined_mask)
|
||||||
|
height, width = combined_mask.shape
|
||||||
|
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_sam_model(model_path: Path):
|
||||||
|
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
local_files_only=True,
|
||||||
|
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||||
|
# model, and figure out how to make it work in the pipeline.
|
||||||
|
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||||
|
)
|
||||||
|
assert isinstance(sam_model, SamModel)
|
||||||
|
|
||||||
|
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||||
|
assert isinstance(sam_processor, SamProcessor)
|
||||||
|
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||||
|
|
||||||
|
def _segment(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
image: Image.Image,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||||
|
# Convert the bounding boxes to the SAM input format.
|
||||||
|
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
||||||
|
|
||||||
|
with (
|
||||||
|
context.models.load_remote_model(
|
||||||
|
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||||
|
) as sam_pipeline,
|
||||||
|
):
|
||||||
|
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||||
|
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||||
|
|
||||||
|
masks = self._process_masks(masks)
|
||||||
|
if self.apply_polygon_refinement:
|
||||||
|
masks = self._apply_polygon_refinement(masks)
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||||
|
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||||
|
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||||
|
"""
|
||||||
|
assert masks.dtype == torch.bool
|
||||||
|
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||||
|
masks, _ = masks.max(dim=1)
|
||||||
|
# Split the first dimension into a list of masks.
|
||||||
|
return list(masks.cpu().unbind(dim=0))
|
||||||
|
|
||||||
|
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||||
|
"""Apply polygon refinement to the masks.
|
||||||
|
|
||||||
|
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||||
|
- Smooth the edges of the mask slightly.
|
||||||
|
- Ensure that each mask consists of a single closed polygon
|
||||||
|
- Removes small mask pieces.
|
||||||
|
- Removes holes from the mask.
|
||||||
|
"""
|
||||||
|
# Convert tensor masks to np masks.
|
||||||
|
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||||
|
|
||||||
|
# Apply polygon refinement.
|
||||||
|
for idx, mask in enumerate(np_masks):
|
||||||
|
shape = mask.shape
|
||||||
|
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||||
|
polygon = mask_to_polygon(mask)
|
||||||
|
mask = polygon_to_mask(polygon, shape)
|
||||||
|
np_masks[idx] = mask
|
||||||
|
|
||||||
|
# Convert np masks back to tensor masks.
|
||||||
|
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||||
|
"""Filter the detected masks based on the specified mask filter."""
|
||||||
|
assert len(masks) == len(bounding_boxes)
|
||||||
|
|
||||||
|
if self.mask_filter == "all":
|
||||||
|
return masks
|
||||||
|
elif self.mask_filter == "largest":
|
||||||
|
# Find the largest mask.
|
||||||
|
return [max(masks, key=lambda x: float(x.sum()))]
|
||||||
|
elif self.mask_filter == "highest_box_score":
|
||||||
|
# Find the index of the bounding box with the highest score.
|
||||||
|
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||||
|
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||||
|
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||||
|
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||||
|
return [masks[max_score_idx]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
@ -1,11 +1,10 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
|
||||||
|
|
||||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
from invokeai.app.services.image_files.image_files_common import (
|
from invokeai.app.services.image_files.image_files_common import (
|
||||||
@ -20,18 +19,12 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
|||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: Path
|
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
|
||||||
__cache: Dict[Path, PILImageType]
|
|
||||||
__max_cache_size: int
|
|
||||||
__invoker: Invoker
|
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = {}
|
self.__cache: dict[Path, PILImageType] = {}
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue[Path]()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
@ -103,7 +96,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
if image_path.exists():
|
if image_path.exists():
|
||||||
send2trash(image_path)
|
image_path.unlink()
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
@ -111,7 +104,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||||
|
|
||||||
if thumbnail_path.exists():
|
if thumbnail_path.exists():
|
||||||
send2trash(thumbnail_path)
|
thumbnail_path.unlink()
|
||||||
if thumbnail_path in self.__cache:
|
if thumbnail_path in self.__cache:
|
||||||
del self.__cache[thumbnail_path]
|
del self.__cache[thumbnail_path]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
||||||
@ -70,7 +69,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
|||||||
if not self._validate_path(path):
|
if not self._validate_path(path):
|
||||||
raise ModelImageFileNotFoundException
|
raise ModelImageFileNotFoundException
|
||||||
|
|
||||||
send2trash(path)
|
path.unlink()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ModelImageFileDeleteException from e
|
raise ModelImageFileDeleteException from e
|
||||||
|
@ -1,90 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import repeat
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms import Compose
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
config = get_config()
|
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
|
||||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
|
||||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
|
||||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
transform = Compose(
|
|
||||||
[
|
|
||||||
Resize(
|
|
||||||
width=518,
|
|
||||||
height=518,
|
|
||||||
resize_target=False,
|
|
||||||
keep_aspect_ratio=True,
|
|
||||||
ensure_multiple_of=14,
|
|
||||||
resize_method="lower_bound",
|
|
||||||
image_interpolation_method=cv2.INTER_CUBIC,
|
|
||||||
),
|
|
||||||
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
PrepareForNet(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DepthAnythingDetector:
|
|
||||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_model(
|
|
||||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
|
||||||
) -> DPT_DINOv2:
|
|
||||||
match model_size:
|
|
||||||
case "small":
|
|
||||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
|
||||||
case "base":
|
|
||||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
|
||||||
case "large":
|
|
||||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
|
||||||
|
|
||||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
|
||||||
if not self.model:
|
|
||||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
|
||||||
return image
|
|
||||||
|
|
||||||
np_image = np.array(image, dtype=np.uint8)
|
|
||||||
np_image = np_image[:, :, ::-1] / 255.0
|
|
||||||
|
|
||||||
image_height, image_width = np_image.shape[:2]
|
|
||||||
np_image = transform({"image": np_image})["image"]
|
|
||||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
depth = self.model(tensor_image)
|
|
||||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
|
||||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
|
||||||
|
|
||||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
|
||||||
depth_map = Image.fromarray(depth_map)
|
|
||||||
|
|
||||||
new_height = int(image_height * (resolution / image_width))
|
|
||||||
depth_map = depth_map.resize((resolution, new_height))
|
|
||||||
|
|
||||||
return depth_map
|
|
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.pipelines import DepthEstimationPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class DepthAnythingPipeline(RawModel):
|
||||||
|
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||||
|
for Invoke's Model Management System"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||||
|
self._pipeline = pipeline
|
||||||
|
|
||||||
|
def generate_depth(self, image: Image.Image) -> Image.Image:
|
||||||
|
depth_map = self._pipeline(image)["depth"]
|
||||||
|
assert isinstance(depth_map, Image.Image)
|
||||||
|
return depth_map
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
|
self._pipeline.device = self._pipeline.model.device
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self._pipeline.model)
|
@ -1,145 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
|
||||||
scratch = nn.Module()
|
|
||||||
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape
|
|
||||||
out_shape3 = out_shape
|
|
||||||
if len(in_shape) >= 4:
|
|
||||||
out_shape4 = out_shape
|
|
||||||
|
|
||||||
if expand:
|
|
||||||
out_shape1 = out_shape
|
|
||||||
out_shape2 = out_shape * 2
|
|
||||||
out_shape3 = out_shape * 4
|
|
||||||
if len(in_shape) >= 4:
|
|
||||||
out_shape4 = out_shape * 8
|
|
||||||
|
|
||||||
scratch.layer1_rn = nn.Conv2d(
|
|
||||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer2_rn = nn.Conv2d(
|
|
||||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
scratch.layer3_rn = nn.Conv2d(
|
|
||||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
if len(in_shape) >= 4:
|
|
||||||
scratch.layer4_rn = nn.Conv2d(
|
|
||||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
|
||||||
)
|
|
||||||
|
|
||||||
return scratch
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConvUnit(nn.Module):
|
|
||||||
"""Residual convolution module."""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, bn):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.bn = bn
|
|
||||||
|
|
||||||
self.groups = 1
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
|
||||||
|
|
||||||
if self.bn:
|
|
||||||
self.bn1 = nn.BatchNorm2d(features)
|
|
||||||
self.bn2 = nn.BatchNorm2d(features)
|
|
||||||
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (tensor): input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
|
|
||||||
out = self.activation(x)
|
|
||||||
out = self.conv1(out)
|
|
||||||
if self.bn:
|
|
||||||
out = self.bn1(out)
|
|
||||||
|
|
||||||
out = self.activation(out)
|
|
||||||
out = self.conv2(out)
|
|
||||||
if self.bn:
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.groups > 1:
|
|
||||||
out = self.conv_merge(out)
|
|
||||||
|
|
||||||
return self.skip_add.add(out, x)
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureFusionBlock(nn.Module):
|
|
||||||
"""Feature fusion block."""
|
|
||||||
|
|
||||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features (int): number of features
|
|
||||||
"""
|
|
||||||
super(FeatureFusionBlock, self).__init__()
|
|
||||||
|
|
||||||
self.deconv = deconv
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
self.groups = 1
|
|
||||||
|
|
||||||
self.expand = expand
|
|
||||||
out_features = features
|
|
||||||
if self.expand:
|
|
||||||
out_features = features // 2
|
|
||||||
|
|
||||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
|
||||||
|
|
||||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
|
||||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
|
||||||
|
|
||||||
self.skip_add = nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def forward(self, *xs, size=None):
|
|
||||||
"""Forward pass.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor: output
|
|
||||||
"""
|
|
||||||
output = xs[0]
|
|
||||||
|
|
||||||
if len(xs) == 2:
|
|
||||||
res = self.resConfUnit1(xs[1])
|
|
||||||
output = self.skip_add.add(output, res)
|
|
||||||
|
|
||||||
output = self.resConfUnit2(output)
|
|
||||||
|
|
||||||
if (size is None) and (self.size is None):
|
|
||||||
modifier = {"scale_factor": 2}
|
|
||||||
elif size is None:
|
|
||||||
modifier = {"size": self.size}
|
|
||||||
else:
|
|
||||||
modifier = {"size": size}
|
|
||||||
|
|
||||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
|
||||||
|
|
||||||
output = self.out_conv(output)
|
|
||||||
|
|
||||||
return output
|
|
@ -1,183 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
|
||||||
|
|
||||||
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
|
||||||
|
|
||||||
|
|
||||||
def _make_fusion_block(features, use_bn, size=None):
|
|
||||||
return FeatureFusionBlock(
|
|
||||||
features,
|
|
||||||
nn.ReLU(False),
|
|
||||||
deconv=False,
|
|
||||||
bn=use_bn,
|
|
||||||
expand=False,
|
|
||||||
align_corners=True,
|
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DPTHead(nn.Module):
|
|
||||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
|
||||||
super(DPTHead, self).__init__()
|
|
||||||
|
|
||||||
self.nclass = nclass
|
|
||||||
self.use_clstoken = use_clstoken
|
|
||||||
|
|
||||||
self.projects = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channel,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
for out_channel in out_channels
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.resize_layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
|
||||||
),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
|
||||||
),
|
|
||||||
nn.Identity(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_clstoken:
|
|
||||||
self.readout_projects = nn.ModuleList()
|
|
||||||
for _ in range(len(self.projects)):
|
|
||||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
|
||||||
|
|
||||||
self.scratch = _make_scratch(
|
|
||||||
out_channels,
|
|
||||||
features,
|
|
||||||
groups=1,
|
|
||||||
expand=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scratch.stem_transpose = None
|
|
||||||
|
|
||||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
|
||||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
|
||||||
|
|
||||||
head_features_1 = features
|
|
||||||
head_features_2 = 32
|
|
||||||
|
|
||||||
if nclass > 1:
|
|
||||||
self.scratch.output_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.scratch.output_conv1 = nn.Conv2d(
|
|
||||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.scratch.output_conv2 = nn.Sequential(
|
|
||||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, out_features, patch_h, patch_w):
|
|
||||||
out = []
|
|
||||||
for i, x in enumerate(out_features):
|
|
||||||
if self.use_clstoken:
|
|
||||||
x, cls_token = x[0], x[1]
|
|
||||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
|
||||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
|
||||||
else:
|
|
||||||
x = x[0]
|
|
||||||
|
|
||||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
|
||||||
|
|
||||||
x = self.projects[i](x)
|
|
||||||
x = self.resize_layers[i](x)
|
|
||||||
|
|
||||||
out.append(x)
|
|
||||||
|
|
||||||
layer_1, layer_2, layer_3, layer_4 = out
|
|
||||||
|
|
||||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
|
||||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
|
||||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
|
||||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
|
||||||
|
|
||||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
|
||||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
|
||||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
|
||||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
|
||||||
|
|
||||||
out = self.scratch.output_conv1(path_1)
|
|
||||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
|
||||||
out = self.scratch.output_conv2(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DPT_DINOv2(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
features,
|
|
||||||
out_channels,
|
|
||||||
encoder="vitl",
|
|
||||||
use_bn=False,
|
|
||||||
use_clstoken=False,
|
|
||||||
):
|
|
||||||
super(DPT_DINOv2, self).__init__()
|
|
||||||
|
|
||||||
assert encoder in ["vits", "vitb", "vitl"]
|
|
||||||
|
|
||||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
|
||||||
# if use_local:
|
|
||||||
# self.pretrained = torch.hub.load(
|
|
||||||
# torchhub_path / "facebookresearch_dinov2_main",
|
|
||||||
# "dinov2_{:}14".format(encoder),
|
|
||||||
# source="local",
|
|
||||||
# pretrained=False,
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# self.pretrained = torch.hub.load(
|
|
||||||
# "facebookresearch/dinov2",
|
|
||||||
# "dinov2_{:}14".format(encoder),
|
|
||||||
# )
|
|
||||||
|
|
||||||
self.pretrained = torch.hub.load(
|
|
||||||
"facebookresearch/dinov2",
|
|
||||||
"dinov2_{:}14".format(encoder),
|
|
||||||
)
|
|
||||||
|
|
||||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
|
||||||
|
|
||||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
|
|
||||||
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
|
||||||
|
|
||||||
patch_h, patch_w = h // 14, w // 14
|
|
||||||
|
|
||||||
depth = self.depth_head(features, patch_h, patch_w)
|
|
||||||
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
|
||||||
depth = F.relu(depth)
|
|
||||||
|
|
||||||
return depth.squeeze(1)
|
|
@ -1,227 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
|
||||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample (dict): sample
|
|
||||||
size (tuple): image size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: new size
|
|
||||||
"""
|
|
||||||
shape = list(sample["disparity"].shape)
|
|
||||||
|
|
||||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
|
||||||
return sample
|
|
||||||
|
|
||||||
scale = [0, 0]
|
|
||||||
scale[0] = size[0] / shape[0]
|
|
||||||
scale[1] = size[1] / shape[1]
|
|
||||||
|
|
||||||
scale = max(scale)
|
|
||||||
|
|
||||||
shape[0] = math.ceil(scale * shape[0])
|
|
||||||
shape[1] = math.ceil(scale * shape[1])
|
|
||||||
|
|
||||||
# resize
|
|
||||||
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
|
||||||
|
|
||||||
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
tuple(shape[::-1]),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
return tuple(shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(object):
|
|
||||||
"""Resize sample to given size (width, height)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
resize_target=True,
|
|
||||||
keep_aspect_ratio=False,
|
|
||||||
ensure_multiple_of=1,
|
|
||||||
resize_method="lower_bound",
|
|
||||||
image_interpolation_method=cv2.INTER_AREA,
|
|
||||||
):
|
|
||||||
"""Init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
width (int): desired output width
|
|
||||||
height (int): desired output height
|
|
||||||
resize_target (bool, optional):
|
|
||||||
True: Resize the full sample (image, mask, target).
|
|
||||||
False: Resize image only.
|
|
||||||
Defaults to True.
|
|
||||||
keep_aspect_ratio (bool, optional):
|
|
||||||
True: Keep the aspect ratio of the input sample.
|
|
||||||
Output sample might not have the given width and height, and
|
|
||||||
resize behaviour depends on the parameter 'resize_method'.
|
|
||||||
Defaults to False.
|
|
||||||
ensure_multiple_of (int, optional):
|
|
||||||
Output width and height is constrained to be multiple of this parameter.
|
|
||||||
Defaults to 1.
|
|
||||||
resize_method (str, optional):
|
|
||||||
"lower_bound": Output will be at least as large as the given size.
|
|
||||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
|
||||||
than given size.)
|
|
||||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
|
||||||
Defaults to "lower_bound".
|
|
||||||
"""
|
|
||||||
self.__width = width
|
|
||||||
self.__height = height
|
|
||||||
|
|
||||||
self.__resize_target = resize_target
|
|
||||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
|
||||||
self.__multiple_of = ensure_multiple_of
|
|
||||||
self.__resize_method = resize_method
|
|
||||||
self.__image_interpolation_method = image_interpolation_method
|
|
||||||
|
|
||||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
|
||||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if max_val is not None and y > max_val:
|
|
||||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
if y < min_val:
|
|
||||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
def get_size(self, width, height):
|
|
||||||
# determine new height and width
|
|
||||||
scale_height = self.__height / height
|
|
||||||
scale_width = self.__width / width
|
|
||||||
|
|
||||||
if self.__keep_aspect_ratio:
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
# scale such that output size is lower bound
|
|
||||||
if scale_width > scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
# scale such that output size is upper bound
|
|
||||||
if scale_width < scale_height:
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
# scale as least as possbile
|
|
||||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
|
||||||
# fit width
|
|
||||||
scale_height = scale_width
|
|
||||||
else:
|
|
||||||
# fit height
|
|
||||||
scale_width = scale_height
|
|
||||||
else:
|
|
||||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
|
||||||
|
|
||||||
if self.__resize_method == "lower_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
|
||||||
elif self.__resize_method == "upper_bound":
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
|
||||||
elif self.__resize_method == "minimal":
|
|
||||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
|
||||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
|
||||||
|
|
||||||
return (new_width, new_height)
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
|
||||||
|
|
||||||
# resize sample
|
|
||||||
sample["image"] = cv2.resize(
|
|
||||||
sample["image"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=self.__image_interpolation_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.__resize_target:
|
|
||||||
if "disparity" in sample:
|
|
||||||
sample["disparity"] = cv2.resize(
|
|
||||||
sample["disparity"],
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
|
||||||
|
|
||||||
if "semseg_mask" in sample:
|
|
||||||
# sample["semseg_mask"] = cv2.resize(
|
|
||||||
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
|
||||||
# )
|
|
||||||
sample["semseg_mask"] = F.interpolate(
|
|
||||||
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
|
||||||
).numpy()[0, 0]
|
|
||||||
|
|
||||||
if "mask" in sample:
|
|
||||||
sample["mask"] = cv2.resize(
|
|
||||||
sample["mask"].astype(np.float32),
|
|
||||||
(width, height),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
# sample["mask"] = sample["mask"].astype(bool)
|
|
||||||
|
|
||||||
# print(sample['image'].shape, sample['depth'].shape)
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizeImage(object):
|
|
||||||
"""Normlize image by given mean and std."""
|
|
||||||
|
|
||||||
def __init__(self, mean, std):
|
|
||||||
self.__mean = mean
|
|
||||||
self.__std = std
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareForNet(object):
|
|
||||||
"""Prepare sample for usage as network input."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
image = np.transpose(sample["image"], (2, 0, 1))
|
|
||||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
|
||||||
|
|
||||||
if "mask" in sample:
|
|
||||||
sample["mask"] = sample["mask"].astype(np.float32)
|
|
||||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
|
||||||
|
|
||||||
if "depth" in sample:
|
|
||||||
depth = sample["depth"].astype(np.float32)
|
|
||||||
sample["depth"] = np.ascontiguousarray(depth)
|
|
||||||
|
|
||||||
if "semseg_mask" in sample:
|
|
||||||
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
|
||||||
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
|
||||||
|
|
||||||
return sample
|
|
@ -0,0 +1,22 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBox(BaseModel):
|
||||||
|
"""Bounding box helper class."""
|
||||||
|
|
||||||
|
xmin: int
|
||||||
|
ymin: int
|
||||||
|
xmax: int
|
||||||
|
ymax: int
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionResult(BaseModel):
|
||||||
|
"""Detection result from Grounding DINO."""
|
||||||
|
|
||||||
|
score: float
|
||||||
|
label: str
|
||||||
|
box: BoundingBox
|
||||||
|
model_config = ConfigDict(
|
||||||
|
# Allow arbitrary types for mask, since it will be a numpy array.
|
||||||
|
arbitrary_types_allowed=True
|
||||||
|
)
|
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingDinoPipeline(RawModel):
|
||||||
|
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
||||||
|
management system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
||||||
|
self._pipeline = pipeline
|
||||||
|
|
||||||
|
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
||||||
|
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
||||||
|
assert results is not None
|
||||||
|
results = [DetectionResult.model_validate(result) for result in results]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
||||||
|
# CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
|
self._pipeline.model.to(device=device, dtype=dtype)
|
||||||
|
self._pipeline.device = self._pipeline.model.device
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self._pipeline.model)
|
@ -0,0 +1,50 @@
|
|||||||
|
# This file contains utilities for Grounded-SAM mask refinement based on:
|
||||||
|
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||||
|
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
|
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
|
||||||
|
"""Convert a binary mask to a polygon.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
||||||
|
"""
|
||||||
|
# Find contours in the binary mask.
|
||||||
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
# Find the contour with the largest area.
|
||||||
|
largest_contour = max(contours, key=cv2.contourArea)
|
||||||
|
|
||||||
|
# Extract the vertices of the contour.
|
||||||
|
polygon = largest_contour.reshape(-1, 2).tolist()
|
||||||
|
|
||||||
|
return polygon
|
||||||
|
|
||||||
|
|
||||||
|
def polygon_to_mask(
|
||||||
|
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
|
||||||
|
) -> npt.NDArray[np.uint8]:
|
||||||
|
"""Convert a polygon to a segmentation mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
||||||
|
image_shape (tuple): Shape of the image (height, width) for the mask.
|
||||||
|
fill_value (int): Value to fill the polygon with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Segmentation mask with the polygon filled (with value 255).
|
||||||
|
"""
|
||||||
|
# Create an empty mask.
|
||||||
|
mask = np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Convert polygon to an array of points.
|
||||||
|
pts = np.array(polygon, dtype=np.int32)
|
||||||
|
|
||||||
|
# Fill the polygon with white color (255).
|
||||||
|
cv2.fillPoly(mask, [pts], color=(fill_value,))
|
||||||
|
|
||||||
|
return mask
|
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.models.sam import SamModel
|
||||||
|
from transformers.models.sam.processing_sam import SamProcessor
|
||||||
|
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAnythingPipeline(RawModel):
|
||||||
|
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||||
|
|
||||||
|
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
||||||
|
self._sam_model = sam_model
|
||||||
|
self._sam_processor = sam_processor
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
||||||
|
if device is not None and device.type not in {"cpu", "cuda"}:
|
||||||
|
device = None
|
||||||
|
self._sam_model.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
# HACK(ryand): Fix the circular import issue.
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self._sam_model)
|
||||||
|
|
||||||
|
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
||||||
|
"""Run the SAM model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The image to segment.
|
||||||
|
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||||
|
[xmin, ymin, xmax, ymax].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||||
|
"""
|
||||||
|
# Add batch dimension of 1 to the bounding boxes.
|
||||||
|
boxes = [bounding_boxes]
|
||||||
|
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
||||||
|
outputs = self._sam_model(**inputs)
|
||||||
|
masks = self._sam_processor.post_process_masks(
|
||||||
|
masks=outputs.pred_masks,
|
||||||
|
original_sizes=inputs.original_sizes,
|
||||||
|
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# There should be only one batch.
|
||||||
|
assert len(masks) == 1
|
||||||
|
return masks[0]
|
@ -3,12 +3,13 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
@ -46,9 +47,19 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
return self.bias
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
params = {"weight": self.get_weight(orig_module.weight)}
|
||||||
|
bias = self.get_bias(orig_module.bias)
|
||||||
|
if bias is not None:
|
||||||
|
params["bias"] = bias
|
||||||
|
return params
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
for val in [self.bias]:
|
for val in [self.bias]:
|
||||||
@ -60,6 +71,17 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
"""Log a warning if values contains unhandled keys."""
|
||||||
|
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||||
|
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||||
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
|
if unknown_keys:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
if "lora_mid.weight" in values:
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lora_up.weight",
|
||||||
|
"lora_down.weight",
|
||||||
|
"lora_mid.weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.w1_b = values["hada_w1_b"]
|
self.w1_b = values["hada_w1_b"]
|
||||||
self.w2_a = values["hada_w2_a"]
|
self.w2_a = values["hada_w2_a"]
|
||||||
self.w2_b = values["hada_w2_b"]
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
self.t1 = values.get("hada_t1", None)
|
||||||
if "hada_t1" in values:
|
self.t2 = values.get("hada_t2", None)
|
||||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"hada_w1_a",
|
||||||
|
"hada_w1_b",
|
||||||
|
"hada_w2_a",
|
||||||
|
"hada_w2_b",
|
||||||
|
"hada_t1",
|
||||||
|
"hada_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -186,37 +216,45 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
self.w1 = values.get("lokr_w1", None)
|
||||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
if self.w1 is None:
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
self.w1_a = values["lokr_w1_a"]
|
||||||
self.w1_b = values["lokr_w1_b"]
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
else:
|
||||||
self.w2 = None
|
self.w1_b = None
|
||||||
|
self.w1_a = None
|
||||||
|
|
||||||
|
self.w2 = values.get("lokr_w2", None)
|
||||||
|
if self.w2 is None:
|
||||||
self.w2_a = values["lokr_w2_a"]
|
self.w2_a = values["lokr_w2_a"]
|
||||||
self.w2_b = values["lokr_w2_b"]
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
|
||||||
else:
|
else:
|
||||||
self.t2 = None
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
self.t2 = values.get("lokr_t2", None)
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
if self.w1_b is not None:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
elif self.w2_b is not None:
|
||||||
|
self.rank = self.w2_b.shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lokr_w1",
|
||||||
|
"lokr_w1_a",
|
||||||
|
"lokr_w1_b",
|
||||||
|
"lokr_w2",
|
||||||
|
"lokr_w2_a",
|
||||||
|
"lokr_w2_b",
|
||||||
|
"lokr_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
@ -272,7 +310,9 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
# weight: torch.Tensor
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -282,15 +322,12 @@ class FullLayer(LoRALayerBase):
|
|||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.weight = values["diff"]
|
self.weight = values["diff"]
|
||||||
|
self.bias = values.get("diff_b", None)
|
||||||
if len(values.keys()) > 1:
|
|
||||||
_keys = list(values.keys())
|
|
||||||
_keys.remove("diff")
|
|
||||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"diff", "diff_b"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -319,8 +356,9 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = values["on_input"]
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"weight", "on_input"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if not self.on_input:
|
if not self.on_input:
|
||||||
weight = weight.reshape(-1, 1)
|
weight = weight.reshape(-1, 1)
|
||||||
@ -340,7 +378,39 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
class NormLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["w_norm"]
|
||||||
|
self.bias = values.get("b_norm", None)
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"w_norm", "b_norm"})
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
@ -458,16 +528,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
||||||
|
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_up.weight" in values:
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
# loha
|
# loha
|
||||||
elif "hada_w1_b" in values:
|
elif "hada_w1_a" in values:
|
||||||
layer = LoHALayer(layer_key, values)
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
# lokr
|
# lokr
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
# diff
|
# diff
|
||||||
@ -475,9 +548,13 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
layer = FullLayer(layer_key, values)
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
# ia3
|
# ia3
|
||||||
elif "weight" in values and "on_input" in values:
|
elif "on_input" in values:
|
||||||
layer = IA3Layer(layer_key, values)
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
elif "w_norm" in values:
|
||||||
|
layer = NormLayer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
raise Exception("Unknown lora format!")
|
raise Exception("Unknown lora format!")
|
||||||
|
@ -11,6 +11,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
|
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
@ -34,7 +37,18 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
elif isinstance(model, CLIPTokenizer):
|
elif isinstance(model, CLIPTokenizer):
|
||||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||||
return 0
|
return 0
|
||||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
elif isinstance(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
TextualInversionModelRaw,
|
||||||
|
IPAdapter,
|
||||||
|
LoRAModelRaw,
|
||||||
|
SpandrelImageToImageModel,
|
||||||
|
GroundingDinoPipeline,
|
||||||
|
SegmentAnythingPipeline,
|
||||||
|
DepthAnythingPipeline,
|
||||||
|
),
|
||||||
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
|
@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@ -85,13 +86,13 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
model_state_dict=model_state_dict,
|
cached_weights=cached_weights,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -101,9 +102,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -113,7 +114,7 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@ -121,66 +122,26 @@ class ModelPatcher:
|
|||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights = {}
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
for lora_model, lora_weight in loras:
|
||||||
for lora, lora_weight in loras:
|
LoRAExt.patch_model(
|
||||||
# assert lora.device.type == "cpu"
|
model=model,
|
||||||
for layer_key, layer in lora.layers.items():
|
prefix=prefix,
|
||||||
if not layer_key.startswith(prefix):
|
lora=lora_model,
|
||||||
continue
|
lora_weight=lora_weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
yield
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
if module_key not in original_weights:
|
|
||||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
|
||||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
|
||||||
else:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
assert hasattr(layer_weight, "reshape")
|
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
yield # wait for context manager exit
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -2,14 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -56,5 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
yield None
|
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||||
|
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
||||||
|
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
||||||
|
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
||||||
|
by this context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
|
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
||||||
|
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
||||||
|
can access original weights values.
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class FreeUExt(ExtensionBase):
|
class FreeUExt(ExtensionBase):
|
||||||
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
|
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
137
invokeai/backend/stable_diffusion/extensions/lora.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
weight: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
|
self.patch_model(
|
||||||
|
model=unet,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
lora=lora_model,
|
||||||
|
lora_weight=self._weight,
|
||||||
|
original_weights=original_weights,
|
||||||
|
)
|
||||||
|
del lora_model
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def patch_model(
|
||||||
|
cls,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
lora: LoRAModelRaw,
|
||||||
|
lora_weight: float,
|
||||||
|
original_weights: OriginalWeightsStorage,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param lora: LoRA model to patch in.
|
||||||
|
:param lora_weight: LoRA patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if lora_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# assert lora.device.type == "cpu"
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
||||||
|
param_key = module_key + "." + param_name
|
||||||
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
|
# save original weight
|
||||||
|
original_weights.save(param_key, module_param)
|
||||||
|
|
||||||
|
if module_param.shape != lora_param_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
||||||
|
|
||||||
|
lora_param_weight *= lora_weight * layer_scale
|
||||||
|
module_param += lora_param_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
|
assert "." not in lora_key
|
||||||
|
|
||||||
|
if not lora_key.startswith(prefix):
|
||||||
|
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||||
|
|
||||||
|
module = model
|
||||||
|
module_key = ""
|
||||||
|
key_parts = lora_key[len(prefix) :].split("_")
|
||||||
|
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
|
while len(key_parts) > 0:
|
||||||
|
try:
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key += "." + submodule_name
|
||||||
|
submodule_name = key_parts.pop(0)
|
||||||
|
except Exception:
|
||||||
|
submodule_name += "_" + key_parts.pop(0)
|
||||||
|
|
||||||
|
module = module.get_submodule(submodule_name)
|
||||||
|
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||||
|
|
||||||
|
return (module_key, module)
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
@ -67,9 +68,15 @@ class ExtensionsManager:
|
|||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# TODO: create weight patch logic in PR with extension which uses it
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
with ExitStack() as exit_stack:
|
try:
|
||||||
for ext in self._extensions:
|
with ExitStack() as exit_stack:
|
||||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
for ext in self._extensions:
|
||||||
|
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with torch.no_grad():
|
||||||
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
|
unet.get_parameter(param_key).copy_(weight)
|
||||||
|
39
invokeai/backend/util/original_weights_storage.py
Normal file
39
invokeai/backend/util/original_weights_storage.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
class OriginalWeightsStorage:
|
||||||
|
"""A class for tracking the original weights of a model for patch/unpatch operations."""
|
||||||
|
|
||||||
|
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
# The original weights of the model.
|
||||||
|
self._weights: dict[str, torch.Tensor] = {}
|
||||||
|
# The keys of the weights that have been changed (via `save()`) during the lifetime of this instance.
|
||||||
|
self._changed_weights: set[str] = set()
|
||||||
|
if cached_weights:
|
||||||
|
self._weights.update(cached_weights)
|
||||||
|
|
||||||
|
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
|
||||||
|
self._changed_weights.add(key)
|
||||||
|
if key in self._weights:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
|
||||||
|
|
||||||
|
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
|
||||||
|
weight = self._weights.get(key, None)
|
||||||
|
if weight is not None and copy:
|
||||||
|
weight = weight.clone()
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def contains(self, key: str) -> bool:
|
||||||
|
return key in self._weights
|
||||||
|
|
||||||
|
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
for key in self._changed_weights:
|
||||||
|
yield key, self._weights[key]
|
@ -91,7 +91,8 @@
|
|||||||
"viewingDesc": "Bilder in großer Galerie ansehen",
|
"viewingDesc": "Bilder in großer Galerie ansehen",
|
||||||
"tab": "Tabulator",
|
"tab": "Tabulator",
|
||||||
"enabled": "Aktiviert",
|
"enabled": "Aktiviert",
|
||||||
"disabled": "Ausgeschaltet"
|
"disabled": "Ausgeschaltet",
|
||||||
|
"dontShowMeThese": "Zeig mir diese nicht"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Bildgröße",
|
"galleryImageSize": "Bildgröße",
|
||||||
@ -106,7 +107,6 @@
|
|||||||
"download": "Runterladen",
|
"download": "Runterladen",
|
||||||
"setCurrentImage": "Setze aktuelle Bild",
|
"setCurrentImage": "Setze aktuelle Bild",
|
||||||
"featuresWillReset": "Wenn Sie dieses Bild löschen, werden diese Funktionen sofort zurückgesetzt.",
|
"featuresWillReset": "Wenn Sie dieses Bild löschen, werden diese Funktionen sofort zurückgesetzt.",
|
||||||
"deleteImageBin": "Gelöschte Bilder werden an den Papierkorb Ihres Betriebssystems gesendet.",
|
|
||||||
"unableToLoad": "Galerie kann nicht geladen werden",
|
"unableToLoad": "Galerie kann nicht geladen werden",
|
||||||
"downloadSelection": "Auswahl herunterladen",
|
"downloadSelection": "Auswahl herunterladen",
|
||||||
"currentlyInUse": "Dieses Bild wird derzeit in den folgenden Funktionen verwendet:",
|
"currentlyInUse": "Dieses Bild wird derzeit in den folgenden Funktionen verwendet:",
|
||||||
@ -628,7 +628,10 @@
|
|||||||
"private": "Private Ordner",
|
"private": "Private Ordner",
|
||||||
"shared": "Geteilte Ordner",
|
"shared": "Geteilte Ordner",
|
||||||
"archiveBoard": "Ordner archivieren",
|
"archiveBoard": "Ordner archivieren",
|
||||||
"archived": "Archiviert"
|
"archived": "Archiviert",
|
||||||
|
"noBoards": "Kein {boardType}} Ordner",
|
||||||
|
"hideBoards": "Ordner verstecken",
|
||||||
|
"viewBoards": "Ordner ansehen"
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"showAdvanced": "Zeige Erweitert",
|
"showAdvanced": "Zeige Erweitert",
|
||||||
@ -943,6 +946,21 @@
|
|||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Reduziert das Ausgangsbild auf die Breite und Höhe des Ausgangsbildes. Empfohlen zu aktivieren."
|
"Reduziert das Ausgangsbild auf die Breite und Höhe des Ausgangsbildes. Empfohlen zu aktivieren."
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Die Struktur steuert, wie genau sich das Ausgabebild an das Layout des Originals hält. Eine niedrige Struktur erlaubt größere Änderungen, während eine hohe Struktur die ursprüngliche Komposition und das Layout strikter beibehält."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Die Kreativität bestimmt den Grad der Freiheit, die dem Modell beim Hinzufügen von Details gewährt wird. Eine niedrige Kreativität hält sich eng an das Originalbild, während eine hohe Kreativität mehr Veränderungen zulässt. Bei der Verwendung eines Prompts erhöht eine hohe Kreativität den Einfluss des Prompts."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"scale": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Die Skalierung steuert die Größe des Ausgabebildes und basiert auf einem Vielfachen der Auflösung des Originalbildes. So würde z. B. eine 2-fache Hochskalierung eines 1024x1024px Bildes eine 2048x2048px große Ausgabe erzeugen."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"invocationCache": {
|
"invocationCache": {
|
||||||
|
@ -200,6 +200,7 @@
|
|||||||
"delete": "Delete",
|
"delete": "Delete",
|
||||||
"depthAnything": "Depth Anything",
|
"depthAnything": "Depth Anything",
|
||||||
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
|
"depthAnythingDescription": "Depth map generation using the Depth Anything technique",
|
||||||
|
"depthAnythingSmallV2": "Small V2",
|
||||||
"depthMidas": "Depth (Midas)",
|
"depthMidas": "Depth (Midas)",
|
||||||
"depthMidasDescription": "Depth map generation using Midas",
|
"depthMidasDescription": "Depth map generation using Midas",
|
||||||
"depthZoe": "Depth (Zoe)",
|
"depthZoe": "Depth (Zoe)",
|
||||||
@ -373,7 +374,6 @@
|
|||||||
"dropToUpload": "$t(gallery.drop) to Upload",
|
"dropToUpload": "$t(gallery.drop) to Upload",
|
||||||
"deleteImage_one": "Delete Image",
|
"deleteImage_one": "Delete Image",
|
||||||
"deleteImage_other": "Delete {{count}} Images",
|
"deleteImage_other": "Delete {{count}} Images",
|
||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"displayBoardSearch": "Display Board Search",
|
"displayBoardSearch": "Display Board Search",
|
||||||
"displaySearch": "Display Search",
|
"displaySearch": "Display Search",
|
||||||
@ -1053,11 +1053,7 @@
|
|||||||
"remixImage": "Remix Image",
|
"remixImage": "Remix Image",
|
||||||
"usePrompt": "Use Prompt",
|
"usePrompt": "Use Prompt",
|
||||||
"useSeed": "Use Seed",
|
"useSeed": "Use Seed",
|
||||||
"width": "Width",
|
"width": "Width"
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"useX2Model": "Image is too large to upscale with x4 model, use x2 model",
|
|
||||||
"tooLarge": "Image is too large to upscale, select smaller image"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
"showDynamicPrompts": "Show Dynamic Prompts",
|
"showDynamicPrompts": "Show Dynamic Prompts",
|
||||||
@ -1678,6 +1674,8 @@
|
|||||||
},
|
},
|
||||||
"upscaling": {
|
"upscaling": {
|
||||||
"creativity": "Creativity",
|
"creativity": "Creativity",
|
||||||
|
"exceedsMaxSize": "Upscale settings exceed max size limit",
|
||||||
|
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
|
||||||
"structure": "Structure",
|
"structure": "Structure",
|
||||||
"upscaleModel": "Upscale Model",
|
"upscaleModel": "Upscale Model",
|
||||||
"postProcessingModel": "Post-Processing Model",
|
"postProcessingModel": "Post-Processing Model",
|
||||||
|
@ -88,7 +88,6 @@
|
|||||||
"deleteImage_one": "Eliminar Imagen",
|
"deleteImage_one": "Eliminar Imagen",
|
||||||
"deleteImage_many": "",
|
"deleteImage_many": "",
|
||||||
"deleteImage_other": "",
|
"deleteImage_other": "",
|
||||||
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
|
|
||||||
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
|
||||||
"assets": "Activos",
|
"assets": "Activos",
|
||||||
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
|
"autoAssignBoardOnClick": "Asignación automática de tableros al hacer clic"
|
||||||
|
@ -89,7 +89,8 @@
|
|||||||
"enabled": "Abilitato",
|
"enabled": "Abilitato",
|
||||||
"disabled": "Disabilitato",
|
"disabled": "Disabilitato",
|
||||||
"comparingDesc": "Confronta due immagini",
|
"comparingDesc": "Confronta due immagini",
|
||||||
"comparing": "Confronta"
|
"comparing": "Confronta",
|
||||||
|
"dontShowMeThese": "Non mostrarmi questi"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Dimensione dell'immagine",
|
"galleryImageSize": "Dimensione dell'immagine",
|
||||||
@ -101,7 +102,6 @@
|
|||||||
"deleteImage_many": "Elimina {{count}} immagini",
|
"deleteImage_many": "Elimina {{count}} immagini",
|
||||||
"deleteImage_other": "Elimina {{count}} immagini",
|
"deleteImage_other": "Elimina {{count}} immagini",
|
||||||
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
|
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
|
||||||
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
|
|
||||||
"assets": "Risorse",
|
"assets": "Risorse",
|
||||||
"autoAssignBoardOnClick": "Assegna automaticamente la bacheca al clic",
|
"autoAssignBoardOnClick": "Assegna automaticamente la bacheca al clic",
|
||||||
"featuresWillReset": "Se elimini questa immagine, quelle funzionalità verranno immediatamente ripristinate.",
|
"featuresWillReset": "Se elimini questa immagine, quelle funzionalità verranno immediatamente ripristinate.",
|
||||||
@ -154,7 +154,9 @@
|
|||||||
"selectAllOnPage": "Seleziona tutto nella pagina",
|
"selectAllOnPage": "Seleziona tutto nella pagina",
|
||||||
"selectAllOnBoard": "Seleziona tutto nella bacheca",
|
"selectAllOnBoard": "Seleziona tutto nella bacheca",
|
||||||
"exitBoardSearch": "Esci da Ricerca bacheca",
|
"exitBoardSearch": "Esci da Ricerca bacheca",
|
||||||
"exitSearch": "Esci dalla ricerca"
|
"exitSearch": "Esci dalla ricerca",
|
||||||
|
"go": "Vai",
|
||||||
|
"jump": "Salta"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||||
@ -571,10 +573,6 @@
|
|||||||
},
|
},
|
||||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||||
"iterations": "Iterazioni",
|
"iterations": "Iterazioni",
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"useX2Model": "L'immagine è troppo grande per l'ampliamento con il modello x4, utilizza il modello x2",
|
|
||||||
"tooLarge": "L'immagine è troppo grande per l'ampliamento, seleziona un'immagine più piccola"
|
|
||||||
},
|
|
||||||
"imageActions": "Azioni Immagine",
|
"imageActions": "Azioni Immagine",
|
||||||
"cfgRescaleMultiplier": "Moltiplicatore riscala CFG",
|
"cfgRescaleMultiplier": "Moltiplicatore riscala CFG",
|
||||||
"useSize": "Usa Dimensioni",
|
"useSize": "Usa Dimensioni",
|
||||||
@ -630,7 +628,9 @@
|
|||||||
"enableNSFWChecker": "Abilita controllo NSFW",
|
"enableNSFWChecker": "Abilita controllo NSFW",
|
||||||
"enableInvisibleWatermark": "Abilita filigrana invisibile",
|
"enableInvisibleWatermark": "Abilita filigrana invisibile",
|
||||||
"enableInformationalPopovers": "Abilita testo informativo a comparsa",
|
"enableInformationalPopovers": "Abilita testo informativo a comparsa",
|
||||||
"reloadingIn": "Ricaricando in"
|
"reloadingIn": "Ricaricando in",
|
||||||
|
"informationalPopoversDisabled": "Testo informativo a comparsa disabilitato",
|
||||||
|
"informationalPopoversDisabledDesc": "I testi informativi a comparsa sono disabilitati. Attivali nelle impostazioni."
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"uploadFailed": "Caricamento fallito",
|
"uploadFailed": "Caricamento fallito",
|
||||||
@ -951,7 +951,7 @@
|
|||||||
"deleteBoardOnly": "solo la Bacheca",
|
"deleteBoardOnly": "solo la Bacheca",
|
||||||
"deleteBoard": "Elimina Bacheca",
|
"deleteBoard": "Elimina Bacheca",
|
||||||
"deleteBoardAndImages": "Bacheca e Immagini",
|
"deleteBoardAndImages": "Bacheca e Immagini",
|
||||||
"deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate",
|
"deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate. Selezionando \"Elimina solo bacheca\" le immagini verranno spostate nella bacheca \"Non categorizzato\".",
|
||||||
"movingImagesToBoard_one": "Spostare {{count}} immagine nella bacheca:",
|
"movingImagesToBoard_one": "Spostare {{count}} immagine nella bacheca:",
|
||||||
"movingImagesToBoard_many": "Spostare {{count}} immagini nella bacheca:",
|
"movingImagesToBoard_many": "Spostare {{count}} immagini nella bacheca:",
|
||||||
"movingImagesToBoard_other": "Spostare {{count}} immagini nella bacheca:",
|
"movingImagesToBoard_other": "Spostare {{count}} immagini nella bacheca:",
|
||||||
@ -972,7 +972,8 @@
|
|||||||
"addPrivateBoard": "Aggiungi una Bacheca Privata",
|
"addPrivateBoard": "Aggiungi una Bacheca Privata",
|
||||||
"noBoards": "Nessuna bacheca {{boardType}}",
|
"noBoards": "Nessuna bacheca {{boardType}}",
|
||||||
"hideBoards": "Nascondi bacheche",
|
"hideBoards": "Nascondi bacheche",
|
||||||
"viewBoards": "Visualizza bacheche"
|
"viewBoards": "Visualizza bacheche",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "Le bacheche cancellate non possono essere ripristinate. Selezionando 'Cancella solo bacheca', le immagini verranno spostate nella bacheca \"Non categorizzato\" privata dell'autore dell'immagine."
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
|
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
|
||||||
@ -1516,6 +1517,30 @@
|
|||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Metodo con cui applicare l'adattatore IP corrente."
|
"Metodo con cui applicare l'adattatore IP corrente."
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"scale": {
|
||||||
|
"heading": "Scala",
|
||||||
|
"paragraphs": [
|
||||||
|
"La scala controlla la dimensione dell'immagine di uscita e si basa su un multiplo della risoluzione dell'immagine di ingresso. Ad esempio, un ampliamento 2x su un'immagine 1024x1024 produrrebbe in uscita a 2048x2048."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"upscaleModel": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Il modello di ampliamento ridimensiona l'immagine alle dimensioni di uscita prima che vengano aggiunti i dettagli. È possibile utilizzare qualsiasi modello di ampliamento supportato, ma alcuni sono specializzati per diversi tipi di immagini, come foto o disegni al tratto."
|
||||||
|
],
|
||||||
|
"heading": "Modello di ampliamento"
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"heading": "Creatività",
|
||||||
|
"paragraphs": [
|
||||||
|
"La creatività controlla quanta libertà è concessa al modello quando si aggiungono dettagli. Una creatività bassa rimane vicina all'immagine originale, mentre una creatività alta consente più cambiamenti. Quando si usa un prompt, una creatività alta aumenta l'influenza del prompt."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"heading": "Struttura",
|
||||||
|
"paragraphs": [
|
||||||
|
"La struttura determina quanto l'immagine finale rispecchierà il layout dell'originale. Una struttura bassa permette cambiamenti significativi, mentre una struttura alta conserva la composizione e il layout originali."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
|
@ -109,7 +109,6 @@
|
|||||||
"drop": "ドロップ",
|
"drop": "ドロップ",
|
||||||
"dropOrUpload": "$t(gallery.drop) またはアップロード",
|
"dropOrUpload": "$t(gallery.drop) またはアップロード",
|
||||||
"deleteImage_other": "画像を削除",
|
"deleteImage_other": "画像を削除",
|
||||||
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
|
|
||||||
"deleteImagePermanent": "削除された画像は復元できません。",
|
"deleteImagePermanent": "削除された画像は復元できません。",
|
||||||
"download": "ダウンロード",
|
"download": "ダウンロード",
|
||||||
"unableToLoad": "ギャラリーをロードできません",
|
"unableToLoad": "ギャラリーをロードできません",
|
||||||
|
@ -70,7 +70,6 @@
|
|||||||
"gallerySettings": "갤러리 설정",
|
"gallerySettings": "갤러리 설정",
|
||||||
"deleteSelection": "선택 항목 삭제",
|
"deleteSelection": "선택 항목 삭제",
|
||||||
"featuresWillReset": "이 이미지를 삭제하면 해당 기능이 즉시 재설정됩니다.",
|
"featuresWillReset": "이 이미지를 삭제하면 해당 기능이 즉시 재설정됩니다.",
|
||||||
"deleteImageBin": "삭제된 이미지는 운영 체제의 Bin으로 전송됩니다.",
|
|
||||||
"assets": "자산",
|
"assets": "자산",
|
||||||
"problemDeletingImagesDesc": "하나 이상의 이미지를 삭제할 수 없습니다",
|
"problemDeletingImagesDesc": "하나 이상의 이미지를 삭제할 수 없습니다",
|
||||||
"noImagesInGallery": "보여줄 이미지가 없음",
|
"noImagesInGallery": "보여줄 이미지가 없음",
|
||||||
|
@ -97,7 +97,6 @@
|
|||||||
"noImagesInGallery": "Geen afbeeldingen om te tonen",
|
"noImagesInGallery": "Geen afbeeldingen om te tonen",
|
||||||
"deleteImage_one": "Verwijder afbeelding",
|
"deleteImage_one": "Verwijder afbeelding",
|
||||||
"deleteImage_other": "",
|
"deleteImage_other": "",
|
||||||
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
|
|
||||||
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
|
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
|
||||||
"assets": "Eigen onderdelen",
|
"assets": "Eigen onderdelen",
|
||||||
"autoAssignBoardOnClick": "Ken automatisch bord toe bij klikken",
|
"autoAssignBoardOnClick": "Ken automatisch bord toe bij klikken",
|
||||||
@ -467,10 +466,6 @@
|
|||||||
},
|
},
|
||||||
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
|
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
|
||||||
},
|
},
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model",
|
|
||||||
"tooLarge": "Afbeelding is te groot om te vergoten. Kies een kleinere afbeelding"
|
|
||||||
},
|
|
||||||
"patchmatchDownScaleSize": "Verklein",
|
"patchmatchDownScaleSize": "Verklein",
|
||||||
"useCpuNoise": "Gebruik CPU-ruis",
|
"useCpuNoise": "Gebruik CPU-ruis",
|
||||||
"imageActions": "Afbeeldingshandeling",
|
"imageActions": "Afbeeldingshandeling",
|
||||||
|
@ -100,7 +100,6 @@
|
|||||||
"loadMore": "Показать больше",
|
"loadMore": "Показать больше",
|
||||||
"noImagesInGallery": "Изображений нет",
|
"noImagesInGallery": "Изображений нет",
|
||||||
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
|
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
|
||||||
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
|
|
||||||
"deleteImage_one": "Удалить изображение",
|
"deleteImage_one": "Удалить изображение",
|
||||||
"deleteImage_few": "Удалить {{count}} изображения",
|
"deleteImage_few": "Удалить {{count}} изображения",
|
||||||
"deleteImage_many": "Удалить {{count}} изображений",
|
"deleteImage_many": "Удалить {{count}} изображений",
|
||||||
@ -567,10 +566,6 @@
|
|||||||
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
|
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
|
|
||||||
"tooLarge": "Изображение слишком велико для увеличения. Выберите изображение меньшего размера"
|
|
||||||
},
|
|
||||||
"cfgRescaleMultiplier": "Множитель масштабирования CFG",
|
"cfgRescaleMultiplier": "Множитель масштабирования CFG",
|
||||||
"patchmatchDownScaleSize": "уменьшить",
|
"patchmatchDownScaleSize": "уменьшить",
|
||||||
"useCpuNoise": "Использовать шум CPU",
|
"useCpuNoise": "Использовать шум CPU",
|
||||||
|
@ -278,7 +278,6 @@
|
|||||||
"enable": "Aç"
|
"enable": "Aç"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"deleteImageBin": "Silinen görseller işletim sisteminin çöp kutusuna gönderilir.",
|
|
||||||
"deleteImagePermanent": "Silinen görseller geri getirilemez.",
|
"deleteImagePermanent": "Silinen görseller geri getirilemez.",
|
||||||
"assets": "Özkaynaklar",
|
"assets": "Özkaynaklar",
|
||||||
"autoAssignBoardOnClick": "Tıklanan Panoya Otomatik Atama",
|
"autoAssignBoardOnClick": "Tıklanan Panoya Otomatik Atama",
|
||||||
@ -622,10 +621,6 @@
|
|||||||
"controlNetControlMode": "Yönetim Kipi",
|
"controlNetControlMode": "Yönetim Kipi",
|
||||||
"general": "Genel",
|
"general": "Genel",
|
||||||
"seamlessYAxis": "Dikişsiz Döşeme Y Ekseni",
|
"seamlessYAxis": "Dikişsiz Döşeme Y Ekseni",
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"tooLarge": "Görsel, büyütme işlemi için çok büyük, daha küçük bir boyut seçin",
|
|
||||||
"useX2Model": "Görsel 4 kat büyütme işlemi için çok geniş, 2 kat büyütmeyi kullanın"
|
|
||||||
},
|
|
||||||
"maskBlur": "Bulandırma",
|
"maskBlur": "Bulandırma",
|
||||||
"images": "Görseller",
|
"images": "Görseller",
|
||||||
"info": "Bilgi",
|
"info": "Bilgi",
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
"settingsLabel": "设置",
|
"settingsLabel": "设置",
|
||||||
"img2img": "图生图",
|
"img2img": "图生图",
|
||||||
"unifiedCanvas": "统一画布",
|
"unifiedCanvas": "统一画布",
|
||||||
"nodes": "工作流编辑器",
|
"nodes": "工作流",
|
||||||
"upload": "上传",
|
"upload": "上传",
|
||||||
"load": "加载",
|
"load": "加载",
|
||||||
"statusDisconnected": "未连接",
|
"statusDisconnected": "未连接",
|
||||||
@ -86,7 +86,12 @@
|
|||||||
"editing": "编辑中",
|
"editing": "编辑中",
|
||||||
"green": "绿",
|
"green": "绿",
|
||||||
"blue": "蓝",
|
"blue": "蓝",
|
||||||
"editingDesc": "在控制图层画布上编辑"
|
"editingDesc": "在控制图层画布上编辑",
|
||||||
|
"goTo": "前往",
|
||||||
|
"dontShowMeThese": "请勿显示这些内容",
|
||||||
|
"beta": "测试版",
|
||||||
|
"toResolve": "解决",
|
||||||
|
"tab": "标签页"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "预览大小",
|
"galleryImageSize": "预览大小",
|
||||||
@ -94,8 +99,7 @@
|
|||||||
"autoSwitchNewImages": "自动切换到新图像",
|
"autoSwitchNewImages": "自动切换到新图像",
|
||||||
"loadMore": "加载更多",
|
"loadMore": "加载更多",
|
||||||
"noImagesInGallery": "无图像可用于显示",
|
"noImagesInGallery": "无图像可用于显示",
|
||||||
"deleteImage_other": "删除图片",
|
"deleteImage_other": "删除{{count}}张图片",
|
||||||
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
|
|
||||||
"deleteImagePermanent": "删除的图片无法被恢复。",
|
"deleteImagePermanent": "删除的图片无法被恢复。",
|
||||||
"assets": "素材",
|
"assets": "素材",
|
||||||
"autoAssignBoardOnClick": "点击后自动分配面板",
|
"autoAssignBoardOnClick": "点击后自动分配面板",
|
||||||
@ -133,7 +137,24 @@
|
|||||||
"hover": "悬停",
|
"hover": "悬停",
|
||||||
"selectAllOnPage": "选择本页全部",
|
"selectAllOnPage": "选择本页全部",
|
||||||
"swapImages": "交换图像",
|
"swapImages": "交换图像",
|
||||||
"compareOptions": "比较选项"
|
"compareOptions": "比较选项",
|
||||||
|
"exitBoardSearch": "退出面板搜索",
|
||||||
|
"exitSearch": "退出搜索",
|
||||||
|
"oldestFirst": "最旧在前",
|
||||||
|
"sortDirection": "排序方向",
|
||||||
|
"showStarredImagesFirst": "优先显示收藏的图片",
|
||||||
|
"compareHelp3": "按 <Kbd>C</Kbd> 键对调正在比较的图片。",
|
||||||
|
"showArchivedBoards": "显示已归档的面板",
|
||||||
|
"newestFirst": "最新在前",
|
||||||
|
"compareHelp4": "按 <Kbd>Z</Kbd>或 <Kbd>Esc</Kbd> 键退出。",
|
||||||
|
"searchImages": "按元数据搜索",
|
||||||
|
"jump": "跳过",
|
||||||
|
"compareHelp2": "按 <Kbd>M</Kbd> 键切换不同的比较模式。",
|
||||||
|
"displayBoardSearch": "显示面板搜索",
|
||||||
|
"displaySearch": "显示搜索",
|
||||||
|
"stretchToFit": "拉伸以适应",
|
||||||
|
"exitCompare": "退出对比",
|
||||||
|
"compareHelp1": "在点击图库中的图片或使用箭头键切换比较图片时,请按住<Kbd>Alt</Kbd> 键。"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "快捷键",
|
"keyboardShortcuts": "快捷键",
|
||||||
@ -348,7 +369,19 @@
|
|||||||
"desc": "打开和关闭选项和图库面板",
|
"desc": "打开和关闭选项和图库面板",
|
||||||
"title": "开关选项和图库"
|
"title": "开关选项和图库"
|
||||||
},
|
},
|
||||||
"clearSearch": "清除检索项"
|
"clearSearch": "清除检索项",
|
||||||
|
"toggleViewer": {
|
||||||
|
"desc": "在当前标签页的图片查看模式和编辑工作区之间切换.",
|
||||||
|
"title": "切换图片查看器"
|
||||||
|
},
|
||||||
|
"postProcess": {
|
||||||
|
"desc": "使用选定的后期处理模型对当前图像进行处理",
|
||||||
|
"title": "处理图像"
|
||||||
|
},
|
||||||
|
"remixImage": {
|
||||||
|
"title": "重新混合图像",
|
||||||
|
"desc": "使用当前图像的所有参数,但不包括随机种子"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"modelManager": {
|
"modelManager": {
|
||||||
"modelManager": "模型管理器",
|
"modelManager": "模型管理器",
|
||||||
@ -396,14 +429,71 @@
|
|||||||
"modelConversionFailed": "模型转换失败",
|
"modelConversionFailed": "模型转换失败",
|
||||||
"baseModel": "基底模型",
|
"baseModel": "基底模型",
|
||||||
"convertingModelBegin": "模型转换中. 请稍候.",
|
"convertingModelBegin": "模型转换中. 请稍候.",
|
||||||
"predictionType": "预测类型(适用于 Stable Diffusion 2.x 模型和部分 Stable Diffusion 1.x 模型)",
|
"predictionType": "预测类型",
|
||||||
"advanced": "高级",
|
"advanced": "高级",
|
||||||
"modelType": "模型类别",
|
"modelType": "模型类别",
|
||||||
"variant": "变体",
|
"variant": "变体",
|
||||||
"vae": "VAE",
|
"vae": "VAE",
|
||||||
"alpha": "Alpha",
|
"alpha": "Alpha",
|
||||||
"vaePrecision": "VAE 精度",
|
"vaePrecision": "VAE 精度",
|
||||||
"noModelSelected": "无选中的模型"
|
"noModelSelected": "无选中的模型",
|
||||||
|
"modelImageUpdateFailed": "模型图像更新失败",
|
||||||
|
"scanFolder": "扫描文件夹",
|
||||||
|
"path": "路径",
|
||||||
|
"pathToConfig": "配置路径",
|
||||||
|
"cancel": "取消",
|
||||||
|
"hfTokenUnableToVerify": "无法验证HuggingFace token",
|
||||||
|
"install": "安装",
|
||||||
|
"simpleModelPlaceholder": "本地文件或diffusers文件夹的URL或路径",
|
||||||
|
"hfTokenInvalidErrorMessage": "无效或缺失的HuggingFace token.",
|
||||||
|
"noModelsInstalledDesc1": "安装模型时使用",
|
||||||
|
"inplaceInstallDesc": "安装模型时,不复制文件,直接从原位置加载。如果关闭此选项,模型文件将在安装过程中被复制到Invoke管理的模型文件夹中.",
|
||||||
|
"installAll": "安装全部",
|
||||||
|
"noModelsInstalled": "无已安装的模型",
|
||||||
|
"urlOrLocalPathHelper": "链接应该指向单个文件.本地路径可以指向单个文件,或者对于单个扩散模型(diffusers model),可以指向一个文件夹.",
|
||||||
|
"modelSettings": "模型设置",
|
||||||
|
"useDefaultSettings": "使用默认设置",
|
||||||
|
"scanPlaceholder": "本地文件夹路径",
|
||||||
|
"installRepo": "安装仓库",
|
||||||
|
"modelImageDeleted": "模型图像已删除",
|
||||||
|
"modelImageDeleteFailed": "模型图像删除失败",
|
||||||
|
"scanFolderHelper": "此文件夹将进行递归扫描以寻找模型.对于大型文件夹,这可能需要一些时间.",
|
||||||
|
"scanResults": "扫描结果",
|
||||||
|
"noMatchingModels": "无匹配的模型",
|
||||||
|
"pruneTooltip": "清理队列中已完成的导入任务",
|
||||||
|
"urlOrLocalPath": "链接或本地路径",
|
||||||
|
"localOnly": "仅本地",
|
||||||
|
"hfTokenHelperText": "需要HuggingFace token才能使用Checkpoint模型。点击此处创建或获取您的token.",
|
||||||
|
"huggingFaceHelper": "如果在此代码库中检测到多个模型,系统将提示您选择其中一个进行安装.",
|
||||||
|
"hfTokenUnableToVerifyErrorMessage": "无法验证HuggingFace token.可能是网络问题所致.请稍后再试.",
|
||||||
|
"hfTokenSaved": "HuggingFace token已保存",
|
||||||
|
"imageEncoderModelId": "图像编码器模型ID",
|
||||||
|
"modelImageUpdated": "模型图像已更新",
|
||||||
|
"modelName": "模型名称",
|
||||||
|
"prune": "清理",
|
||||||
|
"repoVariant": "代码库版本",
|
||||||
|
"defaultSettings": "默认设置",
|
||||||
|
"inplaceInstall": "就地安装",
|
||||||
|
"main": "主界面",
|
||||||
|
"starterModels": "初始模型",
|
||||||
|
"installQueue": "安装队列",
|
||||||
|
"hfTokenInvalidErrorMessage2": "更新于其中 ",
|
||||||
|
"hfTokenInvalid": "无效或缺失的HuggingFace token",
|
||||||
|
"mainModelTriggerPhrases": "主模型触发词",
|
||||||
|
"typePhraseHere": "在此输入触发词",
|
||||||
|
"triggerPhrases": "触发词",
|
||||||
|
"metadata": "元数据",
|
||||||
|
"deleteModelImage": "删除模型图片",
|
||||||
|
"edit": "编辑",
|
||||||
|
"source": "来源",
|
||||||
|
"uploadImage": "上传图像",
|
||||||
|
"addModels": "添加模型",
|
||||||
|
"textualInversions": "文本逆向生成",
|
||||||
|
"upcastAttention": "是否为高精度权重",
|
||||||
|
"defaultSettingsSaved": "默认设置已保存",
|
||||||
|
"huggingFacePlaceholder": "所有者或模型名称",
|
||||||
|
"huggingFaceRepoID": "HuggingFace仓库ID",
|
||||||
|
"loraTriggerPhrases": "LoRA 触发词"
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "图像",
|
"images": "图像",
|
||||||
@ -446,7 +536,7 @@
|
|||||||
"scheduler": "调度器",
|
"scheduler": "调度器",
|
||||||
"general": "通用",
|
"general": "通用",
|
||||||
"controlNetControlMode": "控制模式",
|
"controlNetControlMode": "控制模式",
|
||||||
"maskBlur": "模糊",
|
"maskBlur": "遮罩模糊",
|
||||||
"invoke": {
|
"invoke": {
|
||||||
"noNodesInGraph": "节点图中无节点",
|
"noNodesInGraph": "节点图中无节点",
|
||||||
"noModelSelected": "无已选中的模型",
|
"noModelSelected": "无已选中的模型",
|
||||||
@ -460,7 +550,21 @@
|
|||||||
"noPrompts": "没有已生成的提示词",
|
"noPrompts": "没有已生成的提示词",
|
||||||
"noControlImageForControlAdapter": "有 #{{number}} 个 Control Adapter 缺失控制图像",
|
"noControlImageForControlAdapter": "有 #{{number}} 个 Control Adapter 缺失控制图像",
|
||||||
"noModelForControlAdapter": "有 #{{number}} 个 Control Adapter 没有选择模型。",
|
"noModelForControlAdapter": "有 #{{number}} 个 Control Adapter 没有选择模型。",
|
||||||
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。"
|
"incompatibleBaseModelForControlAdapter": "有 #{{number}} 个 Control Adapter 模型与主模型不兼容。",
|
||||||
|
"layer": {
|
||||||
|
"initialImageNoImageSelected": "未选择初始图像",
|
||||||
|
"controlAdapterImageNotProcessed": "Control Adapter图像尚未处理",
|
||||||
|
"ipAdapterNoModelSelected": "未选择IP adapter",
|
||||||
|
"controlAdapterNoModelSelected": "未选择Control Adapter模型",
|
||||||
|
"controlAdapterNoImageSelected": "未选择Control Adapter图像",
|
||||||
|
"rgNoPromptsOrIPAdapters": "无文本提示或IP Adapters",
|
||||||
|
"controlAdapterIncompatibleBaseModel": "Control Adapter的基础模型不兼容",
|
||||||
|
"ipAdapterIncompatibleBaseModel": "IP Adapter的基础模型不兼容",
|
||||||
|
"t2iAdapterIncompatibleDimensions": "T2I Adapter需要图像尺寸为{{multiple}}的倍数",
|
||||||
|
"ipAdapterNoImageSelected": "未选择IP Adapter图像",
|
||||||
|
"rgNoRegion": "未选择区域"
|
||||||
|
},
|
||||||
|
"imageNotProcessedForControlAdapter": "Control Adapter #{{number}} 的图像未处理"
|
||||||
},
|
},
|
||||||
"patchmatchDownScaleSize": "缩小",
|
"patchmatchDownScaleSize": "缩小",
|
||||||
"clipSkip": "CLIP 跳过层",
|
"clipSkip": "CLIP 跳过层",
|
||||||
@ -468,10 +572,6 @@
|
|||||||
"coherenceMode": "模式",
|
"coherenceMode": "模式",
|
||||||
"imageActions": "图像操作",
|
"imageActions": "图像操作",
|
||||||
"iterations": "迭代数",
|
"iterations": "迭代数",
|
||||||
"isAllowedToUpscale": {
|
|
||||||
"useX2Model": "图像太大,无法使用 x4 模型,使用 x2 模型作为替代",
|
|
||||||
"tooLarge": "图像太大无法进行放大,请选择更小的图像"
|
|
||||||
},
|
|
||||||
"cfgRescaleMultiplier": "CFG 重缩放倍数",
|
"cfgRescaleMultiplier": "CFG 重缩放倍数",
|
||||||
"useSize": "使用尺寸",
|
"useSize": "使用尺寸",
|
||||||
"setToOptimalSize": "优化模型大小",
|
"setToOptimalSize": "优化模型大小",
|
||||||
@ -479,7 +579,21 @@
|
|||||||
"lockAspectRatio": "锁定纵横比",
|
"lockAspectRatio": "锁定纵横比",
|
||||||
"swapDimensions": "交换尺寸",
|
"swapDimensions": "交换尺寸",
|
||||||
"aspect": "纵横",
|
"aspect": "纵横",
|
||||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (可能过大)"
|
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (可能过大)",
|
||||||
|
"globalNegativePromptPlaceholder": "全局反向提示词",
|
||||||
|
"remixImage": "重新混合图像",
|
||||||
|
"coherenceEdgeSize": "边缘尺寸",
|
||||||
|
"postProcessing": "后处理(Shift + U)",
|
||||||
|
"infillMosaicTileWidth": "瓦片宽度",
|
||||||
|
"sendToUpscale": "发送到放大",
|
||||||
|
"processImage": "处理图像",
|
||||||
|
"globalPositivePromptPlaceholder": "全局正向提示词",
|
||||||
|
"globalSettings": "全局设置",
|
||||||
|
"infillMosaicTileHeight": "瓦片高度",
|
||||||
|
"infillMosaicMinColor": "最小颜色",
|
||||||
|
"infillMosaicMaxColor": "最大颜色",
|
||||||
|
"infillColorValue": "填充颜色",
|
||||||
|
"coherenceMinDenoise": "最小去噪"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "模型",
|
"models": "模型",
|
||||||
@ -509,7 +623,9 @@
|
|||||||
"enableNSFWChecker": "启用成人内容检测器",
|
"enableNSFWChecker": "启用成人内容检测器",
|
||||||
"enableInvisibleWatermark": "启用不可见水印",
|
"enableInvisibleWatermark": "启用不可见水印",
|
||||||
"enableInformationalPopovers": "启用信息弹窗",
|
"enableInformationalPopovers": "启用信息弹窗",
|
||||||
"reloadingIn": "重新加载中"
|
"reloadingIn": "重新加载中",
|
||||||
|
"informationalPopoversDisabled": "信息提示框已禁用",
|
||||||
|
"informationalPopoversDisabledDesc": "信息提示框已被禁用.请在设置中重新启用."
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"uploadFailed": "上传失败",
|
"uploadFailed": "上传失败",
|
||||||
@ -518,16 +634,16 @@
|
|||||||
"canvasMerged": "画布已合并",
|
"canvasMerged": "画布已合并",
|
||||||
"sentToImageToImage": "已发送到图生图",
|
"sentToImageToImage": "已发送到图生图",
|
||||||
"sentToUnifiedCanvas": "已发送到统一画布",
|
"sentToUnifiedCanvas": "已发送到统一画布",
|
||||||
"parametersNotSet": "参数未设定",
|
"parametersNotSet": "参数未恢复",
|
||||||
"metadataLoadFailed": "加载元数据失败",
|
"metadataLoadFailed": "加载元数据失败",
|
||||||
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
||||||
"connected": "服务器连接",
|
"connected": "服务器连接",
|
||||||
"parameterSet": "参数已设定",
|
"parameterSet": "参数已恢复",
|
||||||
"parameterNotSet": "参数未设定",
|
"parameterNotSet": "参数未恢复",
|
||||||
"serverError": "服务器错误",
|
"serverError": "服务器错误",
|
||||||
"canceled": "处理取消",
|
"canceled": "处理取消",
|
||||||
"problemCopyingImage": "无法复制图像",
|
"problemCopyingImage": "无法复制图像",
|
||||||
"modelAddedSimple": "已添加模型",
|
"modelAddedSimple": "模型已加入队列",
|
||||||
"imageSavingFailed": "图像保存失败",
|
"imageSavingFailed": "图像保存失败",
|
||||||
"canvasSentControlnetAssets": "画布已发送到 ControlNet & 素材",
|
"canvasSentControlnetAssets": "画布已发送到 ControlNet & 素材",
|
||||||
"problemCopyingCanvasDesc": "无法导出基础层",
|
"problemCopyingCanvasDesc": "无法导出基础层",
|
||||||
@ -557,12 +673,28 @@
|
|||||||
"canvasSavedGallery": "画布已保存到图库",
|
"canvasSavedGallery": "画布已保存到图库",
|
||||||
"imageUploadFailed": "图像上传失败",
|
"imageUploadFailed": "图像上传失败",
|
||||||
"problemImportingMask": "导入遮罩时出现问题",
|
"problemImportingMask": "导入遮罩时出现问题",
|
||||||
"baseModelChangedCleared_other": "基础模型已更改, 已清除或禁用 {{count}} 个不兼容的子模型",
|
"baseModelChangedCleared_other": "已清除或禁用{{count}}个不兼容的子模型",
|
||||||
"setAsCanvasInitialImage": "设为画布初始图像",
|
"setAsCanvasInitialImage": "设为画布初始图像",
|
||||||
"invalidUpload": "无效的上传",
|
"invalidUpload": "无效的上传",
|
||||||
"problemDeletingWorkflow": "删除工作流时出现问题",
|
"problemDeletingWorkflow": "删除工作流时出现问题",
|
||||||
"workflowDeleted": "已删除工作流",
|
"workflowDeleted": "已删除工作流",
|
||||||
"problemRetrievingWorkflow": "检索工作流时发生问题"
|
"problemRetrievingWorkflow": "检索工作流时发生问题",
|
||||||
|
"baseModelChanged": "基础模型已更改",
|
||||||
|
"problemDownloadingImage": "无法下载图像",
|
||||||
|
"outOfMemoryError": "内存不足错误",
|
||||||
|
"parameters": "参数",
|
||||||
|
"resetInitialImage": "重置初始图像",
|
||||||
|
"parameterNotSetDescWithMessage": "无法恢复 {{parameter}}: {{message}}",
|
||||||
|
"parameterSetDesc": "已恢复 {{parameter}}",
|
||||||
|
"parameterNotSetDesc": "无法恢复{{parameter}}",
|
||||||
|
"sessionRef": "会话: {{sessionId}}",
|
||||||
|
"somethingWentWrong": "出现错误",
|
||||||
|
"prunedQueue": "已清理队列",
|
||||||
|
"uploadInitialImage": "上传初始图像",
|
||||||
|
"outOfMemoryErrorDesc": "您当前的生成设置已超出系统处理能力.请调整设置后再次尝试.",
|
||||||
|
"parametersSet": "参数已恢复",
|
||||||
|
"errorCopied": "错误信息已复制",
|
||||||
|
"modelImportCanceled": "模型导入已取消"
|
||||||
},
|
},
|
||||||
"unifiedCanvas": {
|
"unifiedCanvas": {
|
||||||
"layer": "图层",
|
"layer": "图层",
|
||||||
@ -616,7 +748,15 @@
|
|||||||
"antialiasing": "抗锯齿",
|
"antialiasing": "抗锯齿",
|
||||||
"showResultsOn": "显示结果 (开)",
|
"showResultsOn": "显示结果 (开)",
|
||||||
"showResultsOff": "显示结果 (关)",
|
"showResultsOff": "显示结果 (关)",
|
||||||
"saveMask": "保存 $t(unifiedCanvas.mask)"
|
"saveMask": "保存 $t(unifiedCanvas.mask)",
|
||||||
|
"coherenceModeBoxBlur": "盒子模糊",
|
||||||
|
"showBoundingBox": "显示边界框",
|
||||||
|
"coherenceModeGaussianBlur": "高斯模糊",
|
||||||
|
"coherenceModeStaged": "分阶段",
|
||||||
|
"hideBoundingBox": "隐藏边界框",
|
||||||
|
"initialFitImageSize": "在拖放时调整图像大小以适配",
|
||||||
|
"invertBrushSizeScrollDirection": "反转滚动操作以调整画笔大小",
|
||||||
|
"discardCurrent": "放弃当前设置"
|
||||||
},
|
},
|
||||||
"accessibility": {
|
"accessibility": {
|
||||||
"invokeProgressBar": "Invoke 进度条",
|
"invokeProgressBar": "Invoke 进度条",
|
||||||
@ -746,11 +886,11 @@
|
|||||||
"unableToExtractSchemaNameFromRef": "无法从参考中提取架构名",
|
"unableToExtractSchemaNameFromRef": "无法从参考中提取架构名",
|
||||||
"unknownOutput": "未知输出:{{name}}",
|
"unknownOutput": "未知输出:{{name}}",
|
||||||
"unknownErrorValidatingWorkflow": "验证工作流时出现未知错误",
|
"unknownErrorValidatingWorkflow": "验证工作流时出现未知错误",
|
||||||
"collectionFieldType": "{{name}} 合集",
|
"collectionFieldType": "{{name}}(合集)",
|
||||||
"unknownNodeType": "未知节点类型",
|
"unknownNodeType": "未知节点类型",
|
||||||
"targetNodeDoesNotExist": "无效的边缘:{{node}} 的目标/输入节点不存在",
|
"targetNodeDoesNotExist": "无效的边缘:{{node}} 的目标/输入节点不存在",
|
||||||
"unknownFieldType": "$t(nodes.unknownField) 类型:{{type}}",
|
"unknownFieldType": "$t(nodes.unknownField) 类型:{{type}}",
|
||||||
"collectionOrScalarFieldType": "{{name}} 合集 | 标量",
|
"collectionOrScalarFieldType": "{{name}} (单一项目或项目集合)",
|
||||||
"nodeVersion": "节点版本",
|
"nodeVersion": "节点版本",
|
||||||
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
|
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
|
||||||
"unknownInput": "未知输入:{{name}}",
|
"unknownInput": "未知输入:{{name}}",
|
||||||
@ -759,7 +899,27 @@
|
|||||||
"newWorkflow": "新建工作流",
|
"newWorkflow": "新建工作流",
|
||||||
"newWorkflowDesc": "是否创建一个新的工作流?",
|
"newWorkflowDesc": "是否创建一个新的工作流?",
|
||||||
"newWorkflowDesc2": "当前工作流有未保存的更改。",
|
"newWorkflowDesc2": "当前工作流有未保存的更改。",
|
||||||
"unsupportedAnyOfLength": "联合(union)数据类型数目过多 ({{count}})"
|
"unsupportedAnyOfLength": "联合(union)数据类型数目过多 ({{count}})",
|
||||||
|
"resetToDefaultValue": "重置为默认值",
|
||||||
|
"clearWorkflowDesc2": "您当前的工作流有未保存的更改.",
|
||||||
|
"missingNode": "缺少调用节点",
|
||||||
|
"missingInvocationTemplate": "缺少调用模版",
|
||||||
|
"noFieldsViewMode": "此工作流程未选择任何要显示的字段.请查看完整工作流程以进行配置.",
|
||||||
|
"reorderLinearView": "调整线性视图顺序",
|
||||||
|
"viewMode": "在线性视图中使用",
|
||||||
|
"showEdgeLabelsHelp": "在边缘上显示标签,指示连接的节点",
|
||||||
|
"cannotMixAndMatchCollectionItemTypes": "集合项目类型不能混用",
|
||||||
|
"missingFieldTemplate": "缺少字段模板",
|
||||||
|
"editMode": "在工作流编辑器中编辑",
|
||||||
|
"showEdgeLabels": "显示边缘标签",
|
||||||
|
"clearWorkflowDesc": "是否清除当前工作流并创建新的?",
|
||||||
|
"graph": "图表",
|
||||||
|
"noGraph": "无图表",
|
||||||
|
"edit": "编辑",
|
||||||
|
"clearWorkflow": "清除工作流",
|
||||||
|
"imageAccessError": "无法找到图像 {{image_name}},正在恢复默认设置",
|
||||||
|
"boardAccessError": "无法找到面板 {{board_id}},正在恢复默认设置",
|
||||||
|
"modelAccessError": "无法找到模型 {{key}},正在恢复默认设置"
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"resize": "直接缩放",
|
"resize": "直接缩放",
|
||||||
@ -799,7 +959,7 @@
|
|||||||
"mediapipeFaceDescription": "使用 Mediapipe 检测面部",
|
"mediapipeFaceDescription": "使用 Mediapipe 检测面部",
|
||||||
"depthZoeDescription": "使用 Zoe 生成深度图",
|
"depthZoeDescription": "使用 Zoe 生成深度图",
|
||||||
"hedDescription": "整体嵌套边缘检测",
|
"hedDescription": "整体嵌套边缘检测",
|
||||||
"setControlImageDimensions": "设定控制图像尺寸宽/高为",
|
"setControlImageDimensions": "复制尺寸到宽度/高度(为模型优化)",
|
||||||
"amult": "角度倍率 (a_mult)",
|
"amult": "角度倍率 (a_mult)",
|
||||||
"bgth": "背景移除阈值 (bg_th)",
|
"bgth": "背景移除阈值 (bg_th)",
|
||||||
"lineartAnimeDescription": "动漫风格线稿处理",
|
"lineartAnimeDescription": "动漫风格线稿处理",
|
||||||
@ -810,7 +970,7 @@
|
|||||||
"addControlNet": "添加 $t(common.controlNet)",
|
"addControlNet": "添加 $t(common.controlNet)",
|
||||||
"addIPAdapter": "添加 $t(common.ipAdapter)",
|
"addIPAdapter": "添加 $t(common.ipAdapter)",
|
||||||
"safe": "保守模式",
|
"safe": "保守模式",
|
||||||
"scribble": "草绘 (scribble)",
|
"scribble": "草绘",
|
||||||
"maxFaces": "最大面部数",
|
"maxFaces": "最大面部数",
|
||||||
"pidi": "PIDI",
|
"pidi": "PIDI",
|
||||||
"normalBae": "Normal BAE",
|
"normalBae": "Normal BAE",
|
||||||
@ -925,7 +1085,8 @@
|
|||||||
"steps": "步数",
|
"steps": "步数",
|
||||||
"posStylePrompt": "正向样式提示词",
|
"posStylePrompt": "正向样式提示词",
|
||||||
"refiner": "Refiner",
|
"refiner": "Refiner",
|
||||||
"freePromptStyle": "手动输入样式提示词"
|
"freePromptStyle": "手动输入样式提示词",
|
||||||
|
"refinerSteps": "精炼步数"
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"positivePrompt": "正向提示词",
|
"positivePrompt": "正向提示词",
|
||||||
@ -952,7 +1113,12 @@
|
|||||||
"recallParameters": "召回参数",
|
"recallParameters": "召回参数",
|
||||||
"noRecallParameters": "未找到要召回的参数",
|
"noRecallParameters": "未找到要召回的参数",
|
||||||
"vae": "VAE",
|
"vae": "VAE",
|
||||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||||
|
"allPrompts": "所有提示",
|
||||||
|
"parsingFailed": "解析失败",
|
||||||
|
"recallParameter": "调用{{label}}",
|
||||||
|
"imageDimensions": "图像尺寸",
|
||||||
|
"parameterSet": "已设置参数{{parameter}}"
|
||||||
},
|
},
|
||||||
"models": {
|
"models": {
|
||||||
"noMatchingModels": "无相匹配的模型",
|
"noMatchingModels": "无相匹配的模型",
|
||||||
@ -965,7 +1131,8 @@
|
|||||||
"esrganModel": "ESRGAN 模型",
|
"esrganModel": "ESRGAN 模型",
|
||||||
"addLora": "添加 LoRA",
|
"addLora": "添加 LoRA",
|
||||||
"lora": "LoRA",
|
"lora": "LoRA",
|
||||||
"defaultVAE": "默认 VAE"
|
"defaultVAE": "默认 VAE",
|
||||||
|
"concepts": "概念"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"autoAddBoard": "自动添加面板",
|
"autoAddBoard": "自动添加面板",
|
||||||
@ -987,8 +1154,23 @@
|
|||||||
"deleteBoardOnly": "仅删除面板",
|
"deleteBoardOnly": "仅删除面板",
|
||||||
"deleteBoard": "删除面板",
|
"deleteBoard": "删除面板",
|
||||||
"deleteBoardAndImages": "删除面板和图像",
|
"deleteBoardAndImages": "删除面板和图像",
|
||||||
"deletedBoardsCannotbeRestored": "已删除的面板无法被恢复",
|
"deletedBoardsCannotbeRestored": "删除的面板无法恢复。选择“仅删除面板”选项后,相关图片将会被移至未分类区域。",
|
||||||
"movingImagesToBoard_other": "移动 {{count}} 张图像到面板:"
|
"movingImagesToBoard_other": "移动 {{count}} 张图像到面板:",
|
||||||
|
"selectedForAutoAdd": "已选中自动添加",
|
||||||
|
"hideBoards": "隐藏面板",
|
||||||
|
"noBoards": "没有{{boardType}}类型的面板",
|
||||||
|
"unarchiveBoard": "恢复面板",
|
||||||
|
"viewBoards": "查看面板",
|
||||||
|
"addPrivateBoard": "创建私密面板",
|
||||||
|
"addSharedBoard": "创建共享面板",
|
||||||
|
"boards": "面板",
|
||||||
|
"imagesWithCount_other": "{{count}}张图片",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "删除的面板无法恢复。选择“仅删除面板”后,相关图片将会被移至图片创建者的私密未分类区域。",
|
||||||
|
"private": "私密面板",
|
||||||
|
"shared": "共享面板",
|
||||||
|
"archiveBoard": "归档面板",
|
||||||
|
"archived": "已归档",
|
||||||
|
"assetsWithCount_other": "{{count}}项资源"
|
||||||
},
|
},
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
"seedBehaviour": {
|
"seedBehaviour": {
|
||||||
@ -1030,32 +1212,33 @@
|
|||||||
"paramVAEPrecision": {
|
"paramVAEPrecision": {
|
||||||
"heading": "VAE 精度",
|
"heading": "VAE 精度",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"VAE 编解码过程种使用的精度。FP16/半精度以微小的图像变化为代价提高效率。"
|
"在VAE编码和解码过程中使用的精度.",
|
||||||
|
"Fp16/半精度更高效,但可能会造成图像的一些微小差异."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"compositingCoherenceMode": {
|
"compositingCoherenceMode": {
|
||||||
"heading": "模式",
|
"heading": "模式",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"一致性层模式。"
|
"用于将新生成的遮罩区域与原图像融合的方法."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"controlNetResizeMode": {
|
"controlNetResizeMode": {
|
||||||
"heading": "缩放模式",
|
"heading": "缩放模式",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"ControlNet 输入图像适应输出图像大小的方法。"
|
"调整Control Adapter输入图像大小以适应输出图像尺寸的方法."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"clipSkip": {
|
"clipSkip": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"选择要跳过 CLIP 模型多少层。",
|
"跳过CLIP模型的层数.",
|
||||||
"部分模型跳过特定数值的层时效果会更好。"
|
"某些模型更适合结合CLIP Skip功能使用."
|
||||||
],
|
],
|
||||||
"heading": "CLIP 跳过层"
|
"heading": "CLIP 跳过层"
|
||||||
},
|
},
|
||||||
"paramModel": {
|
"paramModel": {
|
||||||
"heading": "模型",
|
"heading": "模型",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"用于去噪过程的模型。"
|
"用于图像生成的模型.不同的模型经过训练,专门用于产生不同的美学效果和内容."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"paramIterations": {
|
"paramIterations": {
|
||||||
@ -1087,19 +1270,21 @@
|
|||||||
"paramScheduler": {
|
"paramScheduler": {
|
||||||
"heading": "调度器",
|
"heading": "调度器",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"调度器 (采样器) 定义如何在图像迭代过程中添加噪声,或者定义如何根据一个模型的输出来更新采样。"
|
"生成过程中所使用的调度器.",
|
||||||
|
"每个调度器决定了在生成过程中如何逐步向图像添加噪声,或者如何根据模型的输出更新样本."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"controlNetWeight": {
|
"controlNetWeight": {
|
||||||
"heading": "权重",
|
"heading": "权重",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"ControlNet 对生成图像的影响强度。"
|
"Control Adapter的权重.权重越高,对最终图像的影响越大."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"paramCFGScale": {
|
"paramCFGScale": {
|
||||||
"heading": "CFG 等级",
|
"heading": "CFG 等级",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"控制提示词对生成过程的影响程度。"
|
"控制提示对生成过程的影响程度.",
|
||||||
|
"较高的CFG比例值可能会导致生成结果过度饱和和扭曲. "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"paramSteps": {
|
"paramSteps": {
|
||||||
@ -1117,28 +1302,29 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"lora": {
|
"lora": {
|
||||||
"heading": "LoRA 权重",
|
"heading": "LoRA",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"更高的 LoRA 权重会对最终图像产生更大的影响。"
|
"与基础模型结合使用的轻量级模型."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"infillMethod": {
|
"infillMethod": {
|
||||||
"heading": "填充方法",
|
"heading": "填充方法",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"填充选定区域的方式。"
|
"在重绘过程中使用的填充方法."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"controlNetBeginEnd": {
|
"controlNetBeginEnd": {
|
||||||
"heading": "开始 / 结束步数百分比",
|
"heading": "开始 / 结束步数百分比",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"去噪过程中在哪部分步数应用 ControlNet。",
|
"去噪过程中将应用Control Adapter 的部分.",
|
||||||
"在组合处理开始阶段应用 ControlNet,且在引导细节生成的结束阶段应用 ControlNet。"
|
"通常,在去噪过程初期应用的Control Adapters用于指导整体构图,而在后期应用的Control Adapters则用于调整细节。"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"scaleBeforeProcessing": {
|
"scaleBeforeProcessing": {
|
||||||
"heading": "处理前缩放",
|
"heading": "处理前缩放",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"生成图像前将所选区域缩放为最适合模型的大小。"
|
"\"自动\"选项会在图像生成之前将所选区域调整到最适合模型的大小.",
|
||||||
|
"\"手动\"选项允许您在图像生成之前自行选择所选区域的宽度和高度."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"paramDenoisingStrength": {
|
"paramDenoisingStrength": {
|
||||||
@ -1152,13 +1338,13 @@
|
|||||||
"heading": "种子",
|
"heading": "种子",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"控制用于生成的起始噪声。",
|
"控制用于生成的起始噪声。",
|
||||||
"禁用 “随机种子” 来以相同设置生成相同的结果。"
|
"禁用\"随机\"选项,以使用相同的生成设置产生一致的结果."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"controlNetControlMode": {
|
"controlNetControlMode": {
|
||||||
"heading": "控制模式",
|
"heading": "控制模式",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"给提示词或 ControlNet 增加更大的权重。"
|
"在提示词和ControlNet之间分配更多的权重."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
@ -1199,7 +1385,171 @@
|
|||||||
"paramCFGRescaleMultiplier": {
|
"paramCFGRescaleMultiplier": {
|
||||||
"heading": "CFG 重缩放倍数",
|
"heading": "CFG 重缩放倍数",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"CFG 引导的重缩放倍率,用于通过 zero-terminal SNR (ztsnr) 训练的模型。推荐设为 0.7。"
|
"CFG指导的重缩放乘数,适用于使用零终端信噪比(ztsnr)训练的模型.",
|
||||||
|
"对于这些模型,建议的数值为0.7."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"imageFit": {
|
||||||
|
"paragraphs": [
|
||||||
|
"将初始图像调整到与输出图像相同的宽度和高度.建议启用此功能."
|
||||||
|
],
|
||||||
|
"heading": "将初始图像适配到输出大小"
|
||||||
|
},
|
||||||
|
"paramAspect": {
|
||||||
|
"paragraphs": [
|
||||||
|
"生成图像的宽高比.调整宽高比会相应地更新图像的宽度和高度.",
|
||||||
|
"选择\"优化\"将把图像的宽度和高度设置为所选模型的最优尺寸."
|
||||||
|
],
|
||||||
|
"heading": "宽高比"
|
||||||
|
},
|
||||||
|
"refinerSteps": {
|
||||||
|
"paragraphs": [
|
||||||
|
"在图像生成过程中的细化阶段将执行的步骤数.",
|
||||||
|
"与生成步骤相似."
|
||||||
|
],
|
||||||
|
"heading": "步数"
|
||||||
|
},
|
||||||
|
"compositingMaskBlur": {
|
||||||
|
"heading": "遮罩模糊",
|
||||||
|
"paragraphs": [
|
||||||
|
"遮罩的模糊范围."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"compositingCoherenceMinDenoise": {
|
||||||
|
"paragraphs": [
|
||||||
|
"连贯模式下的最小去噪力度",
|
||||||
|
"在图像修复或重绘过程中,连贯区域的最小去噪力度"
|
||||||
|
],
|
||||||
|
"heading": "最小去噪"
|
||||||
|
},
|
||||||
|
"loraWeight": {
|
||||||
|
"paragraphs": [
|
||||||
|
"LoRA的权重,权重越高对最终图像的影响越大."
|
||||||
|
],
|
||||||
|
"heading": "权重"
|
||||||
|
},
|
||||||
|
"paramHrf": {
|
||||||
|
"heading": "启用高分辨率修复",
|
||||||
|
"paragraphs": [
|
||||||
|
"以高于模型最优分辨率的大分辨率生成高质量图像.这通常用于防止生成图像中出现重复内容."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"compositingCoherenceEdgeSize": {
|
||||||
|
"paragraphs": [
|
||||||
|
"连贯处理的边缘尺寸."
|
||||||
|
],
|
||||||
|
"heading": "边缘尺寸"
|
||||||
|
},
|
||||||
|
"paramWidth": {
|
||||||
|
"paragraphs": [
|
||||||
|
"生成图像的宽度.必须是8的倍数."
|
||||||
|
],
|
||||||
|
"heading": "宽度"
|
||||||
|
},
|
||||||
|
"refinerScheduler": {
|
||||||
|
"paragraphs": [
|
||||||
|
"在图像生成过程中的细化阶段所使用的调度程序.",
|
||||||
|
"与生成调度程序相似."
|
||||||
|
],
|
||||||
|
"heading": "调度器"
|
||||||
|
},
|
||||||
|
"seamlessTilingXAxis": {
|
||||||
|
"paragraphs": [
|
||||||
|
"沿水平轴将图像进行无缝平铺."
|
||||||
|
],
|
||||||
|
"heading": "无缝平铺X轴"
|
||||||
|
},
|
||||||
|
"paramUpscaleMethod": {
|
||||||
|
"heading": "放大方法",
|
||||||
|
"paragraphs": [
|
||||||
|
"用于高分辨率修复的图像放大方法."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"refinerModel": {
|
||||||
|
"paragraphs": [
|
||||||
|
"在图像生成过程中的细化阶段所使用的模型.",
|
||||||
|
"与生成模型相似."
|
||||||
|
],
|
||||||
|
"heading": "精炼模型"
|
||||||
|
},
|
||||||
|
"paramHeight": {
|
||||||
|
"paragraphs": [
|
||||||
|
"生成图像的高度.必须是8的倍数."
|
||||||
|
],
|
||||||
|
"heading": "高"
|
||||||
|
},
|
||||||
|
"patchmatchDownScaleSize": {
|
||||||
|
"heading": "缩小",
|
||||||
|
"paragraphs": [
|
||||||
|
"在填充之前图像缩小的程度.",
|
||||||
|
"较高的缩小比例会提升处理速度,但可能会降低图像质量."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"seamlessTilingYAxis": {
|
||||||
|
"heading": "Y轴上的无缝平铺",
|
||||||
|
"paragraphs": [
|
||||||
|
"沿垂直轴将图像进行无缝平铺."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"ipAdapterMethod": {
|
||||||
|
"paragraphs": [
|
||||||
|
"当前IP Adapter的应用方法."
|
||||||
|
],
|
||||||
|
"heading": "方法"
|
||||||
|
},
|
||||||
|
"controlNetProcessor": {
|
||||||
|
"paragraphs": [
|
||||||
|
"处理输入图像以引导生成过程的方法.不同的处理器会在生成图像中产生不同的效果或风格."
|
||||||
|
],
|
||||||
|
"heading": "处理器"
|
||||||
|
},
|
||||||
|
"refinerPositiveAestheticScore": {
|
||||||
|
"paragraphs": [
|
||||||
|
"根据训练数据,对生成结果进行加权,使其更接近于具有高美学评分的图像."
|
||||||
|
],
|
||||||
|
"heading": "正面美学评分"
|
||||||
|
},
|
||||||
|
"refinerStart": {
|
||||||
|
"paragraphs": [
|
||||||
|
"在图像生成过程中精炼阶段开始被使用的时刻.",
|
||||||
|
"0表示精炼器将全程参与图像生成,0.8表示细化器仅在生成过程的最后20%阶段被使用."
|
||||||
|
],
|
||||||
|
"heading": "精炼开始"
|
||||||
|
},
|
||||||
|
"refinerCfgScale": {
|
||||||
|
"paragraphs": [
|
||||||
|
"控制提示对生成过程的影响程度.",
|
||||||
|
"与生成CFG Scale相似."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"heading": "结构",
|
||||||
|
"paragraphs": [
|
||||||
|
"结构决定了输出图像在多大程度上保持原始图像的布局.较低的结构设置允许进行较大的变化,而较高的结构设置则会严格保持原始图像的构图和布局."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"paragraphs": [
|
||||||
|
"创造力决定了模型在添加细节时的自由度.较低的创造力会使生成结果更接近原始图像,而较高的创造力则允许更多的变化.在使用提示时,较高的创造力会增加提示对生成结果的影响."
|
||||||
|
],
|
||||||
|
"heading": "创造力"
|
||||||
|
},
|
||||||
|
"refinerNegativeAestheticScore": {
|
||||||
|
"paragraphs": [
|
||||||
|
"根据训练数据,对生成结果进行加权,使其更接近于具有低美学评分的图像."
|
||||||
|
],
|
||||||
|
"heading": "负面美学评分"
|
||||||
|
},
|
||||||
|
"upscaleModel": {
|
||||||
|
"heading": "放大模型",
|
||||||
|
"paragraphs": [
|
||||||
|
"上采样模型在添加细节之前将图像放大到输出尺寸.虽然可以使用任何支持的上采样模型,但有些模型更适合处理特定类型的图像,例如照片或线条画."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"scale": {
|
||||||
|
"heading": "缩放",
|
||||||
|
"paragraphs": [
|
||||||
|
"比例控制决定了输出图像的大小,它是基于输入图像分辨率的倍数来计算的.例如对一张1024x1024的图像进行2倍上采样,将会得到一张2048x2048的输出图像."
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -1259,7 +1609,16 @@
|
|||||||
"updated": "已更新",
|
"updated": "已更新",
|
||||||
"userWorkflows": "我的工作流",
|
"userWorkflows": "我的工作流",
|
||||||
"projectWorkflows": "项目工作流",
|
"projectWorkflows": "项目工作流",
|
||||||
"opened": "已打开"
|
"opened": "已打开",
|
||||||
|
"noRecentWorkflows": "没有最近的工作流",
|
||||||
|
"workflowCleared": "工作流已清除",
|
||||||
|
"saveWorkflowToProject": "保存工作流到项目",
|
||||||
|
"noWorkflows": "无工作流",
|
||||||
|
"convertGraph": "转换图表",
|
||||||
|
"loadWorkflow": "$t(common.load) 工作流",
|
||||||
|
"noUserWorkflows": "没有用户工作流",
|
||||||
|
"loadFromGraph": "从图表加载工作流",
|
||||||
|
"autoLayout": "自动布局"
|
||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "商店尚未初始化"
|
"storeNotInitialized": "商店尚未初始化"
|
||||||
@ -1287,5 +1646,68 @@
|
|||||||
"prompt": {
|
"prompt": {
|
||||||
"addPromptTrigger": "添加提示词触发器",
|
"addPromptTrigger": "添加提示词触发器",
|
||||||
"noMatchingTriggers": "没有匹配的触发器"
|
"noMatchingTriggers": "没有匹配的触发器"
|
||||||
|
},
|
||||||
|
"controlLayers": {
|
||||||
|
"autoNegative": "自动反向",
|
||||||
|
"opacityFilter": "透明度滤镜",
|
||||||
|
"deleteAll": "删除所有",
|
||||||
|
"moveForward": "向前移动",
|
||||||
|
"layers_other": "层",
|
||||||
|
"globalControlAdapterLayer": "全局 $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||||
|
"moveBackward": "向后移动",
|
||||||
|
"regionalGuidance": "区域导向",
|
||||||
|
"controlLayers": "控制层",
|
||||||
|
"moveToBack": "移动到后面",
|
||||||
|
"brushSize": "笔刷尺寸",
|
||||||
|
"moveToFront": "移动到前面",
|
||||||
|
"addLayer": "添加层",
|
||||||
|
"deletePrompt": "删除提示词",
|
||||||
|
"resetRegion": "重置区域",
|
||||||
|
"debugLayers": "调试图层",
|
||||||
|
"maskPreviewColor": "遮罩预览颜色",
|
||||||
|
"addPositivePrompt": "添加 $t(common.positivePrompt)",
|
||||||
|
"addNegativePrompt": "添加 $t(common.negativePrompt)",
|
||||||
|
"addIPAdapter": "添加 $t(common.ipAdapter)",
|
||||||
|
"globalIPAdapterLayer": "全局 $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||||
|
"globalInitialImage": "全局初始图像",
|
||||||
|
"noLayersAdded": "没有层被添加",
|
||||||
|
"globalIPAdapter": "全局 $t(common.ipAdapter)",
|
||||||
|
"resetProcessor": "重置处理器至默认值",
|
||||||
|
"globalMaskOpacity": "全局遮罩透明度",
|
||||||
|
"rectangle": "矩形",
|
||||||
|
"opacity": "透明度",
|
||||||
|
"clearProcessor": "清除处理器",
|
||||||
|
"globalControlAdapter": "全局 $t(controlnet.controlAdapter_one)"
|
||||||
|
},
|
||||||
|
"ui": {
|
||||||
|
"tabs": {
|
||||||
|
"generation": "生成",
|
||||||
|
"queue": "队列",
|
||||||
|
"canvas": "画布",
|
||||||
|
"upscaling": "放大中",
|
||||||
|
"workflows": "工作流",
|
||||||
|
"models": "模型"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"upscaling": {
|
||||||
|
"structure": "结构",
|
||||||
|
"upscaleModel": "放大模型",
|
||||||
|
"missingUpscaleModel": "缺少放大模型",
|
||||||
|
"missingTileControlNetModel": "没有安装有效的tile ControlNet 模型",
|
||||||
|
"missingUpscaleInitialImage": "缺少用于放大的原始图像",
|
||||||
|
"creativity": "创造力",
|
||||||
|
"postProcessingModel": "后处理模型",
|
||||||
|
"scale": "缩放",
|
||||||
|
"tileControlNetModelDesc": "根据所选的主模型架构,选择相应的Tile ControlNet模型",
|
||||||
|
"upscaleModelDesc": "图像放大(图像到图像转换)模型",
|
||||||
|
"postProcessingMissingModelWarning": "请访问 <LinkComponent>模型管理器</LinkComponent>来安装一个后处理(图像到图像转换)模型.",
|
||||||
|
"missingModelsWarning": "请访问<LinkComponent>模型管理器</LinkComponent> 安装所需的模型:",
|
||||||
|
"mainModelDesc": "主模型(SD1.5或SDXL架构)"
|
||||||
|
},
|
||||||
|
"upsell": {
|
||||||
|
"inviteTeammates": "邀请团队成员",
|
||||||
|
"professional": "专业",
|
||||||
|
"professionalUpsell": "可在 Invoke 的专业版中使用.点击此处或访问 invoke.com/pricing 了解更多详情.",
|
||||||
|
"shareAccess": "共享访问权限"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -65,6 +65,11 @@ export type AppConfig = {
|
|||||||
*/
|
*/
|
||||||
shouldUpdateImagesOnConnect: boolean;
|
shouldUpdateImagesOnConnect: boolean;
|
||||||
shouldFetchMetadataFromApi: boolean;
|
shouldFetchMetadataFromApi: boolean;
|
||||||
|
/**
|
||||||
|
* Sets a size limit for outputs on the upscaling tab. This is a maximum dimension, so the actual max number of pixels
|
||||||
|
* will be the square of this value.
|
||||||
|
*/
|
||||||
|
maxUpscaleDimension?: number;
|
||||||
allowPrivateBoards: boolean;
|
allowPrivateBoards: boolean;
|
||||||
disabledTabs: InvokeTabName[];
|
disabledTabs: InvokeTabName[];
|
||||||
disabledFeatures: AppFeature[];
|
disabledFeatures: AppFeature[];
|
||||||
|
@ -16,6 +16,7 @@ import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettin
|
|||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
|
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import i18n from 'i18next';
|
import i18n from 'i18next';
|
||||||
@ -42,6 +43,7 @@ const createSelector = (templates: Templates) =>
|
|||||||
selectControlLayersSlice,
|
selectControlLayersSlice,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
selectUpscalelice,
|
selectUpscalelice,
|
||||||
|
selectConfigSlice,
|
||||||
],
|
],
|
||||||
(
|
(
|
||||||
controlAdapters,
|
controlAdapters,
|
||||||
@ -52,7 +54,8 @@ const createSelector = (templates: Templates) =>
|
|||||||
dynamicPrompts,
|
dynamicPrompts,
|
||||||
controlLayers,
|
controlLayers,
|
||||||
activeTabName,
|
activeTabName,
|
||||||
upscale
|
upscale,
|
||||||
|
config
|
||||||
) => {
|
) => {
|
||||||
const { model } = generation;
|
const { model } = generation;
|
||||||
const { size } = controlLayers.present;
|
const { size } = controlLayers.present;
|
||||||
@ -209,6 +212,16 @@ const createSelector = (templates: Templates) =>
|
|||||||
} else if (activeTabName === 'upscaling') {
|
} else if (activeTabName === 'upscaling') {
|
||||||
if (!upscale.upscaleInitialImage) {
|
if (!upscale.upscaleInitialImage) {
|
||||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleInitialImage') });
|
reasons.push({ content: i18n.t('upscaling.missingUpscaleInitialImage') });
|
||||||
|
} else if (config.maxUpscaleDimension) {
|
||||||
|
const { width, height } = upscale.upscaleInitialImage;
|
||||||
|
const { scale } = upscale;
|
||||||
|
|
||||||
|
const maxPixels = config.maxUpscaleDimension ** 2;
|
||||||
|
const upscaledPixels = width * scale * height * scale;
|
||||||
|
|
||||||
|
if (upscaledPixels > maxPixels) {
|
||||||
|
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!upscale.upscaleModel) {
|
if (!upscale.upscaleModel) {
|
||||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||||
|
@ -42,6 +42,7 @@ const DepthAnythingProcessor = (props: Props) => {
|
|||||||
|
|
||||||
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
||||||
() => [
|
() => [
|
||||||
|
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
|
||||||
{ label: t('controlnet.small'), value: 'small' },
|
{ label: t('controlnet.small'), value: 'small' },
|
||||||
{ label: t('controlnet.base'), value: 'base' },
|
{ label: t('controlnet.base'), value: 'base' },
|
||||||
{ label: t('controlnet.large'), value: 'large' },
|
{ label: t('controlnet.large'), value: 'large' },
|
||||||
|
@ -94,7 +94,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
buildDefaults: (baseModel?: BaseModelType) => ({
|
buildDefaults: (baseModel?: BaseModelType) => ({
|
||||||
id: 'depth_anything_image_processor',
|
id: 'depth_anything_image_processor',
|
||||||
type: 'depth_anything_image_processor',
|
type: 'depth_anything_image_processor',
|
||||||
model_size: 'small',
|
model_size: 'small_v2',
|
||||||
resolution: baseModel === 'sdxl' ? 1024 : 512,
|
resolution: baseModel === 'sdxl' ? 1024 : 512,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
@ -84,7 +84,7 @@ export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
|||||||
'type' | 'model_size' | 'resolution' | 'offload'
|
'type' | 'model_size' | 'resolution' | 'offload'
|
||||||
>;
|
>;
|
||||||
|
|
||||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
|
||||||
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
||||||
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
||||||
zDepthAnythingModelSize.safeParse(v).success;
|
zDepthAnythingModelSize.safeParse(v).success;
|
||||||
|
@ -24,6 +24,7 @@ export const DepthAnythingProcessor = memo(({ onChange, config }: Props) => {
|
|||||||
|
|
||||||
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
|
||||||
() => [
|
() => [
|
||||||
|
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
|
||||||
{ label: t('controlnet.small'), value: 'small' },
|
{ label: t('controlnet.small'), value: 'small' },
|
||||||
{ label: t('controlnet.base'), value: 'base' },
|
{ label: t('controlnet.base'), value: 'base' },
|
||||||
{ label: t('controlnet.large'), value: 'large' },
|
{ label: t('controlnet.large'), value: 'large' },
|
||||||
|
@ -36,7 +36,7 @@ const zContentShuffleProcessorConfig = z.object({
|
|||||||
});
|
});
|
||||||
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
|
export type ContentShuffleProcessorConfig = z.infer<typeof zContentShuffleProcessorConfig>;
|
||||||
|
|
||||||
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
const zDepthAnythingModelSize = z.enum(['large', 'base', 'small', 'small_v2']);
|
||||||
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
export type DepthAnythingModelSize = z.infer<typeof zDepthAnythingModelSize>;
|
||||||
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
export const isDepthAnythingModelSize = (v: unknown): v is DepthAnythingModelSize =>
|
||||||
zDepthAnythingModelSize.safeParse(v).success;
|
zDepthAnythingModelSize.safeParse(v).success;
|
||||||
@ -298,7 +298,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
|
|||||||
buildDefaults: () => ({
|
buildDefaults: () => ({
|
||||||
id: 'depth_anything_image_processor',
|
id: 'depth_anything_image_processor',
|
||||||
type: 'depth_anything_image_processor',
|
type: 'depth_anything_image_processor',
|
||||||
model_size: 'small',
|
model_size: 'small_v2',
|
||||||
}),
|
}),
|
||||||
buildNode: (image, config) => ({
|
buildNode: (image, config) => ({
|
||||||
...config,
|
...config,
|
||||||
|
@ -56,7 +56,6 @@ const DeleteImageModal = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const shouldConfirmOnDelete = useAppSelector((s) => s.system.shouldConfirmOnDelete);
|
const shouldConfirmOnDelete = useAppSelector((s) => s.system.shouldConfirmOnDelete);
|
||||||
const canRestoreDeletedImagesFromBin = useAppSelector((s) => s.config.canRestoreDeletedImagesFromBin);
|
|
||||||
const isModalOpen = useAppSelector((s) => s.deleteImageModal.isModalOpen);
|
const isModalOpen = useAppSelector((s) => s.deleteImageModal.isModalOpen);
|
||||||
const { imagesToDelete, imagesUsage, imageUsageSummary } = useAppSelector(selectImageUsages);
|
const { imagesToDelete, imagesUsage, imageUsageSummary } = useAppSelector(selectImageUsages);
|
||||||
|
|
||||||
@ -90,7 +89,7 @@ const DeleteImageModal = () => {
|
|||||||
<Flex direction="column" gap={3}>
|
<Flex direction="column" gap={3}>
|
||||||
<ImageUsageMessage imageUsage={imageUsageSummary} />
|
<ImageUsageMessage imageUsage={imageUsageSummary} />
|
||||||
<Divider />
|
<Divider />
|
||||||
<Text>{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}</Text>
|
<Text>{t('gallery.deleteImagePermanent')}</Text>
|
||||||
<Text>{t('common.areYouSure')}</Text>
|
<Text>{t('common.areYouSure')}</Text>
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
||||||
|
@ -35,7 +35,6 @@ type Props = {
|
|||||||
const DeleteBoardModal = (props: Props) => {
|
const DeleteBoardModal = (props: Props) => {
|
||||||
const { boardToDelete, setBoardToDelete } = props;
|
const { boardToDelete, setBoardToDelete } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const canRestoreDeletedImagesFromBin = useAppSelector((s) => s.config.canRestoreDeletedImagesFromBin);
|
|
||||||
const { currentData: boardImageNames, isFetching: isFetchingBoardNames } = useListAllImageNamesForBoardQuery(
|
const { currentData: boardImageNames, isFetching: isFetchingBoardNames } = useListAllImageNamesForBoardQuery(
|
||||||
boardToDelete?.board_id ?? skipToken
|
boardToDelete?.board_id ?? skipToken
|
||||||
);
|
);
|
||||||
@ -125,9 +124,7 @@ const DeleteBoardModal = (props: Props) => {
|
|||||||
? t('boards.deletedPrivateBoardsCannotbeRestored')
|
? t('boards.deletedPrivateBoardsCannotbeRestored')
|
||||||
: t('boards.deletedBoardsCannotbeRestored')}
|
: t('boards.deletedBoardsCannotbeRestored')}
|
||||||
</Text>
|
</Text>
|
||||||
<Text>
|
<Text>{t('gallery.deleteImagePermanent')}</Text>
|
||||||
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</AlertDialogBody>
|
</AlertDialogBody>
|
||||||
<AlertDialogFooter>
|
<AlertDialogFooter>
|
||||||
|
@ -125,19 +125,11 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
||||||
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
cfg_scale,
|
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
model: Graph.getModelMetadataField(modelConfig),
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
vae: vae ?? undefined,
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||||
@ -168,24 +160,33 @@ export const buildMultidiffusionUpscaleGraph = async (state: RootState): Promise
|
|||||||
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
g.addEdge(modelNode, 'unet', tiledMultidiffusionNode, 'unet');
|
||||||
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
|
||||||
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
|
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
cfg_scale,
|
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model: Graph.getModelMetadataField(modelConfig),
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
vae: vae ?? undefined,
|
|
||||||
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
|
|
||||||
creativity,
|
|
||||||
structure,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
const upscaleModelConfig = await fetchModelConfigWithTypeGuard(upscaleModel.key, isSpandrelImageToImageModelConfig);
|
||||||
|
|
||||||
|
g.upsertMetadata({
|
||||||
|
cfg_scale,
|
||||||
|
model: Graph.getModelMetadataField(modelConfig),
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
scheduler,
|
||||||
|
vae: vae ?? undefined,
|
||||||
|
upscale_model: Graph.getModelMetadataField(upscaleModelConfig),
|
||||||
|
creativity,
|
||||||
|
structure,
|
||||||
|
upscale_initial_image: {
|
||||||
|
image_name: upscaleInitialImage.image_name,
|
||||||
|
width: upscaleInitialImage.width,
|
||||||
|
height: upscaleInitialImage.height,
|
||||||
|
},
|
||||||
|
upscale_scale: scale,
|
||||||
|
});
|
||||||
|
|
||||||
g.setMetadataReceivingNode(l2iNode);
|
g.setMetadataReceivingNode(l2iNode);
|
||||||
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
|
g.addEdgeToMetadata(upscaleNode, 'width', 'width');
|
||||||
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
|
g.addEdgeToMetadata(upscaleNode, 'height', 'height');
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
const createIsTooLargeToUpscaleSelector = (imageDTO?: ImageDTO) =>
|
||||||
|
createMemoizedSelector(selectUpscalelice, selectConfigSlice, (upscale, config) => {
|
||||||
|
const { upscaleModel, scale } = upscale;
|
||||||
|
const { maxUpscaleDimension } = config;
|
||||||
|
|
||||||
|
if (!maxUpscaleDimension || !upscaleModel || !imageDTO) {
|
||||||
|
// When these are missing, another warning will be shown
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { width, height } = imageDTO;
|
||||||
|
|
||||||
|
const maxPixels = maxUpscaleDimension ** 2;
|
||||||
|
const upscaledPixels = width * scale * height * scale;
|
||||||
|
|
||||||
|
return upscaledPixels > maxPixels;
|
||||||
|
});
|
||||||
|
|
||||||
|
export const useIsTooLargeToUpscale = (imageDTO?: ImageDTO) => {
|
||||||
|
const selectIsTooLargeToUpscale = useMemo(() => createIsTooLargeToUpscaleSelector(imageDTO), [imageDTO]);
|
||||||
|
return useAppSelector(selectIsTooLargeToUpscale);
|
||||||
|
};
|
@ -1,4 +1,4 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||||
@ -41,13 +41,30 @@ export const UpscaleInitialImage = () => {
|
|||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
/>
|
/>
|
||||||
{imageDTO && (
|
{imageDTO && (
|
||||||
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
<>
|
||||||
<IAIDndImageIcon
|
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
|
||||||
onClick={onReset}
|
<IAIDndImageIcon
|
||||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
onClick={onReset}
|
||||||
tooltip={t('controlnet.resetControlImage')}
|
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||||
/>
|
tooltip={t('controlnet.resetControlImage')}
|
||||||
</Flex>
|
/>
|
||||||
|
</Flex>
|
||||||
|
<Text
|
||||||
|
position="absolute"
|
||||||
|
background="base.900"
|
||||||
|
color="base.50"
|
||||||
|
fontSize="sm"
|
||||||
|
fontWeight="semibold"
|
||||||
|
bottom={0}
|
||||||
|
left={0}
|
||||||
|
opacity={0.7}
|
||||||
|
px={2}
|
||||||
|
lineHeight={1.25}
|
||||||
|
borderTopEndRadius="base"
|
||||||
|
borderBottomStartRadius="base"
|
||||||
|
pointerEvents="none"
|
||||||
|
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Button, Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
import { Button, Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallModels';
|
import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallModels';
|
||||||
|
import { useIsTooLargeToUpscale } from 'features/parameters/hooks/useIsTooLargeToUpscale';
|
||||||
import { tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
|
import { tileControlnetModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { useCallback, useEffect, useMemo } from 'react';
|
import { useCallback, useEffect, useMemo } from 'react';
|
||||||
@ -12,10 +13,13 @@ export const UpscaleWarning = () => {
|
|||||||
const model = useAppSelector((s) => s.generation.model);
|
const model = useAppSelector((s) => s.generation.model);
|
||||||
const upscaleModel = useAppSelector((s) => s.upscale.upscaleModel);
|
const upscaleModel = useAppSelector((s) => s.upscale.upscaleModel);
|
||||||
const tileControlnetModel = useAppSelector((s) => s.upscale.tileControlnetModel);
|
const tileControlnetModel = useAppSelector((s) => s.upscale.tileControlnetModel);
|
||||||
|
const upscaleInitialImage = useAppSelector((s) => s.upscale.upscaleInitialImage);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const [modelConfigs, { isLoading }] = useControlNetModels();
|
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||||
const shouldShowButton = useMemo(() => !disabledTabs.includes('models'), [disabledTabs]);
|
const shouldShowButton = useMemo(() => !disabledTabs.includes('models'), [disabledTabs]);
|
||||||
|
const maxUpscaleDimension = useAppSelector((s) => s.config.maxUpscaleDimension);
|
||||||
|
const isTooLargeToUpscale = useIsTooLargeToUpscale(upscaleInitialImage || undefined);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const validModel = modelConfigs.find((cnetModel) => {
|
const validModel = modelConfigs.find((cnetModel) => {
|
||||||
@ -24,7 +28,7 @@ export const UpscaleWarning = () => {
|
|||||||
dispatch(tileControlnetModelChanged(validModel || null));
|
dispatch(tileControlnetModelChanged(validModel || null));
|
||||||
}, [model?.base, modelConfigs, dispatch]);
|
}, [model?.base, modelConfigs, dispatch]);
|
||||||
|
|
||||||
const warnings = useMemo(() => {
|
const modelWarnings = useMemo(() => {
|
||||||
const _warnings: string[] = [];
|
const _warnings: string[] = [];
|
||||||
if (!model) {
|
if (!model) {
|
||||||
_warnings.push(t('upscaling.mainModelDesc'));
|
_warnings.push(t('upscaling.mainModelDesc'));
|
||||||
@ -35,33 +39,44 @@ export const UpscaleWarning = () => {
|
|||||||
if (!upscaleModel) {
|
if (!upscaleModel) {
|
||||||
_warnings.push(t('upscaling.upscaleModelDesc'));
|
_warnings.push(t('upscaling.upscaleModelDesc'));
|
||||||
}
|
}
|
||||||
|
|
||||||
return _warnings;
|
return _warnings;
|
||||||
}, [model, tileControlnetModel, upscaleModel, t]);
|
}, [model, tileControlnetModel, upscaleModel, t]);
|
||||||
|
|
||||||
|
const otherWarnings = useMemo(() => {
|
||||||
|
const _warnings: string[] = [];
|
||||||
|
if (isTooLargeToUpscale && maxUpscaleDimension) {
|
||||||
|
_warnings.push(
|
||||||
|
t('upscaling.exceedsMaxSizeDetails', { maxUpscaleDimension: maxUpscaleDimension.toLocaleString() })
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return _warnings;
|
||||||
|
}, [isTooLargeToUpscale, t, maxUpscaleDimension]);
|
||||||
|
|
||||||
const handleGoToModelManager = useCallback(() => {
|
const handleGoToModelManager = useCallback(() => {
|
||||||
dispatch(setActiveTab('models'));
|
dispatch(setActiveTab('models'));
|
||||||
$installModelsTab.set(3);
|
$installModelsTab.set(3);
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
if (!warnings.length || isLoading || !shouldShowButton) {
|
if ((!modelWarnings.length && !otherWarnings.length) || isLoading || !shouldShowButton) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
|
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
|
||||||
<Text>
|
{!!modelWarnings.length && (
|
||||||
<Trans
|
<Text>
|
||||||
i18nKey="upscaling.missingModelsWarning"
|
<Trans
|
||||||
components={{
|
i18nKey="upscaling.missingModelsWarning"
|
||||||
LinkComponent: (
|
components={{
|
||||||
<Button size="sm" flexGrow={0} variant="link" color="base.50" onClick={handleGoToModelManager} />
|
LinkComponent: (
|
||||||
),
|
<Button size="sm" flexGrow={0} variant="link" color="base.50" onClick={handleGoToModelManager} />
|
||||||
}}
|
),
|
||||||
/>
|
}}
|
||||||
</Text>
|
/>
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
<UnorderedList>
|
<UnorderedList>
|
||||||
{warnings.map((warning) => (
|
{[...modelWarnings, ...otherWarnings].map((warning) => (
|
||||||
<ListItem key={warning}>{warning}</ListItem>
|
<ListItem key={warning}>{warning}</ListItem>
|
||||||
))}
|
))}
|
||||||
</UnorderedList>
|
</UnorderedList>
|
||||||
|
@ -24,7 +24,6 @@ const initialConfigState: AppConfig = {
|
|||||||
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
|
disabledSDFeatures: ['variation', 'symmetry', 'hires', 'perlinNoise', 'noiseThreshold'],
|
||||||
nodesAllowlist: undefined,
|
nodesAllowlist: undefined,
|
||||||
nodesDenylist: undefined,
|
nodesDenylist: undefined,
|
||||||
canRestoreDeletedImagesFromBin: true,
|
|
||||||
sd: {
|
sd: {
|
||||||
disabledControlNetModels: [],
|
disabledControlNetModels: [],
|
||||||
disabledControlNetProcessors: [],
|
disabledControlNetProcessors: [],
|
||||||
|
@ -1298,6 +1298,11 @@ export type components = {
|
|||||||
};
|
};
|
||||||
/** Body_update_style_preset */
|
/** Body_update_style_preset */
|
||||||
Body_update_style_preset: {
|
Body_update_style_preset: {
|
||||||
|
/**
|
||||||
|
* Image
|
||||||
|
* @description The image file to upload
|
||||||
|
*/
|
||||||
|
image?: Blob | null;
|
||||||
/**
|
/**
|
||||||
* Name
|
* Name
|
||||||
* @description The name of the style preset to create
|
* @description The name of the style preset to create
|
||||||
@ -1313,11 +1318,6 @@ export type components = {
|
|||||||
* @description The negative prompt of the style preset
|
* @description The negative prompt of the style preset
|
||||||
*/
|
*/
|
||||||
negative_prompt: string;
|
negative_prompt: string;
|
||||||
/**
|
|
||||||
* Image
|
|
||||||
* @description The image file to upload
|
|
||||||
*/
|
|
||||||
image?: Blob | null;
|
|
||||||
};
|
};
|
||||||
/** Body_update_workflow */
|
/** Body_update_workflow */
|
||||||
Body_update_workflow: {
|
Body_update_workflow: {
|
||||||
@ -7388,147 +7388,147 @@ export type components = {
|
|||||||
project_id: string | null;
|
project_id: string | null;
|
||||||
};
|
};
|
||||||
InvocationOutputMap: {
|
InvocationOutputMap: {
|
||||||
integer: components["schemas"]["IntegerOutput"];
|
|
||||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
rand_int: components["schemas"]["IntegerOutput"];
|
|
||||||
infill_cv2: components["schemas"]["ImageOutput"];
|
|
||||||
img_scale: components["schemas"]["ImageOutput"];
|
|
||||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
|
||||||
latents: components["schemas"]["LatentsOutput"];
|
|
||||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
|
||||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
|
||||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
|
||||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
|
||||||
vae_loader: components["schemas"]["VAEOutput"];
|
|
||||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
|
||||||
freeu: components["schemas"]["UNetOutput"];
|
|
||||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
|
||||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
|
||||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
|
||||||
range: components["schemas"]["IntegerCollectionOutput"];
|
|
||||||
string_replace: components["schemas"]["StringOutput"];
|
|
||||||
boolean: components["schemas"]["BooleanOutput"];
|
|
||||||
show_image: components["schemas"]["ImageOutput"];
|
|
||||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
|
||||||
metadata: components["schemas"]["MetadataOutput"];
|
|
||||||
img_conv: components["schemas"]["ImageOutput"];
|
|
||||||
sub: components["schemas"]["IntegerOutput"];
|
|
||||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
|
||||||
save_image: components["schemas"]["ImageOutput"];
|
|
||||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
|
||||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
|
||||||
lresize: components["schemas"]["LatentsOutput"];
|
|
||||||
lblend: components["schemas"]["LatentsOutput"];
|
|
||||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
|
||||||
rand_float: components["schemas"]["FloatOutput"];
|
|
||||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
|
||||||
tomask: components["schemas"]["ImageOutput"];
|
|
||||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
|
||||||
color_correct: components["schemas"]["ImageOutput"];
|
color_correct: components["schemas"]["ImageOutput"];
|
||||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
|
||||||
compel: components["schemas"]["ConditioningOutput"];
|
|
||||||
infill_tile: components["schemas"]["ImageOutput"];
|
|
||||||
img_resize: components["schemas"]["ImageOutput"];
|
|
||||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
|
||||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
|
||||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
|
||||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
|
||||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
|
||||||
mask_from_id: components["schemas"]["ImageOutput"];
|
|
||||||
string: components["schemas"]["StringOutput"];
|
|
||||||
float: components["schemas"]["FloatOutput"];
|
|
||||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
|
||||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
string_join: components["schemas"]["StringOutput"];
|
|
||||||
spandrel_image_to_image_autoscale: components["schemas"]["ImageOutput"];
|
|
||||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
|
||||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
|
||||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
|
||||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
face_off: components["schemas"]["FaceOffOutput"];
|
|
||||||
mul: components["schemas"]["IntegerOutput"];
|
|
||||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
|
||||||
round_float: components["schemas"]["FloatOutput"];
|
|
||||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
|
||||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
|
||||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
|
||||||
string_split: components["schemas"]["String2Output"];
|
|
||||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
|
||||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
|
||||||
img_crop: components["schemas"]["ImageOutput"];
|
|
||||||
img_watermark: components["schemas"]["ImageOutput"];
|
|
||||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
add: components["schemas"]["IntegerOutput"];
|
|
||||||
conditioning: components["schemas"]["ConditioningOutput"];
|
|
||||||
esrgan: components["schemas"]["ImageOutput"];
|
|
||||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
|
||||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
|
||||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
|
||||||
img_chan: components["schemas"]["ImageOutput"];
|
|
||||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
|
||||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
blank_image: components["schemas"]["ImageOutput"];
|
|
||||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
|
||||||
img_nsfw: components["schemas"]["ImageOutput"];
|
|
||||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
|
||||||
scheduler: components["schemas"]["SchedulerOutput"];
|
|
||||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
|
||||||
crop_latents: components["schemas"]["LatentsOutput"];
|
|
||||||
string_join_three: components["schemas"]["StringOutput"];
|
|
||||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
controlnet: components["schemas"]["ControlOutput"];
|
|
||||||
mask_edge: components["schemas"]["ImageOutput"];
|
|
||||||
img_ilerp: components["schemas"]["ImageOutput"];
|
|
||||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
|
||||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
|
||||||
face_identifier: components["schemas"]["ImageOutput"];
|
|
||||||
i2l: components["schemas"]["LatentsOutput"];
|
|
||||||
infill_lama: components["schemas"]["ImageOutput"];
|
|
||||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
|
||||||
mask_combine: components["schemas"]["ImageOutput"];
|
|
||||||
noise: components["schemas"]["NoiseOutput"];
|
|
||||||
div: components["schemas"]["IntegerOutput"];
|
|
||||||
img_paste: components["schemas"]["ImageOutput"];
|
|
||||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
|
||||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
|
||||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
|
||||||
l2i: components["schemas"]["ImageOutput"];
|
|
||||||
float_math: components["schemas"]["FloatOutput"];
|
|
||||||
img_lerp: components["schemas"]["ImageOutput"];
|
|
||||||
spandrel_image_to_image: components["schemas"]["ImageOutput"];
|
|
||||||
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
|
|
||||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
|
||||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
|
||||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
|
||||||
canny_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
img_mul: components["schemas"]["ImageOutput"];
|
|
||||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
|
||||||
color: components["schemas"]["ColorOutput"];
|
|
||||||
lscale: components["schemas"]["LatentsOutput"];
|
|
||||||
integer_math: components["schemas"]["IntegerOutput"];
|
|
||||||
infill_rgba: components["schemas"]["ImageOutput"];
|
|
||||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
img_blur: components["schemas"]["ImageOutput"];
|
|
||||||
float_to_int: components["schemas"]["IntegerOutput"];
|
|
||||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
|
||||||
collect: components["schemas"]["CollectInvocationOutput"];
|
|
||||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
|
||||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
|
||||||
image: components["schemas"]["ImageOutput"];
|
|
||||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
|
||||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
|
||||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
|
||||||
core_metadata: components["schemas"]["MetadataOutput"];
|
|
||||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||||
|
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||||
|
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
integer_math: components["schemas"]["IntegerOutput"];
|
||||||
|
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||||
|
img_resize: components["schemas"]["ImageOutput"];
|
||||||
|
infill_lama: components["schemas"]["ImageOutput"];
|
||||||
|
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||||
|
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||||
|
latents: components["schemas"]["LatentsOutput"];
|
||||||
|
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||||
|
crop_latents: components["schemas"]["LatentsOutput"];
|
||||||
|
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||||
|
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
color: components["schemas"]["ColorOutput"];
|
||||||
|
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||||
|
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||||
|
collect: components["schemas"]["CollectInvocationOutput"];
|
||||||
|
img_crop: components["schemas"]["ImageOutput"];
|
||||||
|
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||||
|
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||||
|
lblend: components["schemas"]["LatentsOutput"];
|
||||||
|
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||||
|
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||||
|
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||||
|
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||||
|
save_image: components["schemas"]["ImageOutput"];
|
||||||
|
controlnet: components["schemas"]["ControlOutput"];
|
||||||
|
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||||
|
l2i: components["schemas"]["ImageOutput"];
|
||||||
|
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||||
|
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||||
|
img_watermark: components["schemas"]["ImageOutput"];
|
||||||
|
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||||
|
tiled_multi_diffusion_denoise_latents: components["schemas"]["LatentsOutput"];
|
||||||
|
string_replace: components["schemas"]["StringOutput"];
|
||||||
|
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||||
|
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||||
|
img_chan: components["schemas"]["ImageOutput"];
|
||||||
|
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||||
|
blank_image: components["schemas"]["ImageOutput"];
|
||||||
|
face_identifier: components["schemas"]["ImageOutput"];
|
||||||
|
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
img_ilerp: components["schemas"]["ImageOutput"];
|
||||||
|
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
infill_cv2: components["schemas"]["ImageOutput"];
|
||||||
|
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||||
|
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||||
|
float_to_int: components["schemas"]["IntegerOutput"];
|
||||||
|
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||||
|
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||||
|
compel: components["schemas"]["ConditioningOutput"];
|
||||||
|
lscale: components["schemas"]["LatentsOutput"];
|
||||||
|
core_metadata: components["schemas"]["MetadataOutput"];
|
||||||
|
add: components["schemas"]["IntegerOutput"];
|
||||||
|
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||||
|
metadata: components["schemas"]["MetadataOutput"];
|
||||||
|
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||||
|
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||||
|
img_paste: components["schemas"]["ImageOutput"];
|
||||||
|
sub: components["schemas"]["IntegerOutput"];
|
||||||
|
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||||
|
float: components["schemas"]["FloatOutput"];
|
||||||
|
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||||
|
noise: components["schemas"]["NoiseOutput"];
|
||||||
|
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||||
|
mul: components["schemas"]["IntegerOutput"];
|
||||||
|
esrgan: components["schemas"]["ImageOutput"];
|
||||||
|
img_lerp: components["schemas"]["ImageOutput"];
|
||||||
|
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||||
|
integer: components["schemas"]["IntegerOutput"];
|
||||||
|
spandrel_image_to_image_autoscale: components["schemas"]["ImageOutput"];
|
||||||
|
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||||
|
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||||
|
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||||
|
img_mul: components["schemas"]["ImageOutput"];
|
||||||
|
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||||
|
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
div: components["schemas"]["IntegerOutput"];
|
||||||
|
canny_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
image: components["schemas"]["ImageOutput"];
|
||||||
|
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||||
|
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||||
|
round_float: components["schemas"]["FloatOutput"];
|
||||||
|
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
img_blur: components["schemas"]["ImageOutput"];
|
||||||
|
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||||
|
lresize: components["schemas"]["LatentsOutput"];
|
||||||
|
string: components["schemas"]["StringOutput"];
|
||||||
|
boolean: components["schemas"]["BooleanOutput"];
|
||||||
|
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||||
|
infill_tile: components["schemas"]["ImageOutput"];
|
||||||
|
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||||
|
string_split: components["schemas"]["String2Output"];
|
||||||
|
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||||
|
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||||
|
string_join: components["schemas"]["StringOutput"];
|
||||||
|
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||||
|
infill_rgba: components["schemas"]["ImageOutput"];
|
||||||
|
show_image: components["schemas"]["ImageOutput"];
|
||||||
|
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||||
|
scheduler: components["schemas"]["SchedulerOutput"];
|
||||||
|
mask_combine: components["schemas"]["ImageOutput"];
|
||||||
|
string_join_three: components["schemas"]["StringOutput"];
|
||||||
|
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||||
|
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||||
|
range: components["schemas"]["IntegerCollectionOutput"];
|
||||||
|
img_scale: components["schemas"]["ImageOutput"];
|
||||||
|
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||||
|
freeu: components["schemas"]["UNetOutput"];
|
||||||
|
face_off: components["schemas"]["FaceOffOutput"];
|
||||||
|
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||||
|
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||||
|
rand_int: components["schemas"]["IntegerOutput"];
|
||||||
|
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||||
|
img_conv: components["schemas"]["ImageOutput"];
|
||||||
|
mask_edge: components["schemas"]["ImageOutput"];
|
||||||
|
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||||
|
img_nsfw: components["schemas"]["ImageOutput"];
|
||||||
|
vae_loader: components["schemas"]["VAEOutput"];
|
||||||
|
i2l: components["schemas"]["LatentsOutput"];
|
||||||
|
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||||
|
float_math: components["schemas"]["FloatOutput"];
|
||||||
|
rand_float: components["schemas"]["FloatOutput"];
|
||||||
|
spandrel_image_to_image: components["schemas"]["ImageOutput"];
|
||||||
|
conditioning: components["schemas"]["ConditioningOutput"];
|
||||||
|
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||||
|
tomask: components["schemas"]["ImageOutput"];
|
||||||
|
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||||
|
mask_from_id: components["schemas"]["ImageOutput"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* InvocationStartedEvent
|
* InvocationStartedEvent
|
||||||
|
@ -74,7 +74,8 @@ dependencies = [
|
|||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
# Exclude 3.9.1 which has a problem on windows, see https://github.com/matplotlib/matplotlib/issues/28551
|
||||||
|
"matplotlib!=3.9.1",
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"picklescan",
|
"picklescan",
|
||||||
@ -89,7 +90,6 @@ dependencies = [
|
|||||||
"rich~=13.3",
|
"rich~=13.3",
|
||||||
"scikit-image~=0.21.0",
|
"scikit-image~=0.21.0",
|
||||||
"semver~=3.0.1",
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"windows-curses; sys_platform=='win32'",
|
"windows-curses; sys_platform=='win32'",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user