mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): fix extraneous downloaded files in diffusers
Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory. In this situation, we assume the files are different versions of the same weights. For example, we may have multiple formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only the best of these files for the requested format and precision/variant. The previous logic assumed that each model weights file would have the same base filename, but this assumption was not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single file being downloaded for each submodel/subdirectory.
This commit is contained in:
@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_test_files() -> list[Path]:
|
||||
return [
|
||||
Path(f)
|
||||
for f in [
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"safety_checker/model.safetensors",
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.non_ema.bin",
|
||||
"unet/diffusion_pytorch_model.non_ema.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_files",
|
||||
[
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
{
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
},
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP32,
|
||||
{
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select_multiple_weights(
|
||||
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str]
|
||||
) -> None:
|
||||
filtered_files = filter_files(sd15_test_files, variant)
|
||||
assert {str(f) for f in filtered_files} == expected_files
|
||||
|
Reference in New Issue
Block a user