tests: fix test_hf_model_select::test_select_multiple_weights on windows

This commit is contained in:
psychedelicious 2024-03-04 23:00:51 +11:00
parent 3391c19926
commit 02bde7bb75

View File

@ -283,7 +283,7 @@ def sd15_test_files() -> list[Path]:
[
(
ModelRepoVariant.FP16,
{
[
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
@ -298,11 +298,11 @@ def sd15_test_files() -> list[Path]:
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
},
],
),
(
ModelRepoVariant.FP32,
{
[
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.safetensors",
@ -317,12 +317,12 @@ def sd15_test_files() -> list[Path]:
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
},
],
),
],
)
def test_select_multiple_weights(
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str]
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
) -> None:
filtered_files = filter_files(sd15_test_files, variant)
assert {str(f) for f in filtered_files} == expected_files
assert set(filtered_files) == {Path(f) for f in expected_files}