mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix installer logic for tokenizer_3 and text_encoder_3
This commit is contained in:
parent
28f1d25973
commit
39881d3d7d
@ -43,7 +43,9 @@ def filter_files(
|
|||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
# will adhere to this naming convention, so this is an area to be careful of.
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
DIFFUSERS_COMPONENT_PATTERN = r"model(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
DIFFUSERS_COMPONENT_PATTERN = (
|
||||||
|
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||||
|
)
|
||||||
|
|
||||||
variant = variant or ModelRepoVariant.Default
|
variant = variant or ModelRepoVariant.Default
|
||||||
paths: List[Path] = []
|
paths: List[Path] = []
|
||||||
@ -75,7 +77,6 @@ def filter_files(
|
|||||||
if subfolder:
|
if subfolder:
|
||||||
subfolder = root / subfolder
|
subfolder = root / subfolder
|
||||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||||
|
|
||||||
# _filter_by_variant uniquifies the paths and returns a set
|
# _filter_by_variant uniquifies the paths and returns a set
|
||||||
return sorted(_filter_by_variant(paths, variant))
|
return sorted(_filter_by_variant(paths, variant))
|
||||||
|
|
||||||
@ -103,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
elif path.suffix in [".json", ".txt", ".model"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
|
# handle shard patterns
|
||||||
|
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
|
||||||
|
if variant is ModelRepoVariant.FP16:
|
||||||
|
result.add(path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
|
||||||
|
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
|
||||||
|
result.add(path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
elif variant in [
|
elif variant in [
|
||||||
ModelRepoVariant.FP16,
|
ModelRepoVariant.FP16,
|
||||||
ModelRepoVariant.FP32,
|
ModelRepoVariant.FP32,
|
||||||
@ -129,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||||
|
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
|
||||||
|
|
||||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||||
@ -145,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
print(subfolder_weights)
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for candidate_list in subfolder_weights.values():
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||||
if highest_score_candidate:
|
if highest_score_candidate:
|
||||||
|
Loading…
Reference in New Issue
Block a user