mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix installer logic for tokenizer_3 and text_encoder_3
This commit is contained in:
parent
28f1d25973
commit
39881d3d7d
@ -43,7 +43,9 @@ def filter_files(
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
DIFFUSERS_COMPONENT_PATTERN = r"model(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||
DIFFUSERS_COMPONENT_PATTERN = (
|
||||
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||
)
|
||||
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
paths: List[Path] = []
|
||||
@ -75,7 +77,6 @@ def filter_files(
|
||||
if subfolder:
|
||||
subfolder = root / subfolder
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
@ -103,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
# handle shard patterns
|
||||
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
|
||||
if variant is ModelRepoVariant.FP16:
|
||||
result.add(path)
|
||||
else:
|
||||
continue
|
||||
|
||||
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
|
||||
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
|
||||
result.add(path)
|
||||
else:
|
||||
continue
|
||||
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
@ -129,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
@ -145,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
else:
|
||||
continue
|
||||
|
||||
print(subfolder_weights)
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
|
Loading…
Reference in New Issue
Block a user