mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): fix extraneous downloaded files in diffusers
Sometimes, diffusers model components (tokenizer, unet, etc.) have multiple weights files in the same directory. In this situation, we assume the files are different versions of the same weights. For example, we may have multiple formats (`.bin`, `.safetensors`) with different precisions. When downloading model files, we want to select only the best of these files for the requested format and precision/variant. The previous logic assumed that each model weights file would have the same base filename, but this assumption was not always true. The logic is revised score each file and choose the best scoring file, resulting in only a single file being downloaded for each submodel/subdirectory.
This commit is contained in:
parent
f2b5f8753f
commit
3534366146
@ -13,6 +13,7 @@ files_to_download = select_hf_model_files(metadata.files, variant='onnx')
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
@ -73,10 +74,34 @@ def filter_files(
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
|
||||
def get_variant_label(path: Path) -> Optional[str]:
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, _ = suffixes
|
||||
else:
|
||||
variant_label = None
|
||||
return variant_label
|
||||
|
||||
|
||||
def get_suffix(path: Path) -> str:
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
_, suffix = suffixes
|
||||
else:
|
||||
suffix = suffixes[0]
|
||||
return suffix
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubfolderCandidate:
|
||||
path: Path
|
||||
score: int
|
||||
|
||||
|
||||
def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames: Dict[Path, Path] = {}
|
||||
result: set[Path] = set()
|
||||
subfolder_weights: dict[Path, list[SubfolderCandidate]] = {}
|
||||
for path in files:
|
||||
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
if variant == ModelRepoVariant.ONNX:
|
||||
@ -93,38 +118,49 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
ModelRepoVariant.DEFAULT,
|
||||
]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
variant_label, suffix = suffixes
|
||||
basename = parent / Path(path.stem).stem
|
||||
else:
|
||||
variant_label = ""
|
||||
suffix = suffixes[0]
|
||||
basename = parent / path.stem
|
||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||
# text encoders:
|
||||
#
|
||||
# - text_encoder/model.fp16.safetensors
|
||||
# - text_encoder/model.safetensors
|
||||
# - text_encoder/pytorch_model.bin
|
||||
# - text_encoder/pytorch_model.fp16.bin
|
||||
#
|
||||
# We prefer safetensors over other file formats and an exact variant match. We'll score each file based on
|
||||
# variant and format and select the best one.
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if (
|
||||
previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
): # replace non-safetensors with safetensors when available
|
||||
basenames[basename] = path
|
||||
if variant_label == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
basenames[basename] = path
|
||||
else:
|
||||
basenames[basename] = path
|
||||
parent = path.parent
|
||||
score = 0
|
||||
|
||||
if path.suffix == ".safetensors":
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = get_variant_label(path)
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
if candidate_variant_label == f".{variant}" or (
|
||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]
|
||||
):
|
||||
score += 1
|
||||
|
||||
if parent not in subfolder_weights:
|
||||
subfolder_weights[parent] = []
|
||||
|
||||
subfolder_weights[parent].append(SubfolderCandidate(path=path, score=score))
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
result.add(highest_score_candidate.path)
|
||||
|
||||
# If one of the architecture-related variants was specified and no files matched other than
|
||||
# config and text files then we return an empty list
|
||||
@ -144,3 +180,76 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
directories[x.parent] = directories.get(x.parent, 0) + 1
|
||||
|
||||
return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
|
||||
|
||||
|
||||
# def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path]:
|
||||
# """Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
# result: set[Path] = set()
|
||||
# basenames: Dict[Path, Path] = {}
|
||||
# for path in files:
|
||||
# if path.suffix in [".onnx", ".pb", ".onnx_data"]:
|
||||
# if variant == ModelRepoVariant.ONNX:
|
||||
# result.add(path)
|
||||
|
||||
# elif "openvino_model" in path.name:
|
||||
# if variant == ModelRepoVariant.OPENVINO:
|
||||
# result.add(path)
|
||||
|
||||
# elif "flax_model" in path.name:
|
||||
# if variant == ModelRepoVariant.FLAX:
|
||||
# result.add(path)
|
||||
|
||||
# elif path.suffix in [".json", ".txt"]:
|
||||
# result.add(path)
|
||||
|
||||
# elif path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"] and variant in [
|
||||
# ModelRepoVariant.FP16,
|
||||
# ModelRepoVariant.FP32,
|
||||
# ModelRepoVariant.DEFAULT,
|
||||
# ]:
|
||||
# parent = path.parent
|
||||
# suffixes = path.suffixes
|
||||
# if len(suffixes) == 2:
|
||||
# variant_label, suffix = suffixes
|
||||
# basename = parent / Path(path.stem).stem
|
||||
# else:
|
||||
# variant_label = ""
|
||||
# suffix = suffixes[0]
|
||||
# basename = parent / path.stem
|
||||
|
||||
# if previous := basenames.get(basename):
|
||||
# if (
|
||||
# previous.suffix != ".safetensors" and suffix == ".safetensors"
|
||||
# ): # replace non-safetensors with safetensors when available
|
||||
# basenames[basename] = path
|
||||
# if variant_label == f".{variant}":
|
||||
# basenames[basename] = path
|
||||
# elif not variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]:
|
||||
# basenames[basename] = path
|
||||
# else:
|
||||
# basenames[basename] = path
|
||||
|
||||
# else:
|
||||
# continue
|
||||
|
||||
# for v in basenames.values():
|
||||
# result.add(v)
|
||||
|
||||
# # If one of the architecture-related variants was specified and no files matched other than
|
||||
# # config and text files then we return an empty list
|
||||
# if (
|
||||
# variant
|
||||
# and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
||||
# and not any(variant.value in x.name for x in result)
|
||||
# ):
|
||||
# return set()
|
||||
|
||||
# # Prune folders that contain just a `config.json`. This happens when
|
||||
# # the requested variant (e.g. "onnx") is missing
|
||||
# directories: Dict[Path, int] = {}
|
||||
# for x in result:
|
||||
# if not x.parent:
|
||||
# continue
|
||||
# directories[x.parent] = directories.get(x.parent, 0) + 1
|
||||
|
||||
# return {x for x in result if directories[x.parent] > 1 or x.name != "config.json"}
|
||||
|
@ -235,7 +235,94 @@ def sdxl_base_files() -> List[Path]:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[str]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_test_files() -> list[Path]:
|
||||
return [
|
||||
Path(f)
|
||||
for f in [
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"safety_checker/model.safetensors",
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.non_ema.bin",
|
||||
"unet/diffusion_pytorch_model.non_ema.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_files",
|
||||
[
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
{
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
},
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP32,
|
||||
{
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"safety_checker/config.json",
|
||||
"safety_checker/model.safetensors",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select_multiple_weights(
|
||||
sd15_test_files: list[Path], variant: ModelRepoVariant, expected_files: set[str]
|
||||
) -> None:
|
||||
filtered_files = filter_files(sd15_test_files, variant)
|
||||
assert {str(f) for f in filtered_files} == expected_files
|
||||
|
Loading…
Reference in New Issue
Block a user