From 35343661469ac1c7cb43a3dc4cd0de0c68e8e639 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Mar 2024 16:10:29 +1100 Subject: [PATCH] fix(mm): fix extraneous downloaded files in diffusers Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory. In this situation, we assume the files are different versions of the same weights. For example, we may have multiple formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only the best of these files for the requested format and precision/variant. The previous logic assumed that each model weights file would have the same base filename, but this assumption was not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single file being downloaded for each submodel/subdirectory. --- .../model_manager/util/select_hf_files.py | 161 +++++++++++++++--- .../util/test_hf_model_select.py | 89 +++++++++- 2 files changed, 223 insertions(+), 27 deletions(-) diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 2fd7a3721a..b84ca06b94 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx') """ import re +from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Set @@ -73,10 +74,34 @@ def filter_files( return sorted(_filter_by_variant(paths, variant)) +def get_variant_label(path: Path) -> Optional[str]: + suffixes = path.suffixes + if len(suffixes) == 2: + variant_label, _ = suffixes + else: + variant_label = None + return variant_label + + +def get_suffix(path: Path) -> str: + suffixes = path.suffixes + if len(suffixes) == 2: + _, suffix = suffixes + else: + suffix = suffixes[0] + return suffix + + +@dataclass +class SubfolderCandidate: + path: Path + score: int + + def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]: """Select the proper variant files from a list of HuggingFace repo_id paths.""" - result = set() - basenames: Dict[Path, Path] = {} + result: set[Path] = set() + subfolder_weights: dict[Path, list[SubfolderCandidate]] = {} for path in files: if path.suffix in [".onnx", ".pb", ".onnx_data"]: if variant == ModelRepoVariant.ONNX: @@ -93,38 +118,49 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path elif path.suffix in [".json", ".txt"]: result.add(path) - elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [ + elif variant in [ ModelRepoVariant.FP16, ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT, - ]: - parent = path.parent - suffixes = path.suffixes - if len(suffixes) == 2: - variant_label, suffix = suffixes - basename = parent / Path(path.stem).stem - else: - variant_label = "" - suffix = suffixes[0] - basename = parent / path.stem + ] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]: + # For weights files, we want to select the best one for each subfolder. For example, we may have multiple + # text encoders: + # + # - text_encoder/model.fp16.safetensors + # - text_encoder/model.safetensors + # - text_encoder/pytorch_model.bin + # - text_encoder/pytorch_model.fp16.bin + # + # We prefer safetensors over other file formats and an exact variant match. We'll score each file based on + # variant and format and select the best one. - if previous := basenames.get(basename): - if ( - previous.suffix != ".safetensors" and suffix == ".safetensors" - ): # replace non-safetensors with safetensors when available - basenames[basename] = path - if variant_label == f".{variant}": - basenames[basename] = path - elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]: - basenames[basename] = path - else: - basenames[basename] = path + parent = path.parent + score = 0 + + if path.suffix == ".safetensors": + score += 1 + + candidate_variant_label = get_variant_label(path) + + # Some special handling is needed here if there is not an exact match and if we cannot infer the variant + # from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT. + if candidate_variant_label == f".{variant}" or ( + not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT] + ): + score += 1 + + if parent not in subfolder_weights: + subfolder_weights[parent] = [] + + subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score)) else: continue - for v in basenames.values(): - result.add(v) + for candidate_list in subfolder_weights.values(): + highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score) + if highest_score_candidate: + result.add(highest_score_candidate.path) # If one of the architecture-related variants was specified and no files matched other than # config and text files then we return an empty list @@ -144,3 +180,76 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path directories[x.parent] = directories.get(x.parent, 0) + 1 return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"} + + +# def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]: +# """Select the proper variant files from a list of HuggingFace repo_id paths.""" +# result: set[Path] = set() +# basenames: Dict[Path, Path] = {} +# for path in files: +# if path.suffix in [".onnx", ".pb", ".onnx_data"]: +# if variant == ModelRepoVariant.ONNX: +# result.add(path) + +# elif "openvino_model" in path.name: +# if variant == ModelRepoVariant.OPENVINO: +# result.add(path) + +# elif "flax_model" in path.name: +# if variant == ModelRepoVariant.FLAX: +# result.add(path) + +# elif path.suffix in [".json", ".txt"]: +# result.add(path) + +# elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [ +# ModelRepoVariant.FP16, +# ModelRepoVariant.FP32, +# ModelRepoVariant.DEFAULT, +# ]: +# parent = path.parent +# suffixes = path.suffixes +# if len(suffixes) == 2: +# variant_label, suffix = suffixes +# basename = parent / Path(path.stem).stem +# else: +# variant_label = "" +# suffix = suffixes[0] +# basename = parent / path.stem + +# if previous := basenames.get(basename): +# if ( +# previous.suffix != ".safetensors" and suffix == ".safetensors" +# ): # replace non-safetensors with safetensors when available +# basenames[basename] = path +# if variant_label == f".{variant}": +# basenames[basename] = path +# elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]: +# basenames[basename] = path +# else: +# basenames[basename] = path + +# else: +# continue + +# for v in basenames.values(): +# result.add(v) + +# # If one of the architecture-related variants was specified and no files matched other than +# # config and text files then we return an empty list +# if ( +# variant +# and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX] +# and not any(variant.value in x.name for x in result) +# ): +# return set() + +# # Prune folders that contain just a `config.json`. This happens when +# # the requested variant (e.g. "onnx") is missing +# directories: Dict[Path, int] = {} +# for x in result: +# if not x.parent: +# continue +# directories[x.parent] = directories.get(x.parent, 0) + 1 + +# return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"} diff --git a/tests/backend/model_manager/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py index 5bef9cb2e1..e111738c30 100644 --- a/tests/backend/model_manager/util/test_hf_model_select.py +++ b/tests/backend/model_manager/util/test_hf_model_select.py @@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]: ), ], ) -def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None: +def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None: print(f"testing variant {variant}") filtered_files = filter_files(sdxl_base_files, variant) assert set(filtered_files) == {Path(x) for x in expected_list} + + +@pytest.fixture +def sd15_test_files() -> list[Path]: + return [ + Path(f) + for f in [ + "feature_extractor/preprocessor_config.json", + "safety_checker/config.json", + "safety_checker/model.fp16.safetensors", + "safety_checker/model.safetensors", + "safety_checker/pytorch_model.bin", + "safety_checker/pytorch_model.fp16.bin", + "scheduler/scheduler_config.json", + "text_encoder/config.json", + "text_encoder/model.fp16.safetensors", + "text_encoder/model.safetensors", + "text_encoder/pytorch_model.bin", + "text_encoder/pytorch_model.fp16.bin", + "tokenizer/merges.txt", + "tokenizer/special_tokens_map.json", + "tokenizer/tokenizer_config.json", + "tokenizer/vocab.json", + "unet/config.json", + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + "unet/diffusion_pytorch_model.non_ema.bin", + "unet/diffusion_pytorch_model.non_ema.safetensors", + "unet/diffusion_pytorch_model.safetensors", + "vae/config.json", + "vae/diffusion_pytorch_model.bin", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.safetensors", + ] + ] + + +@pytest.mark.parametrize( + "variant,expected_files", + [ + ( + ModelRepoVariant.FP16, + { + "feature_extractor/preprocessor_config.json", + "safety_checker/config.json", + "safety_checker/model.fp16.safetensors", + "scheduler/scheduler_config.json", + "text_encoder/config.json", + "text_encoder/model.fp16.safetensors", + "tokenizer/merges.txt", + "tokenizer/special_tokens_map.json", + "tokenizer/tokenizer_config.json", + "tokenizer/vocab.json", + "unet/config.json", + "unet/diffusion_pytorch_model.fp16.safetensors", + "vae/config.json", + "vae/diffusion_pytorch_model.fp16.safetensors", + }, + ), + ( + ModelRepoVariant.FP32, + { + "feature_extractor/preprocessor_config.json", + "safety_checker/config.json", + "safety_checker/model.safetensors", + "scheduler/scheduler_config.json", + "text_encoder/config.json", + "text_encoder/model.safetensors", + "tokenizer/merges.txt", + "tokenizer/special_tokens_map.json", + "tokenizer/tokenizer_config.json", + "tokenizer/vocab.json", + "unet/config.json", + "unet/diffusion_pytorch_model.safetensors", + "vae/config.json", + "vae/diffusion_pytorch_model.safetensors", + }, + ), + ], +) +def test_select_multiple_weights( + sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str] +) -> None: + filtered_files = filter_files(sd15_test_files, variant) + assert {str(f) for f in filtered_files} == expected_files