add back the heuristic_import() method and extend repo_ids to arbitrary file paths

This commit is contained in:
Lincoln Stein
2024-02-11 23:37:49 -05:00
committed by psychedelicious
parent a23dedd2ee
commit 4027e845d4
6 changed files with 199 additions and 12 deletions

View File

@ -36,6 +36,11 @@ def filter_files(
"""
variant = variant or ModelRepoVariant.DEFAULT
paths: List[Path] = []
root = files[0].parts[0]
# if the subfolder is a single file, then bypass the selection and just return it
if subfolder and subfolder.suffix in [".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack"]:
return [root / subfolder]
# Start by filtering on model file extensions, discarding images, docs, etc
for file in files:
@ -61,6 +66,7 @@ def filter_files(
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set