Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein
2024-02-05 21:55:11 -05:00
committed by psychedelicious
parent 0d3addc69b
commit 5745ce9c7d
18 changed files with 215 additions and 49 deletions

View File

@ -41,13 +41,21 @@ def filter_files(
for file in files:
if file.name.endswith((".json", ".txt")):
paths.append(file)
elif file.name.endswith(("learned_embeds.bin", "ip_adapter.bin", "lora_weights.safetensors")):
elif file.name.endswith(
(
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
)
):
paths.append(file)
# BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area of brittleness.
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
@ -64,7 +72,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
result = set()
basenames: Dict[Path, Path] = {}
for path in files:
if path.suffix == ".onnx":
if path.suffix in [".onnx", ".pb", ".onnx_data"]:
if variant == ModelRepoVariant.ONNX:
result.add(path)