Fixup FLUX LoRA unit tests.

This commit is contained in:
Ryan Dick
2024-09-05 14:12:56 +00:00
committed by Kent Keirsey
parent 50c9410121
commit 92b8477299
2 changed files with 26 additions and 5 deletions

View File

@ -6,6 +6,7 @@ from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_lora_conversion_utils import (
convert_flux_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
@ -70,3 +71,27 @@ def test_convert_flux_kohya_state_dict_to_invoke_format_error():
with pytest.raises(ValueError):
convert_flux_kohya_state_dict_to_invoke_format(state_dict)
def test_lora_model_from_flux_kohya_state_dict():
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
lora_model = lora_model_from_flux_kohya_state_dict(state_dict)
# Prepare expected layer keys.
expected_layer_keys: set[str] = set()
for k in state_dict_keys:
k = k.replace("lora_unet_", "")
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
expected_layer_keys.add(k)
# Assert that the lora_model has the expected layers.
lora_model_keys = set(lora_model.layers.keys())
lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
assert lora_model_keys == expected_layer_keys

View File

@ -1,7 +1,3 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
@ -18,7 +14,7 @@ from invokeai.backend.model_patcher import ModelPatcher
],
)
@torch.no_grad()
def test_apply_lora(device):
def test_apply_lora(device: str):
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
result, and that model/LoRA tensors are moved between devices as expected.
"""