fix installer logic for tokenizer_3 and text_encoder_3

This commit is contained in:
Lincoln Stein 2024-06-21 23:34:18 -04:00
parent 28f1d25973
commit 39881d3d7d

View File

@ -43,7 +43,9 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
DIFFUSERS_COMPONENT_PATTERN = r"model(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
DIFFUSERS_COMPONENT_PATTERN = (
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
)
variant = variant or ModelRepoVariant.Default
paths: List[Path] = []
@ -75,7 +77,6 @@ def filter_files(
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@ -103,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
# handle shard patterns
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
if variant is ModelRepoVariant.FP16:
result.add(path)
else:
continue
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
result.add(path)
else:
continue
elif variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
@ -129,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
@ -145,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
else:
continue
print(subfolder_weights)
for candidate_list in subfolder_weights.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate: