fix(mm): fix extraneous downloaded files in diffusers

Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory.

In this situation, we assume the files are different versions of the same weights. For example, we may have multiple
formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only
the best of these files for the requested format and precision/variant.

The previous logic assumed that each model weights file would have the same base filename, but this assumption was
not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single
file being downloaded for each submodel/subdirectory.
This commit is contained in:
psychedelicious 2024-03-04 16:10:29 +11:00
parent f2b5f8753f
commit 3534366146
2 changed files with 223 additions and 27 deletions

View File

@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
""" """
import re import re
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
@ -73,10 +74,34 @@ def filter_files(
return sorted(_filter_by_variant(paths, variant)) return sorted(_filter_by_variant(paths, variant))
def get_variant_label(path: Path) -> Optional[str]:
suffixes = path.suffixes
if len(suffixes) == 2:
variant_label, _ = suffixes
else:
variant_label = None
return variant_label
def get_suffix(path: Path) -> str:
suffixes = path.suffixes
if len(suffixes) == 2:
_, suffix = suffixes
else:
suffix = suffixes[0]
return suffix
@dataclass
class SubfolderCandidate:
path: Path
score: int
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]: def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths.""" """Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set() result: set[Path] = set()
basenames: Dict[Path, Path] = {} subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
for path in files: for path in files:
if path.suffix in [".onnx", ".pb", ".onnx_data"]: if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX: if variant == ModelRepoVariant.ONNX:
@ -93,38 +118,49 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
elif path.suffix in [".json", ".txt"]: elif path.suffix in [".json", ".txt"]:
result.add(path) result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [ elif variant in [
ModelRepoVariant.FP16, ModelRepoVariant.FP16,
ModelRepoVariant.FP32, ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT, ModelRepoVariant.DEFAULT,
]: ] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
parent = path.parent # For weights files, we want to select the best one for each subfolder. For example, we may have multiple
suffixes = path.suffixes # text encoders:
if len(suffixes) == 2: #
variant_label, suffix = suffixes # - text_encoder/model.fp16.safetensors
basename = parent / Path(path.stem).stem # - text_encoder/model.safetensors
else: # - text_encoder/pytorch_model.bin
variant_label = "" # - text_encoder/pytorch_model.fp16.bin
suffix = suffixes[0] #
basename = parent / path.stem # We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
# variant and format and select the best one.
if previous := basenames.get(basename): parent = path.parent
if ( score = 0
previous.suffix != ".safetensors" and suffix == ".safetensors"
): # replace non-safetensors with safetensors when available if path.suffix == ".safetensors":
basenames[basename] = path score += 1
if variant_label == f".{variant}":
basenames[basename] = path candidate_variant_label = get_variant_label(path)
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
basenames[basename] = path # Some special handling is needed here if there is not an exact match and if we cannot infer the variant
else: # from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
basenames[basename] = path if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]
):
score += 1
if parent not in subfolder_weights:
subfolder_weights[parent] = []
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
else: else:
continue continue
for v in basenames.values(): for candidate_list in subfolder_weights.values():
result.add(v) highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than # If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list # config and text files then we return an empty list
@ -144,3 +180,76 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
directories[x.parent] = directories.get(x.parent, 0) + 1 directories[x.parent] = directories.get(x.parent, 0) + 1
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"} return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
# def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
# """Select the proper variant files from a list of HuggingFace repo_id paths."""
# result: set[Path] = set()
# basenames: Dict[Path, Path] = {}
# for path in files:
# if path.suffix in [".onnx", ".pb", ".onnx_data"]:
# if variant == ModelRepoVariant.ONNX:
# result.add(path)
# elif "openvino_model" in path.name:
# if variant == ModelRepoVariant.OPENVINO:
# result.add(path)
# elif "flax_model" in path.name:
# if variant == ModelRepoVariant.FLAX:
# result.add(path)
# elif path.suffix in [".json", ".txt"]:
# result.add(path)
# elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
# ModelRepoVariant.FP16,
# ModelRepoVariant.FP32,
# ModelRepoVariant.DEFAULT,
# ]:
# parent = path.parent
# suffixes = path.suffixes
# if len(suffixes) == 2:
# variant_label, suffix = suffixes
# basename = parent / Path(path.stem).stem
# else:
# variant_label = ""
# suffix = suffixes[0]
# basename = parent / path.stem
# if previous := basenames.get(basename):
# if (
# previous.suffix != ".safetensors" and suffix == ".safetensors"
# ): # replace non-safetensors with safetensors when available
# basenames[basename] = path
# if variant_label == f".{variant}":
# basenames[basename] = path
# elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
# basenames[basename] = path
# else:
# basenames[basename] = path
# else:
# continue
# for v in basenames.values():
# result.add(v)
# # If one of the architecture-related variants was specified and no files matched other than
# # config and text files then we return an empty list
# if (
# variant
# and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
# and not any(variant.value in x.name for x in result)
# ):
# return set()
# # Prune folders that contain just a `config.json`. This happens when
# # the requested variant (e.g. "onnx") is missing
# directories: Dict[Path, int] = {}
# for x in result:
# if not x.parent:
# continue
# directories[x.parent] = directories.get(x.parent, 0) + 1
# return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}

View File

@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
), ),
], ],
) )
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None: def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
print(f"testing variant {variant}") print(f"testing variant {variant}")
filtered_files = filter_files(sdxl_base_files, variant) filtered_files = filter_files(sdxl_base_files, variant)
assert set(filtered_files) == {Path(x) for x in expected_list} assert set(filtered_files) == {Path(x) for x in expected_list}
@pytest.fixture
def sd15_test_files() -> list[Path]:
return [
Path(f)
for f in [
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"safety_checker/model.safetensors",
"safety_checker/pytorch_model.bin",
"safety_checker/pytorch_model.fp16.bin",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/pytorch_model.fp16.bin",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.non_ema.bin",
"unet/diffusion_pytorch_model.non_ema.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
]
@pytest.mark.parametrize(
"variant,expected_files",
[
(
ModelRepoVariant.FP16,
{
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.fp16.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
},
),
(
ModelRepoVariant.FP32,
{
"feature_extractor/preprocessor_config.json",
"safety_checker/config.json",
"safety_checker/model.safetensors",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
},
),
],
)
def test_select_multiple_weights(
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str]
) -> None:
filtered_files = filter_files(sd15_test_files, variant)
assert {str(f) for f in filtered_files} == expected_files