From d5a949e6c32957a7d8834d25abd965182cc4baeb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 6 Aug 2024 19:34:49 +0000 Subject: [PATCH] Update HF download logic to work for black-forest-labs/FLUX.1-schnell. --- .../model_manager/util/select_hf_files.py | 24 +++++- .../util/test_hf_model_select.py | 77 +++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index b0a9551437..2e86d9a62e 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -54,6 +54,7 @@ def filter_files( "lora_weights.safetensors", "weights.pb", "onnx_data", + "spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`. ) ): paths.append(file) @@ -62,7 +63,7 @@ def filter_files( # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models # will adhere to this naming convention, so this is an area to be careful of. - elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): + elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) # limit search to subfolder if requested @@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path if variant == ModelRepoVariant.Flax: result.add(path) - elif path.suffix in [".json", ".txt"]: + # Note: '.model' was added to support: + # https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model + elif path.suffix in [".json", ".txt", ".model"]: result.add(path) elif variant in [ @@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path continue for candidate_list in subfolder_weights.values(): + # Check if at least one of the files has the explicit fp16 variant. + at_least_one_fp16 = False + for candidate in candidate_list: + if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16": + at_least_one_fp16 = True + break + + if not at_least_one_fp16: + # If none of the candidates in this candidate_list have the explicit fp16 variant label, then this + # candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case, + # we'll simply keep all the candidates. An example of a model that hits this case is + # `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd). + for candidate in candidate_list: + result.add(candidate.path) + + # The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring + # candidate. highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score) if highest_score_candidate: result.add(highest_score_candidate.path) diff --git a/tests/backend/model_manager/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py index a29827e8c4..8b5a395fdb 100644 --- a/tests/backend/model_manager/util/test_hf_model_select.py +++ b/tests/backend/model_manager/util/test_hf_model_select.py @@ -326,3 +326,80 @@ def test_select_multiple_weights( ) -> None: filtered_files = filter_files(sd15_test_files, variant) assert set(filtered_files) == {Path(f) for f in expected_files} + + +@pytest.fixture +def flux_schnell_test_files() -> list[Path]: + return [ + Path(f) + for f in [ + "FLUX.1-schnell/.gitattributes", + "FLUX.1-schnell/README.md", + "FLUX.1-schnell/ae.safetensors", + "FLUX.1-schnell/flux1-schnell.safetensors", + "FLUX.1-schnell/model_index.json", + "FLUX.1-schnell/scheduler/scheduler_config.json", + "FLUX.1-schnell/schnell_grid.jpeg", + "FLUX.1-schnell/text_encoder/config.json", + "FLUX.1-schnell/text_encoder/model.safetensors", + "FLUX.1-schnell/text_encoder_2/config.json", + "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", + "FLUX.1-schnell/tokenizer/merges.txt", + "FLUX.1-schnell/tokenizer/special_tokens_map.json", + "FLUX.1-schnell/tokenizer/tokenizer_config.json", + "FLUX.1-schnell/tokenizer/vocab.json", + "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", + "FLUX.1-schnell/tokenizer_2/spiece.model", + "FLUX.1-schnell/tokenizer_2/tokenizer.json", + "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", + "FLUX.1-schnell/transformer/config.json", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", + "FLUX.1-schnell/vae/config.json", + "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", + ] + ] + + +@pytest.mark.parametrize( + ["variant", "expected_files"], + [ + ( + ModelRepoVariant.Default, + [ + "FLUX.1-schnell/model_index.json", + "FLUX.1-schnell/scheduler/scheduler_config.json", + "FLUX.1-schnell/text_encoder/config.json", + "FLUX.1-schnell/text_encoder/model.safetensors", + "FLUX.1-schnell/text_encoder_2/config.json", + "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", + "FLUX.1-schnell/tokenizer/merges.txt", + "FLUX.1-schnell/tokenizer/special_tokens_map.json", + "FLUX.1-schnell/tokenizer/tokenizer_config.json", + "FLUX.1-schnell/tokenizer/vocab.json", + "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", + "FLUX.1-schnell/tokenizer_2/spiece.model", + "FLUX.1-schnell/tokenizer_2/tokenizer.json", + "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", + "FLUX.1-schnell/transformer/config.json", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", + "FLUX.1-schnell/vae/config.json", + "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", + ], + ), + ], +) +def test_select_flux_schnell_files( + flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str] +) -> None: + filtered_files = filter_files(flux_schnell_test_files, variant) + assert set(filtered_files) == {Path(f) for f in expected_files}