fix(mm): fix extraneous downloaded files in diffusers

Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory.

In this situation, we assume the files are different versions of the same weights. For example, we may have multiple
formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only
the best of these files for the requested format and precision/variant.

The previous logic assumed that each model weights file would have the same base filename, but this assumption was
not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single
file being downloaded for each submodel/subdirectory.
This commit is contained in:
psychedelicious 2024-03-04 16:10:29 +11:00
parent f2b5f8753f
commit 3534366146
2 changed files with 223 additions and 27 deletions

View File

@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
"""
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set
@ -73,10 +74,34 @@ def filter_files(
return sorted(_filter_by_variant(paths, variant))
def get_variant_label(path: Path) -> Optional[str]:
suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, _ = suffixes
else:
variant_label = None
return variant_label
def get_suffix(path: Path) -> str:
suffixes = path.suffixes
if len(suffixes) == 2:
_, suffix = suffixes
else:
suffix = suffixes[0]
return suffix
@dataclass
class SubfolderCandidate:
path: Path
score: int
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = {}
result: set[Path] = set()
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
@ -93,38 +118,49 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
elif variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT,
]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
variant_label = ""
suffix = suffixes[0]
basename = parent / path.stem
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
# text encoders:
#
# - text_encoder/model.fp16.safetensors
# - text_encoder/model.safetensors
# - text_encoder/pytorch_model.bin
# - text_encoder/pytorch_model.fp16.bin
#
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if previous := basenames.get(basename):
if (
previous.suffix != ".safetensors" and suffix == ".safetensors"
): # replace non-safetensors with safetensors when available
basenames[basename] = path
if variant_label == f".{variant}":
basenames[basename] = path
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
basenames[basename] = path
else:
basenames[basename] = path
parent = path.parent
score = 0
if path.suffix == ".safetensors":
score += 1
candidate_variant_label = get_variant_label(path)
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]
):
score += 1
if parent not in subfolder_weights:
subfolder_weights[parent] = []
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
else:
continue
for v in basenames.values():
result.add(v)
for candidate_list in subfolder_weights.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list
@ -144,3 +180,76 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
directories[x.parent] = directories.get(x.parent, 0) + 1
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
# def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
# """Select the proper variant files from a list of HuggingFace repo_id paths."""
# result: set[Path] = set()
# basenames: Dict[Path, Path] = {}
# for path in files:
# if path.suffix in [".onnx", ".pb", ".onnx_data"]:
# if variant == ModelRepoVariant.ONNX:
# result.add(path)
# elif "openvino_model" in path.name:
# if variant == ModelRepoVariant.OPENVINO:
# result.add(path)
# elif "flax_model" in path.name:
# if variant == ModelRepoVariant.FLAX:
# result.add(path)
# elif path.suffix in [".json", ".txt"]:
# result.add(path)
# elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
# ModelRepoVariant.FP16,
# ModelRepoVariant.FP32,
# ModelRepoVariant.DEFAULT,
# ]:
# parent = path.parent
# suffixes = path.suffixes
# if len(suffixes) == 2:
# variant_label, suffix = suffixes
# basename = parent / Path(path.stem).stem
# else:
# variant_label = ""
# suffix = suffixes[0]
# basename = parent / path.stem
# if previous := basenames.get(basename):
# if (
# previous.suffix != ".safetensors" and suffix == ".safetensors"
# ): # replace non-safetensors with safetensors when available
# basenames[basename] = path
# if variant_label == f".{variant}":
# basenames[basename] = path
# elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
# basenames[basename] = path
# else:
# basenames[basename] = path
# else:
# continue
# for v in basenames.values():
# result.add(v)
# # If one of the architecture-related variants was specified and no files matched other than
# # config and text files then we return an empty list
# if (
# variant
# and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
# and not any(variant.value in x.name for x in result)
# ):
# return set()
# # Prune folders that contain just a `config.json`. This happens when
# # the requested variant (e.g. "onnx") is missing
# directories: Dict[Path, int] = {}
# for x in result:
# if not x.parent:
# continue
# directories[x.parent] = directories.get(x.parent, 0) + 1
# return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}

View File

@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
),
],
)
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
print(f"testing variant {variant}")
filtered_files = filter_files(sdxl_base_files, variant)
assert set(filtered_files) == {Path(x) for x in expected_list}
@pytest.fixture
def sd15_test_files() -> list[Path]:
return [
Path(f)
for f in [
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"safety_checker/model.safetensors",
"safety_checker/pytorch_model.bin",
"safety_checker/pytorch_model.fp16.bin",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/pytorch_model.fp16.bin",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.non_ema.bin",
"unet/diffusion_pytorch_model.non_ema.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
]
@pytest.mark.parametrize(
"variant,expected_files",
[
(
ModelRepoVariant.FP16,
{
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
},
),
(
ModelRepoVariant.FP32,
{
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
},
),
],
)
def test_select_multiple_weights(
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str]
) -> None:
filtered_files = filter_files(sd15_test_files, variant)
assert {str(f) for f in filtered_files} == expected_files