tweak installer to select correct components of HF SD3 diffusers models

This commit is contained in:
Lincoln Stein 2024-06-14 16:46:24 -04:00
parent 03b9d17d0b
commit 78f704e7d5

View File

@ -35,6 +35,16 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
"""
# BRITTLENESS WARNING!!
# The following pattern is designed to match model files that are components of diffusers submodels,
# but not to match other random stuff found in huggingface repos.
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
DIFFUSERS_COMPONENT_PATTERN = r"model(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
variant = variant or ModelRepoVariant.Default
paths: List[Path] = []
root = files[0].parts[0]
@ -45,24 +55,20 @@ def filter_files(
# Start by filtering on model file extensions, discarding images, docs, etc
for file in files:
if file.name.endswith((".json", ".txt")):
paths.append(file)
elif file.name.endswith(
if file.name.endswith(
(
".json",
".txt",
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model",
)
):
paths.append(file)
# BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(DIFFUSERS_COMPONENT_PATTERN, file.name):
paths.append(file)
# limit search to subfolder if requested