Rename Structural Lora to Control Lora

This commit is contained in:
Brandon Rising
2024-12-12 13:45:07 -05:00
committed by Kent Keirsey
parent 040551d4fb
commit 046d19446c
34 changed files with 239 additions and 326 deletions

View File

@ -56,7 +56,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
StructuralLoRAModel = "StructuralLoRAModelField"
ControlLoRAModel = "ControlLoRAModelField"
# endregion
# region Misc Field Types
@ -144,7 +144,7 @@ class FieldDescriptions:
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
structural_lora_model = "Structural LoRA model to load"
control_lora_model = "Control LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"

View File

@ -0,0 +1,55 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ControlLoRAField, ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_control_lora_loader_output")
class FluxControlLoRALoaderOutput(BaseInvocationOutput):
"""Flux Control LoRA Loader Output"""
control_lora: Optional[ControlLoRAField] = OutputField(
title="Flux Control Lora", description="Control LoRAs to apply on model loading", default=None
)
@invocation(
"flux_control_lora_loader",
title="Flux Control LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
"""LoRA model and Image to use with FLUX transformer generation."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
)
image: ImageField = InputField(
description="The image to encode.",
)
def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
output = FluxControlLoRALoaderOutput()
output.control_lora = ControlLoRAField(
lora=self.lora,
img=self.image,
weight=1,
)
return output

View File

@ -8,8 +8,6 @@ import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
@ -24,7 +22,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import TransformerField, VAEField, StructuralLoRAField, LoRAField
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@ -35,8 +33,10 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
generate_img_ids,
@ -45,8 +45,6 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.flux_tools_sampling_utils import prepare_control
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@ -93,6 +91,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
control_lora: Optional[ControlLoRAField] = InputField(
description=FieldDescriptions.control_lora_model, input=Input.Connection, title="Control Lora", default=None
)
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
@ -198,7 +199,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
# Calculate the timestep schedule.
timesteps = get_schedule(
@ -289,12 +290,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device=x.device,
)
img_cond = None
if struct_lora := self.transformer.structural_lora:
if self.control_lora:
# What should we do when we have multiple of these?
if not self.controlnet_vae:
raise ValueError("controlnet_vae must be set when using a strutural lora")
ae_info = context.models.load(self.controlnet_vae.vae)
img = context.images.get_pil(struct_lora.img.image_name)
img = context.images.get_pil(self.control_lora.img.image_name)
with ae_info as ae:
assert isinstance(ae, AutoEncoder)
img_cond = prepare_control(self.height, self.width, self.seed, ae, img)
@ -359,7 +360,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
controlnet_extensions=controlnet_extensions,
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond
img_cond=img_cond,
)
x = unpack(x.float(), self.height, self.width)
@ -697,9 +698,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
return pos_ip_adapter_extensions, neg_ip_adapter_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
loras: list[Union[LoRAField, StructuralLoRAField]] = [*self.transformer.loras]
if self.transformer.structural_lora:
loras.append(self.transformer.structural_lora)
loras: list[Union[LoRAField, ControlLoRAField]] = [*self.transformer.loras]
if self.control_lora:
loras.append(self.control_lora)
for lora in loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)

View File

@ -81,8 +81,10 @@ class FluxModelLoaderInvocation(BaseInvocation):
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[], structural_loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], structural_loras=[], skipped_layers=0),
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(
tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0
),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],

View File

@ -1,70 +0,0 @@
from typing import Optional, Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.model import VAEField, StructuralLoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_structural_lora_loader_output")
class FluxStructuralLoRALoaderOutput(BaseInvocationOutput):
"""Flux Structural LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)
@invocation(
"flux_structural_lora_loader",
title="Flux Structural LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxStructuralLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.structural_lora_model, title="Structural LoRA", ui_type=UIType.StructuralLoRAModel
)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="FLUX Transformer",
)
image: ImageField = InputField(
description="The image to encode.",
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> FluxStructuralLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and self.transformer.structural_lora and self.transformer.structural_lora.lora.key == lora_key:
raise ValueError(f'Structural LoRA "{lora_key}" already applied to transformer.')
output = FluxStructuralLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.structural_lora = StructuralLoRAField(
lora=self.lora,
img=self.image,
weight=self.weight,
)
return output

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Optional, Literal
from typing import List, Optional
from pydantic import BaseModel, Field
@ -10,7 +10,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, ImageField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import (
@ -74,13 +74,15 @@ class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
class StructuralLoRAField(LoRAField):
class ControlLoRAField(LoRAField):
img: ImageField = Field(description="Image to use in structural conditioning")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
structural_lora: Optional[StructuralLoRAField] = Field(description="Structural LoRAs to apply on model loading", default=None)
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):

View File

@ -1,10 +1,11 @@
import torch
import numpy as np
from PIL import Image
import torch
from einops import rearrange
from PIL import Image
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
def prepare_control(
height: int,
width: int,

View File

@ -1,10 +1,10 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
from typing import Optional
from invokeai.backend.flux.custom_block_processor import (
CustomDoubleStreamBlockProcessor,

View File

@ -1,50 +0,0 @@
import os
import cv2
import numpy as np
import torch
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
class DepthImageEncoder:
depth_model_name = "LiheYoung/depth-anything-large-hf"
def __init__(self, device):
self.device = device
self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
def __call__(self, img: torch.Tensor) -> torch.Tensor:
hw = img.shape[-2:]
img = torch.clamp(img, -1.0, 1.0)
img_byte = ((img + 1.0) * 127.5).byte()
img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
depth = self.depth_model(img.to(self.device)).predicted_depth
depth = repeat(depth, "b h w -> b 3 h w")
depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
depth = depth / 127.5 - 1.0
return depth
class CannyImageEncoder:
def __init__(
self,
device,
min_t: int = 50,
max_t: int = 200,
):
self.device = device
self.min_t = min_t
self.max_t = max_t
def __call__(self, img: torch.Tensor) -> torch.Tensor:
assert img.shape[0] == 1, "Only batch size 1 is supported"
img = rearrange(img[0], "c h w -> h w c")
img = torch.clamp(img, -1.0, 1.0)
img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
# Apply Canny edge detection
canny = cv2.Canny(img_np, self.min_t, self.max_t)
# Convert back to torch tensor and reshape
canny = torch.from_numpy(canny).float() / 127.5 - 1.0
canny = rearrange(canny, "h w -> 1 1 h w")
canny = repeat(canny, "b 1 ... -> b 3 ...")
return canny.to(self.device)

View File

@ -1,21 +1,21 @@
import re
from typing import Any, Dict
import torch
from typing import Any, Dict
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
# A regex pattern that matches all of the keys in the Flux Dev/Canny LoRA format.
# Example keys:
# guidance_in.in_layer.lora_B.bias
# single_blocks.0.linear1.lora_A.weight
# double_blocks.0.img_attn.norm.key_norm.scale
FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(final_layer|vector_in|txt_in|time_in|img_in|guidance_in|\w+_blocks)(\.(\d+))?\.(lora_(A|B)|(in|out)_layer|adaLN_modulation|img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear|linear1|linear2|modulation|norm)\.?(.*)"
def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the FLUX Control LoRA format.
@ -24,10 +24,11 @@ def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool:
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
"""
return all(
re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_STRUCTURAL_TRANSFORMER_KEY_REGEX, k)
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw:
# converted_state_dict = _convert_lora_bfl_control(state_dict=state_dict)
# Group keys by layer.
@ -54,7 +55,7 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
None,
layer_state_dict["lora_A.weight"],
None,
layer_state_dict["lora_B.bias"]
layer_state_dict["lora_B.bias"],
)
elif "scale" in layer_state_dict:
layers[prefixed_key] = SetParameterLayer("scale", layer_state_dict["scale"])
@ -62,4 +63,3 @@ def lora_model_from_flux_control_state_dict(state_dict: Dict[str, torch.Tensor])
raise AssertionError(f"{layer_key} not expected")
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@ -9,4 +9,6 @@ from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer]
AnyLoRALayer = Union[
LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer, ConcatenatedLoRALayer, SetParameterLayer
]

View File

@ -1,34 +0,0 @@
from typing import Dict, Optional
import torch
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
class ReshapeWeightLayer(LoRALayerBase):
# TODO: Just everything in this class
def __init__(self, weight: Optional[torch.Tensor], bias: Optional[torch.Tensor], scale: Optional[torch.Tensor]):
super().__init__(alpha=None, bias=bias)
self.weight = torch.nn.Parameter(weight) if weight is not None else None
self.bias = torch.nn.Parameter(bias) if bias is not None else None
self.manual_scale = scale
def scale(self):
return self.manual_scale.float() if self.manual_scale is not None else super().scale()
def rank(self) -> int | None:
return None
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return orig_weight
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
super().to(device=device, dtype=dtype)
if self.weight is not None:
self.weight = self.weight.to(device=device, dtype=dtype)
if self.manual_scale is not None:
self.manual_scale = self.manual_scale.to(device=device, dtype=dtype)
def calc_size(self) -> int:
return super().calc_size() + calc_tensor_size(self.manual_scale)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict
import torch
@ -17,7 +17,7 @@ class SetParameterLayer(LoRALayerBase):
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight - orig_weight
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
return {self.param_name: self.get_weight(orig_module.get_parameter(self.param_name))}

View File

@ -9,7 +9,6 @@ from invokeai.backend.lora.layers.loha_layer import LoHALayer
from invokeai.backend.lora.layers.lokr_layer import LoKRLayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.norm_layer import NormLayer
from invokeai.backend.lora.layers.set_parameter_layer import SetParameterLayer
def any_lora_layer_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> AnyLoRALayer:

View File

@ -67,7 +67,7 @@ class ModelType(str, Enum):
Main = "main"
VAE = "vae"
LoRA = "lora"
StructuralLoRa = "structural_lora"
ControlLoRa = "control_lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
@ -274,16 +274,16 @@ class LoRALyCORISConfig(LoRAConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class StructuralLoRALyCORISConfig(ModelConfigBase):
"""Model config for Structural LoRA/Lycoris models."""
class ControlLoRALyCORISConfig(ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.StructuralLoRa] = ModelType.StructuralLoRa
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.StructuralLoRa.value}.{ModelFormat.LyCORIS.value}")
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(LoRAConfigBase):
@ -548,7 +548,7 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[StructuralLoRALyCORISConfig, StructuralLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

View File

@ -9,13 +9,17 @@ import torch
from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora.conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control, lora_model_from_flux_control_state_dict
from invokeai.backend.lora.conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.lora.conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.model_manager import (
@ -33,7 +37,7 @@ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoade
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.StructuralLoRa, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
"""Class to load LoRA models."""

View File

@ -15,10 +15,10 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.lora.conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import is_state_dict_likely_in_flux_kohya_format
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
@ -269,7 +269,7 @@ class ModelProbe(object):
and isinstance(tensor_b, torch.Tensor)
and tensor_b.shape[0] == 3072
):
return ModelType.StructuralLoRa
return ModelType.ControlLoRa
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(
@ -1061,7 +1061,7 @@ ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelI
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.StructuralLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -809,7 +809,7 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"structuralLora": "Structural LoRA",
"controlLora": "Control LoRA",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",

View File

@ -24,7 +24,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterControlLoRAModel,
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
@ -76,7 +76,7 @@ export type ParamsState = {
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
structuralLora: ParameterStructuralLoRAModel | null;
controlLora: ParameterControlLoRAModel | null;
};
const initialState: ParamsState = {
@ -123,7 +123,7 @@ const initialState: ParamsState = {
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
structuralLora: null,
controlLora: null,
};
export const paramsSlice = createSlice({
@ -198,8 +198,8 @@ export const paramsSlice = createSlice({
t5EncoderModelSelected: (state, action: PayloadAction<ParameterT5EncoderModel | null>) => {
state.t5EncoderModel = action.payload;
},
structuralLoRAModelSelected: (state, action: PayloadAction<ParameterStructuralLoRAModel | null>) => {
state.structuralLora = action.payload;
controlLoRAModelSelected: (state, action: PayloadAction<ParameterControlLoRAModel | null>) => {
state.controlLora = action.payload;
},
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;

View File

@ -46,7 +46,7 @@ import type {
ParameterSeed,
ParameterSteps,
ParameterStrength,
ParameterStructuralLoRAModel,
ParameterControlLoRAModel,
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
@ -81,7 +81,7 @@ import {
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isStructuralLoRAModelConfig,
isControlLoRAModelConfig,
isT2IAdapterModelConfig,
isVAEModelConfig,
} from 'services/api/types';
@ -228,10 +228,10 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
return modelIdentifier;
};
const parseStructuralLoRAModel: MetadataParseFunc<ParameterStructuralLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'structural_lora', undefined);
const key = await getModelKey(slora, 'structural_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isStructuralLoRAModelConfig);
const parseControlLoRAModel: MetadataParseFunc<ParameterControlLoRAModel> = async (metadata) => {
const slora = await getProperty(metadata, 'control_lora', undefined);
const key = await getModelKey(slora, 'control_lora');
const sloraModelConfig = await fetchModelConfigWithTypeGuard(key, isControlLoRAModelConfig);
const modelIdentifier = zModelIdentifierField.parse(sloraModelConfig);
return modelIdentifier;
};
@ -681,7 +681,7 @@ export const parsers = {
mainModel: parseMainModel,
refinerModel: parseRefinerModel,
vaeModel: parseVAEModel,
structuralLora: parseStructuralLoRAModel,
controlLora: parseControlLoRAModel,
lora: parseLoRA,
loras: parseAllLoRAs,
controlNet: parseControlNet,

View File

@ -18,7 +18,7 @@ import {
useMainModels,
useRefinerModels,
useSpandrelImageToImageModels,
useStructuralLoRAModel,
useControlLoRAModel,
useT2IAdapterModels,
useT5EncoderModels,
useVAEModels,
@ -93,10 +93,10 @@ const ModelList = () => {
[t5EncoderModels, searchTerm, filteredModelType]
);
const [structuralLoRAModels, { isLoading: isLoadingStructuralLoRAModels }] = useStructuralLoRAModel();
const filteredStructuralLoRAModels = useMemo(
() => modelsFilter(structuralLoRAModels, searchTerm, filteredModelType),
[structuralLoRAModels, searchTerm, filteredModelType]
const [controlLoRAModels, { isLoading: isLoadingControlLoRAModels }] = useControlLoRAModel();
const filteredControlLoRAModels = useMemo(
() => modelsFilter(controlLoRAModels, searchTerm, filteredModelType),
[controlLoRAModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
@ -126,7 +126,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
clipEmbedModels.length +
structuralLoRAModels.length
controlLoRAModels.length
);
}, [
filteredControlNetModels.length,
@ -141,7 +141,7 @@ const ModelList = () => {
filteredSpandrelImageToImageModels.length,
t5EncoderModels.length,
clipEmbedModels.length,
structuralLoRAModels.length,
controlLoRAModels.length,
]);
return (
@ -204,13 +204,13 @@ const ModelList = () => {
{!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && (
<ModelListWrapper title={t('modelManager.t5Encoder')} modelList={filteredT5EncoderModels} key="t5-encoder" />
)}
{/* Structural Lora List */}
{isLoadingStructuralLoRAModels && <FetchingModelsLoader loadingMessage="Loading Structural Loras..." />}
{!isLoadingStructuralLoRAModels && filteredStructuralLoRAModels.length > 0 && (
{/* Control Lora List */}
{isLoadingControlLoRAModels && <FetchingModelsLoader loadingMessage="Loading Control Loras..." />}
{!isLoadingControlLoRAModels && filteredControlLoRAModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.structuralLora')}
modelList={filteredStructuralLoRAModels}
key="structural-lora"
title={t('modelManager.controlLora')}
modelList={filteredControlLoRAModels}
key="control-lora"
/>
)}
{/* Clip Embed List */}

View File

@ -24,7 +24,7 @@ export const ModelTypeFilter = memo(() => {
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
structural_lora: t('modelManager.structuralLora'),
control_lora: t('modelManager.controlLora'),
}),
[t]
);

View File

@ -51,8 +51,8 @@ import {
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isStructuralLoRAModelFieldInputInstance,
isStructuralLoRAModelFieldInputTemplate,
isControlLoRAModelFieldInputInstance,
isControlLoRAModelFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
isT2IAdapterModelFieldInputTemplate,
isT5EncoderModelFieldInputInstance,
@ -83,7 +83,7 @@ import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComp
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import StructuralLoRAModelFieldInputComponent from './inputs/StructuralLoraModelFieldInputComponent';
import ControlLoRAModelFieldInputComponent from './inputs/ControlLoraModelFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@ -160,11 +160,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
}
if (
isStructuralLoRAModelFieldInputInstance(fieldInstance) &&
isStructuralLoRAModelFieldInputTemplate(fieldTemplate)
isControlLoRAModelFieldInputInstance(fieldInstance) &&
isControlLoRAModelFieldInputTemplate(fieldTemplate)
) {
return (
<StructuralLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
<ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />
);
}

View File

@ -1,34 +1,34 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldStructuralLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldControlLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StructuralLoRAModelFieldInputInstance,
StructuralLoRAModelFieldInputTemplate,
ControlLoRAModelFieldInputInstance,
ControlLoRAModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useStructuralLoRAModel } from 'services/api/hooks/modelsByType';
import { isStructuralLoRAModelConfig, type StructuralLoRAModelConfig } from 'services/api/types';
import { useControlLoRAModel } from 'services/api/hooks/modelsByType';
import { isControlLoRAModelConfig, type ControlLoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<StructuralLoRAModelFieldInputInstance, StructuralLoRAModelFieldInputTemplate>;
type Props = FieldComponentProps<ControlLoRAModelFieldInputInstance, ControlLoRAModelFieldInputTemplate>;
const StructuralLoRAModelFieldInputComponent = (props: Props) => {
const ControlLoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useStructuralLoRAModel();
const [modelConfigs, { isLoading }] = useControlLoRAModel();
const _onChange = useCallback(
(value: StructuralLoRAModelConfig | null) => {
(value: ControlLoRAModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldStructuralLoRAModelValueChanged({
fieldControlLoRAModelValueChanged({
nodeId,
fieldName: field.name,
value,
@ -38,7 +38,7 @@ const StructuralLoRAModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isStructuralLoRAModelConfig(config)),
modelConfigs: modelConfigs.filter((config) => isControlLoRAModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
@ -62,4 +62,4 @@ const StructuralLoRAModelFieldInputComponent = (props: Props) => {
);
};
export default memo(StructuralLoRAModelFieldInputComponent);
export default memo(ControlLoRAModelFieldInputComponent);

View File

@ -28,7 +28,7 @@ import type {
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldValue,
StructuralLoRAModelFieldValue,
ControlLoRAModelFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
@ -56,7 +56,7 @@ import {
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldValue,
zStructuralLoRAModelFieldValue,
zControlLoRAModelFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
@ -371,8 +371,8 @@ export const nodesSlice = createSlice({
fieldCLIPGEmbedValueChanged: (state, action: FieldValueAction<CLIPGEmbedModelFieldValue>) => {
fieldValueReducer(state, action, zCLIPGEmbedModelFieldValue);
},
fieldStructuralLoRAModelValueChanged: (state, action: FieldValueAction<StructuralLoRAModelFieldValue>) => {
fieldValueReducer(state, action, zStructuralLoRAModelFieldValue);
fieldControlLoRAModelValueChanged: (state, action: FieldValueAction<ControlLoRAModelFieldValue>) => {
fieldValueReducer(state, action, zControlLoRAModelFieldValue);
},
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
@ -443,7 +443,7 @@ export const {
fieldCLIPEmbedValueChanged,
fieldCLIPLEmbedValueChanged,
fieldCLIPGEmbedValueChanged,
fieldStructuralLoRAModelValueChanged,
fieldControlLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,

View File

@ -69,7 +69,7 @@ const zModelType = z.enum([
'main',
'vae',
'lora',
'structural_lora',
'control_lora',
'controlnet',
't2i_adapter',
'ip_adapter',

View File

@ -178,8 +178,8 @@ const zCLIPGEmbedModelFieldType = zFieldTypeBase.extend({
name: z.literal('CLIPGEmbedModelField'),
originalType: zStatelessFieldType.optional(),
});
const zStructuralLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('StructuralLoRAModelField'),
const zControlLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlLoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
@ -214,7 +214,7 @@ const zStatefulFieldType = z.union([
zCLIPEmbedModelFieldType,
zCLIPLEmbedModelFieldType,
zCLIPGEmbedModelFieldType,
zStructuralLoRAModelFieldType,
zControlLoRAModelFieldType,
zFluxVAEModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@ -869,26 +869,26 @@ export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGE
// #endregion
// #region StructuralLoRAModelField
// #region ControlLoRAModelField
export const zStructuralLoRAModelFieldValue = zModelIdentifierField.optional();
const zStructuralLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zStructuralLoRAModelFieldValue,
export const zControlLoRAModelFieldValue = zModelIdentifierField.optional();
const zControlLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zControlLoRAModelFieldValue,
});
const zStructuralLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zStructuralLoRAModelFieldType,
const zControlLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zControlLoRAModelFieldType,
originalType: zFieldType.optional(),
default: zStructuralLoRAModelFieldValue,
default: zControlLoRAModelFieldValue,
});
export type StructuralLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValue>;
export type StructuralLoRAModelFieldInputInstance = z.infer<typeof zStructuralLoRAModelFieldInputInstance>;
export type StructuralLoRAModelFieldInputTemplate = z.infer<typeof zStructuralLoRAModelFieldInputTemplate>;
export const isStructuralLoRAModelFieldInputInstance = (val: unknown): val is StructuralLoRAModelFieldInputInstance =>
zStructuralLoRAModelFieldInputInstance.safeParse(val).success;
export const isStructuralLoRAModelFieldInputTemplate = (val: unknown): val is StructuralLoRAModelFieldInputTemplate =>
zStructuralLoRAModelFieldInputTemplate.safeParse(val).success;
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
zControlLoRAModelFieldInputInstance.safeParse(val).success;
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
// #endregion
@ -987,7 +987,7 @@ export const zStatefulFieldValue = z.union([
zCLIPEmbedModelFieldValue,
zCLIPLEmbedModelFieldValue,
zCLIPGEmbedModelFieldValue,
zStructuralLoRAModelFieldValue,
zControlLoRAModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@ -1059,7 +1059,7 @@ const zStatefulFieldInputTemplate = z.union([
zCLIPEmbedModelFieldInputTemplate,
zCLIPLEmbedModelFieldInputTemplate,
zCLIPGEmbedModelFieldInputTemplate,
zStructuralLoRAModelFieldInputTemplate,
zControlLoRAModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,

View File

@ -28,7 +28,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
CLIPEmbedModelField: undefined,
CLIPLEmbedModelField: undefined,
CLIPGEmbedModelField: undefined,
StructuralLoRAModelField: undefined,
ControlLoRAModelField: undefined,
};
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

View File

@ -28,7 +28,7 @@ import type {
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldInputTemplate,
StructuralLoRAModelFieldInputTemplate,
ControlLoRAModelFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
@ -301,12 +301,12 @@ const buildCLIPGEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPGEmb
return template;
};
const buildStructuralLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<StructuralLoRAModelFieldInputTemplate> = ({
const buildControlLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<ControlLoRAModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StructuralLoRAModelFieldInputTemplate = {
const template: ControlLoRAModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
@ -541,7 +541,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
CLIPLEmbedModelField: buildCLIPLEmbedModelFieldInputTemplate,
CLIPGEmbedModelField: buildCLIPGEmbedModelFieldInputTemplate,
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
StructuralLoRAModelField: buildStructuralLoRAModelFieldInputTemplate,
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@ -113,9 +113,9 @@ export const zParameterVAEModel = zModelIdentifierField;
export type ParameterVAEModel = z.infer<typeof zParameterVAEModel>;
// #endregion
// #region Structural Lora Model
export const zParameterStructuralLoRAModel = zModelIdentifierField;
export type ParameterStructuralLoRAModel = z.infer<typeof zParameterStructuralLoRAModel>;
// #region Control Lora Model
export const zParameterControlLoRAModel = zModelIdentifierField;
export type ParameterControlLoRAModel = z.infer<typeof zParameterControlLoRAModel>;
// #endregion
// #region T5Encoder Model

View File

@ -23,7 +23,7 @@ import {
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isStructuralLoRAModelConfig,
isControlLoRAModelConfig,
isT2IAdapterModelConfig,
isT5EncoderModelConfig,
isTIModelConfig,
@ -59,7 +59,7 @@ export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useStructuralLoRAModel = buildModelsHook(isStructuralLoRAModelConfig);
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -44,7 +44,7 @@ export type BaseModelType = S['BaseModelType'];
// Model Configs
export type StructuralLoRAModelConfig = S['StructuralLoRALyCORISConfig'];
export type ControlLoRAModelConfig = S['ControlLoRALyCORISConfig'];
// TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
@ -64,7 +64,7 @@ export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| StructuralLoRAModelConfig
| ControlLoRAModelConfig
| LoRAModelConfig
| VAEModelConfig
| ControlNetModelConfig
@ -116,8 +116,8 @@ export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelCo
return config.type === 'lora';
};
export const isStructuralLoRAModelConfig = (config: AnyModelConfig): config is StructuralLoRAModelConfig => {
return config.type === 'structural_lora';
export const isControlLoRAModelConfig = (config: AnyModelConfig): config is ControlLoRAModelConfig => {
return config.type === 'control_lora';
};
export const isVAEModelConfig = (config: AnyModelConfig, excludeSubmodels?: boolean): config is VAEModelConfig => {

View File

@ -23,6 +23,7 @@ def test_is_state_dict_likely_in_flux_control_format_true(sd_keys: dict[str, lis
assert is_state_dict_likely_flux_control(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys])
def test_is_state_dict_likely_in_flux_control_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_flux_control() returns False for a state dict that is in the Diffusers