tests: fix test_hf_model_select::test_select_multiple_weights on windows

This commit is contained in:
psychedelicious 2024-03-04 23:00:51 +11:00
parent 3391c19926
commit 02bde7bb75

View File

@ -283,7 +283,7 @@ def sd15_test_files() -> list[Path]:
[ [
( (
ModelRepoVariant.FP16, ModelRepoVariant.FP16,
{ [
"feature_extractor/preprocessor_config.json", "feature_extractor/preprocessor_config.json",
"safety_checker/config.json", "safety_checker/config.json",
"safety_checker/model.fp16.safetensors", "safety_checker/model.fp16.safetensors",
@ -298,11 +298,11 @@ def sd15_test_files() -> list[Path]:
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json", "vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors", "vae/diffusion_pytorch_model.fp16.safetensors",
}, ],
), ),
( (
ModelRepoVariant.FP32, ModelRepoVariant.FP32,
{ [
"feature_extractor/preprocessor_config.json", "feature_extractor/preprocessor_config.json",
"safety_checker/config.json", "safety_checker/config.json",
"safety_checker/model.safetensors", "safety_checker/model.safetensors",
@ -317,12 +317,12 @@ def sd15_test_files() -> list[Path]:
"unet/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors",
"vae/config.json", "vae/config.json",
"vae/diffusion_pytorch_model.safetensors", "vae/diffusion_pytorch_model.safetensors",
}, ],
), ),
], ],
) )
def test_select_multiple_weights( def test_select_multiple_weights(
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str] sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
) -> None: ) -> None:
filtered_files = filter_files(sd15_test_files, variant) filtered_files = filter_files(sd15_test_files, variant)
assert {str(f) for f in filtered_files} == expected_files assert set(filtered_files) == {Path(f) for f in expected_files}