mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Fixup FLUX LoRA unit tests.
This commit is contained in:
@ -6,6 +6,7 @@ from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.lora.conversions.flux_lora_conversion_utils import (
|
||||
convert_flux_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
|
||||
|
||||
@ -70,3 +71,27 @@ def test_convert_flux_kohya_state_dict_to_invoke_format_error():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_kohya_state_dict():
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
|
||||
|
||||
# Prepare expected layer keys.
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in state_dict_keys:
|
||||
k = k.replace("lora_unet_", "")
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
expected_layer_keys.add(k)
|
||||
|
||||
# Assert that the lora_model has the expected layers.
|
||||
lora_model_keys = set(lora_model.layers.keys())
|
||||
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
|
||||
assert lora_model_keys == expected_layer_keys
|
||||
|
@ -1,7 +1,3 @@
|
||||
# test that if the model's device changes while the lora is applied, the weights can still be restored
|
||||
|
||||
# test that LoRA patching works on both CPU and CUDA
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -18,7 +14,7 @@ from invokeai.backend.model_patcher import ModelPatcher
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora(device):
|
||||
def test_apply_lora(device: str):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
|
||||
result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user