mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Rename Structural Lora to Control Lora
This commit is contained in:
committed by
Kent Keirsey
parent
040551d4fb
commit
046d19446c
@ -56,7 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
CLIPLEmbedModel = "CLIPLEmbedModelField"
|
||||
CLIPGEmbedModel = "CLIPGEmbedModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
StructuralLoRAModel = "StructuralLoRAModelField"
|
||||
ControlLoRAModel = "ControlLoRAModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@ -144,7 +144,7 @@ class FieldDescriptions:
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
structural_lora_model = "Structural LoRA model to load"
|
||||
control_lora_model = "Control LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
|
55
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
55
invokeai/app/invocations/flux_control_lora_loader.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_control_lora_loader_output")
|
||||
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Control LoRA Loader Output"""
|
||||
|
||||
control_lora: Optional[ControlLoRAField] = OutputField(
|
||||
title="Flux Control Lora", description="Control LoRAs to apply on model loading", default=None
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_control_lora_loader",
|
||||
title="Flux Control LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
|
||||
)
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
output = FluxControlLoRALoaderOutput()
|
||||
|
||||
output.control_lora = ControlLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=1,
|
||||
)
|
||||
return output
|
@ -8,8 +8,6 @@ import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@ -24,7 +22,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
|
||||
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
@ -35,8 +33,10 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
generate_img_ids,
|
||||
@ -45,8 +45,6 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
@ -93,6 +91,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
control_lora: Optional[ControlLoRAField] = InputField(
|
||||
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control Lora", default=None
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
@ -198,7 +199,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@ -289,12 +290,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
device=x.device,
|
||||
)
|
||||
img_cond = None
|
||||
if struct_lora := self.transformer.structural_lora:
|
||||
if self.control_lora:
|
||||
# What should we do when we have multiple of these?
|
||||
if not self.controlnet_vae:
|
||||
raise ValueError("controlnet_vae must be set when using a strutural lora")
|
||||
ae_info = context.models.load(self.controlnet_vae.vae)
|
||||
img = context.images.get_pil(struct_lora.img.image_name)
|
||||
img = context.images.get_pil(self.control_lora.img.image_name)
|
||||
with ae_info as ae:
|
||||
assert isinstance(ae, AutoEncoder)
|
||||
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
|
||||
@ -359,7 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
img_cond=img_cond
|
||||
img_cond=img_cond,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@ -697,9 +698,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
|
||||
if self.transformer.structural_lora:
|
||||
loras.append(self.transformer.structural_lora)
|
||||
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
|
||||
if self.control_lora:
|
||||
loras.append(self.control_lora)
|
||||
for lora in loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
|
@ -81,8 +81,10 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
clip=CLIPField(
|
||||
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0
|
||||
),
|
||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||
|
@ -1,70 +0,0 @@
|
||||
from typing import Optional, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("flux_structural_lora_loader_output")
|
||||
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Flux Structural LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_structural_lora_loader",
|
||||
title="Flux Structural LoRA",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
|
||||
)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="FLUX Transformer",
|
||||
)
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
|
||||
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = FluxStructuralLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.structural_lora = StructuralLoRAField(
|
||||
lora=self.lora,
|
||||
img=self.image,
|
||||
weight=self.weight,
|
||||
)
|
||||
|
||||
return output
|
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import List, Optional, Literal
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -74,13 +74,15 @@ class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
class StructuralLoRAField(LoRAField):
|
||||
|
||||
class ControlLoRAField(LoRAField):
|
||||
img: ImageField = Field(description="Image to use in structural conditioning")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
|
||||
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
|
@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
|
||||
|
||||
def prepare_control(
|
||||
height: int,
|
||||
width: int,
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.flux.custom_block_processor import (
|
||||
CustomDoubleStreamBlockProcessor,
|
||||
|
@ -1,50 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
from torch import nn
|
||||
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
class DepthImageEncoder:
|
||||
depth_model_name = "LiheYoung/depth-anything-large-hf"
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
|
||||
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
hw = img.shape[-2:]
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_byte = ((img + 1.0) * 127.5).byte()
|
||||
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
|
||||
depth = self.depth_model(img.to(self.device)).predicted_depth
|
||||
depth = repeat(depth, "b h w -> b 3 h w")
|
||||
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
|
||||
depth = depth / 127.5 - 1.0
|
||||
return depth
|
||||
|
||||
class CannyImageEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
min_t: int = 50,
|
||||
max_t: int = 200,
|
||||
):
|
||||
self.device = device
|
||||
self.min_t = min_t
|
||||
self.max_t = max_t
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
assert img.shape[0] == 1, "Only batch size 1 is supported"
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img = torch.clamp(img, -1.0, 1.0)
|
||||
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
|
||||
# Apply Canny edge detection
|
||||
canny = cv2.Canny(img_np, self.min_t, self.max_t)
|
||||
# Convert back to torch tensor and reshape
|
||||
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
|
||||
canny = rearrange(canny, "h w -> 1 1 h w")
|
||||
canny = repeat(canny, "b 1 ... -> b 3 ...")
|
||||
return canny.to(self.device)
|
@ -1,21 +1,21 @@
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Any, Dict
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
|
||||
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
|
||||
# Example keys:
|
||||
# guidance_in.in_layer.lora_B.bias
|
||||
# single_blocks.0.linear1.lora_A.weight
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
|
||||
|
||||
|
||||
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
|
||||
@ -24,10 +24,11 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
|
||||
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
|
||||
"""
|
||||
return all(
|
||||
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
|
||||
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k)
|
||||
for k in state_dict.keys()
|
||||
)
|
||||
|
||||
|
||||
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
|
||||
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
|
||||
# Group keys by layer.
|
||||
@ -54,7 +55,7 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
None,
|
||||
layer_state_dict["lora_A.weight"],
|
||||
None,
|
||||
layer_state_dict["lora_B.bias"]
|
||||
layer_state_dict["lora_B.bias"],
|
||||
)
|
||||
elif "scale" in layer_state_dict:
|
||||
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
|
||||
@ -62,4 +63,3 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
|
||||
raise AssertionError(f"{layer_key} not expected")
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
|
@ -9,4 +9,6 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
|
||||
AnyLoRALayer = Union[
|
||||
LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer
|
||||
]
|
||||
|
@ -1,34 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class ReshapeWeightLayer(LoRALayerBase):
|
||||
# TODO: Just everything in this class
|
||||
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
|
||||
super().__init__(alpha=None, bias=bias)
|
||||
self.weight = torch.nn.Parameter(weight) if weight is not None else None
|
||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||
self.manual_scale = scale
|
||||
|
||||
def scale(self):
|
||||
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
|
||||
|
||||
def rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return orig_weight
|
||||
|
||||
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
if self.weight is not None:
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
if self.manual_scale is not None:
|
||||
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
return super().calc_size() + calc_tensor_size(self.manual_scale)
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
@ -17,7 +17,7 @@ class SetParameterLayer(LoRALayerBase):
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight - orig_weight
|
||||
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}
|
||||
|
||||
|
@ -9,7 +9,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.layers.norm_layer import NormLayer
|
||||
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
|
||||
|
||||
|
||||
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:
|
||||
|
@ -67,7 +67,7 @@ class ModelType(str, Enum):
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
StructuralLoRa = "structural_lora"
|
||||
ControlLoRa = "control_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -274,16 +274,16 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class StructuralLoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for Structural LoRA/Lycoris models."""
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase):
|
||||
@ -548,7 +548,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||
|
@ -9,13 +9,17 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
|
||||
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -33,7 +37,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
@ -15,10 +15,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -269,7 +269,7 @@ class ModelProbe(object):
|
||||
and isinstance(tensor_b, torch.Tensor)
|
||||
and tensor_b.shape[0] == 3072
|
||||
):
|
||||
return ModelType.StructuralLoRa
|
||||
return ModelType.ControlLoRa
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if key.startswith(
|
||||
@ -1061,7 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
@ -809,7 +809,7 @@
|
||||
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
|
||||
"starterModels": "Starter Models",
|
||||
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||
"structuralLora": "Structural LoRA",
|
||||
"controlLora": "Control LoRA",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
|
@ -24,7 +24,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterStructuralLoRAModel,
|
||||
ParameterControlLoRAModel,
|
||||
ParameterT5EncoderModel,
|
||||
ParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@ -76,7 +76,7 @@ export type ParamsState = {
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
||||
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
||||
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
||||
structuralLora: ParameterStructuralLoRAModel | null;
|
||||
controlLora: ParameterControlLoRAModel | null;
|
||||
};
|
||||
|
||||
const initialState: ParamsState = {
|
||||
@ -123,7 +123,7 @@ const initialState: ParamsState = {
|
||||
clipEmbedModel: null,
|
||||
clipLEmbedModel: null,
|
||||
clipGEmbedModel: null,
|
||||
structuralLora: null,
|
||||
controlLora: null,
|
||||
};
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
@ -198,8 +198,8 @@ export const paramsSlice = createSlice({
|
||||
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
|
||||
state.t5EncoderModel = action.payload;
|
||||
},
|
||||
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
|
||||
state.structuralLora = action.payload;
|
||||
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
|
||||
state.controlLora = action.payload;
|
||||
},
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
|
@ -46,7 +46,7 @@ import type {
|
||||
ParameterSeed,
|
||||
ParameterSteps,
|
||||
ParameterStrength,
|
||||
ParameterStructuralLoRAModel,
|
||||
ParameterControlLoRAModel,
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@ -81,7 +81,7 @@ import {
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isStructuralLoRAModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
@ -228,10 +228,10 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
||||
return modelIdentifier;
|
||||
};
|
||||
|
||||
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
|
||||
const slora = await getProperty(metadata, 'structural_lora', undefined);
|
||||
const key = await getModelKey(slora, 'structural_lora');
|
||||
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
|
||||
const parseControlLoRAModel: MetadataParseFunc<ParameterControlLoRAModel> = async (metadata) => {
|
||||
const slora = await getProperty(metadata, 'control_lora', undefined);
|
||||
const key = await getModelKey(slora, 'control_lora');
|
||||
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isControlLoRAModelConfig);
|
||||
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
|
||||
return modelIdentifier;
|
||||
};
|
||||
@ -681,7 +681,7 @@ export const parsers = {
|
||||
mainModel: parseMainModel,
|
||||
refinerModel: parseRefinerModel,
|
||||
vaeModel: parseVAEModel,
|
||||
structuralLora: parseStructuralLoRAModel,
|
||||
controlLora: parseControlLoRAModel,
|
||||
lora: parseLoRA,
|
||||
loras: parseAllLoRAs,
|
||||
controlNet: parseControlNet,
|
||||
|
@ -18,7 +18,7 @@ import {
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useSpandrelImageToImageModels,
|
||||
useStructuralLoRAModel,
|
||||
useControlLoRAModel,
|
||||
useT2IAdapterModels,
|
||||
useT5EncoderModels,
|
||||
useVAEModels,
|
||||
@ -93,10 +93,10 @@ const ModelList = () => {
|
||||
[t5EncoderModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
|
||||
const filteredStructuralLoRAModels = useMemo(
|
||||
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
|
||||
[structuralLoRAModels, searchTerm, filteredModelType]
|
||||
const [controlLoRAModels, { isLoading: isLoadingControlLoRAModels }] = useControlLoRAModel();
|
||||
const filteredControlLoRAModels = useMemo(
|
||||
() => modelsFilter(controlLoRAModels, searchTerm, filteredModelType),
|
||||
[controlLoRAModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
|
||||
@ -126,7 +126,7 @@ const ModelList = () => {
|
||||
filteredSpandrelImageToImageModels.length +
|
||||
t5EncoderModels.length +
|
||||
clipEmbedModels.length +
|
||||
structuralLoRAModels.length
|
||||
controlLoRAModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
@ -141,7 +141,7 @@ const ModelList = () => {
|
||||
filteredSpandrelImageToImageModels.length,
|
||||
t5EncoderModels.length,
|
||||
clipEmbedModels.length,
|
||||
structuralLoRAModels.length,
|
||||
controlLoRAModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
@ -204,13 +204,13 @@ const ModelList = () => {
|
||||
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
|
||||
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
|
||||
)}
|
||||
{/* Structural Lora List */}
|
||||
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
|
||||
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
|
||||
{/* Control Lora List */}
|
||||
{isLoadingControlLoRAModels && <FetchingModelsLoader loadingMessage="Loading Control Loras..." />}
|
||||
{!isLoadingControlLoRAModels && filteredControlLoRAModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title={t('modelManager.structuralLora')}
|
||||
modelList={filteredStructuralLoRAModels}
|
||||
key="structural-lora"
|
||||
title={t('modelManager.controlLora')}
|
||||
modelList={filteredControlLoRAModels}
|
||||
key="control-lora"
|
||||
/>
|
||||
)}
|
||||
{/* Clip Embed List */}
|
||||
|
@ -24,7 +24,7 @@ export const ModelTypeFilter = memo(() => {
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'CLIP Vision',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
structural_lora: t('modelManager.structuralLora'),
|
||||
control_lora: t('modelManager.controlLora'),
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
@ -51,8 +51,8 @@ import {
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isStructuralLoRAModelFieldInputInstance,
|
||||
isStructuralLoRAModelFieldInputTemplate,
|
||||
isControlLoRAModelFieldInputInstance,
|
||||
isControlLoRAModelFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
isT2IAdapterModelFieldInputTemplate,
|
||||
isT5EncoderModelFieldInputInstance,
|
||||
@ -83,7 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
|
||||
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
@ -160,11 +160,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
}
|
||||
|
||||
if (
|
||||
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
|
||||
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
|
||||
isControlLoRAModelFieldInputInstance(fieldInstance) &&
|
||||
isControlLoRAModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
|
||||
<ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -1,34 +1,34 @@
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
StructuralLoRAModelFieldInputInstance,
|
||||
StructuralLoRAModelFieldInputTemplate,
|
||||
ControlLoRAModelFieldInputInstance,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
|
||||
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
|
||||
import { isControlLoRAModelConfig, type ControlLoRAModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
|
||||
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
|
||||
|
||||
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const ControlLoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
|
||||
const [modelConfigs, { isLoading }] = useControlLoRAModel();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: StructuralLoRAModelConfig | null) => {
|
||||
(value: ControlLoRAModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldStructuralLoRAModelValueChanged({
|
||||
fieldControlLoRAModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
@ -38,7 +38,7 @@ const StructuralLoRAModelFieldInputComponent = (props: Props) => {
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
|
||||
modelConfigs: modelConfigs.filter((config) => isControlLoRAModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
@ -62,4 +62,4 @@ const StructuralLoRAModelFieldInputComponent = (props: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(StructuralLoRAModelFieldInputComponent);
|
||||
export default memo(ControlLoRAModelFieldInputComponent);
|
@ -28,7 +28,7 @@ import type {
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldValue,
|
||||
StructuralLoRAModelFieldValue,
|
||||
ControlLoRAModelFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
@ -56,7 +56,7 @@ import {
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldValue,
|
||||
zStructuralLoRAModelFieldValue,
|
||||
zControlLoRAModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
@ -371,8 +371,8 @@ export const nodesSlice = createSlice({
|
||||
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
|
||||
},
|
||||
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
|
||||
fieldControlLoRAModelValueChanged: (state, action: FieldValueAction<ControlLoRAModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zControlLoRAModelFieldValue);
|
||||
},
|
||||
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
|
||||
@ -443,7 +443,7 @@ export const {
|
||||
fieldCLIPEmbedValueChanged,
|
||||
fieldCLIPLEmbedValueChanged,
|
||||
fieldCLIPGEmbedValueChanged,
|
||||
fieldStructuralLoRAModelValueChanged,
|
||||
fieldControlLoRAModelValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
|
@ -69,7 +69,7 @@ const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
'lora',
|
||||
'structural_lora',
|
||||
'control_lora',
|
||||
'controlnet',
|
||||
't2i_adapter',
|
||||
'ip_adapter',
|
||||
|
@ -178,8 +178,8 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('CLIPGEmbedModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StructuralLoRAModelField'),
|
||||
const zControlLoRAModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ControlLoRAModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
|
||||
@ -214,7 +214,7 @@ const zStatefulFieldType = z.union([
|
||||
zCLIPEmbedModelFieldType,
|
||||
zCLIPLEmbedModelFieldType,
|
||||
zCLIPGEmbedModelFieldType,
|
||||
zStructuralLoRAModelFieldType,
|
||||
zControlLoRAModelFieldType,
|
||||
zFluxVAEModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
@ -869,26 +869,26 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region StructuralLoRAModelField
|
||||
// #region ControlLoRAModelField
|
||||
|
||||
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStructuralLoRAModelFieldValue,
|
||||
export const zControlLoRAModelFieldValue = zModelIdentifierField.optional();
|
||||
const zControlLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zControlLoRAModelFieldValue,
|
||||
});
|
||||
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStructuralLoRAModelFieldType,
|
||||
const zControlLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zControlLoRAModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStructuralLoRAModelFieldValue,
|
||||
default: zControlLoRAModelFieldValue,
|
||||
});
|
||||
|
||||
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
|
||||
|
||||
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
|
||||
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
|
||||
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
|
||||
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
|
||||
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
|
||||
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
|
||||
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
|
||||
zControlLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
|
||||
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
// #endregion
|
||||
|
||||
@ -987,7 +987,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zCLIPEmbedModelFieldValue,
|
||||
zCLIPLEmbedModelFieldValue,
|
||||
zCLIPGEmbedModelFieldValue,
|
||||
zStructuralLoRAModelFieldValue,
|
||||
zControlLoRAModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@ -1059,7 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zCLIPEmbedModelFieldInputTemplate,
|
||||
zCLIPLEmbedModelFieldInputTemplate,
|
||||
zCLIPGEmbedModelFieldInputTemplate,
|
||||
zStructuralLoRAModelFieldInputTemplate,
|
||||
zControlLoRAModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
|
@ -28,7 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
CLIPEmbedModelField: undefined,
|
||||
CLIPLEmbedModelField: undefined,
|
||||
CLIPGEmbedModelField: undefined,
|
||||
StructuralLoRAModelField: undefined,
|
||||
ControlLoRAModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {
|
||||
|
@ -28,7 +28,7 @@ import type {
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
StructuralLoRAModelFieldInputTemplate,
|
||||
ControlLoRAModelFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
@ -301,12 +301,12 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
|
||||
const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<ControlLoRAModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StructuralLoRAModelFieldInputTemplate = {
|
||||
const template: ControlLoRAModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
@ -541,7 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
|
||||
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
|
||||
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
@ -113,9 +113,9 @@ export const zParameterVAEModel = zModelIdentifierField;
|
||||
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
|
||||
// #endregion
|
||||
|
||||
// #region Structural Lora Model
|
||||
export const zParameterStructuralLoRAModel = zModelIdentifierField;
|
||||
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
|
||||
// #region Control Lora Model
|
||||
export const zParameterControlLoRAModel = zModelIdentifierField;
|
||||
export type ParameterControlLoRAModel = z.infer<typeof zParameterControlLoRAModel>;
|
||||
// #endregion
|
||||
|
||||
// #region T5Encoder Model
|
||||
|
@ -23,7 +23,7 @@ import {
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isStructuralLoRAModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isT5EncoderModelConfig,
|
||||
isTIModelConfig,
|
||||
@ -59,7 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
|
||||
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
|
File diff suppressed because one or more lines are too long
@ -44,7 +44,7 @@ export type BaseModelType = S['BaseModelType'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
|
||||
export type ControlLoRAModelConfig = S['ControlLoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we make key required in the pydantic model?
|
||||
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||
@ -64,7 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type AnyModelConfig =
|
||||
| StructuralLoRAModelConfig
|
||||
| ControlLoRAModelConfig
|
||||
| LoRAModelConfig
|
||||
| VAEModelConfig
|
||||
| ControlNetModelConfig
|
||||
@ -116,8 +116,8 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
|
||||
return config.type === 'lora';
|
||||
};
|
||||
|
||||
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
|
||||
return config.type === 'structural_lora';
|
||||
export const isControlLoRAModelConfig = (config: AnyModelConfig): config is ControlLoRAModelConfig => {
|
||||
return config.type === 'control_lora';
|
||||
};
|
||||
|
||||
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {
|
||||
|
@ -23,6 +23,7 @@ def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, lis
|
||||
|
||||
assert is_state_dict_likely_flux_control(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers
|
||||
|
Reference in New Issue
Block a user