mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update HF download logic to work for black-forest-labs/FLUX.1-schnell.
This commit is contained in:
parent
3bbba7e4b1
commit
7d447cbb88
@ -54,6 +54,7 @@ def filter_files(
|
|||||||
"lora_weights.safetensors",
|
"lora_weights.safetensors",
|
||||||
"weights.pb",
|
"weights.pb",
|
||||||
"onnx_data",
|
"onnx_data",
|
||||||
|
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
@ -62,7 +63,7 @@ def filter_files(
|
|||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
# will adhere to this naming convention, so this is an area to be careful of.
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
|
|
||||||
# limit search to subfolder if requested
|
# limit search to subfolder if requested
|
||||||
@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
# Note: '.model' was added to support:
|
||||||
|
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||||
|
elif path.suffix in [".json", ".txt", ".model"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif variant in [
|
elif variant in [
|
||||||
@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for candidate_list in subfolder_weights.values():
|
||||||
|
# Check if at least one of the files has the explicit fp16 variant.
|
||||||
|
at_least_one_fp16 = False
|
||||||
|
for candidate in candidate_list:
|
||||||
|
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||||
|
at_least_one_fp16 = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not at_least_one_fp16:
|
||||||
|
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||||
|
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||||
|
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||||
|
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||||
|
for candidate in candidate_list:
|
||||||
|
result.add(candidate.path)
|
||||||
|
|
||||||
|
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||||
|
# candidate.
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||||
if highest_score_candidate:
|
if highest_score_candidate:
|
||||||
result.add(highest_score_candidate.path)
|
result.add(highest_score_candidate.path)
|
||||||
|
@ -326,3 +326,80 @@ def test_select_multiple_weights(
|
|||||||
) -> None:
|
) -> None:
|
||||||
filtered_files = filter_files(sd15_test_files, variant)
|
filtered_files = filter_files(sd15_test_files, variant)
|
||||||
assert set(filtered_files) == {Path(f) for f in expected_files}
|
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flux_schnell_test_files() -> list[Path]:
|
||||||
|
return [
|
||||||
|
Path(f)
|
||||||
|
for f in [
|
||||||
|
"FLUX.1-schnell/.gitattributes",
|
||||||
|
"FLUX.1-schnell/README.md",
|
||||||
|
"FLUX.1-schnell/ae.safetensors",
|
||||||
|
"FLUX.1-schnell/flux1-schnell.safetensors",
|
||||||
|
"FLUX.1-schnell/model_index.json",
|
||||||
|
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||||
|
"FLUX.1-schnell/schnell_grid.jpeg",
|
||||||
|
"FLUX.1-schnell/text_encoder/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||||
|
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/transformer/config.json",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/vae/config.json",
|
||||||
|
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["variant", "expected_files"],
|
||||||
|
[
|
||||||
|
(
|
||||||
|
ModelRepoVariant.Default,
|
||||||
|
[
|
||||||
|
"FLUX.1-schnell/model_index.json",
|
||||||
|
"FLUX.1-schnell/scheduler/scheduler_config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder/model.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/config.json",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors",
|
||||||
|
"FLUX.1-schnell/text_encoder_2/model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/merges.txt",
|
||||||
|
"FLUX.1-schnell/tokenizer/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/tokenizer/vocab.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/special_tokens_map.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/spiece.model",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer.json",
|
||||||
|
"FLUX.1-schnell/tokenizer_2/tokenizer_config.json",
|
||||||
|
"FLUX.1-schnell/transformer/config.json",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors",
|
||||||
|
"FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json",
|
||||||
|
"FLUX.1-schnell/vae/config.json",
|
||||||
|
"FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_select_flux_schnell_files(
|
||||||
|
flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str]
|
||||||
|
) -> None:
|
||||||
|
filtered_files = filter_files(flux_schnell_test_files, variant)
|
||||||
|
assert set(filtered_files) == {Path(f) for f in expected_files}
|
||||||
|
Loading…
Reference in New Issue
Block a user