use variant instead of revision

This commit is contained in:
Lincoln Stein 2023-08-03 19:23:52 -04:00
parent f4981f26d5
commit ab5d938a1d

View File

@ -419,18 +419,24 @@ class ModelInstall(object):
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
revisions = ["fp16", "main"] if precision == torch.float16 else ["main"]
variants = ["fp16",None] if precision == torch.float16 else [None,"fp16"]
model = None
for revision in revisions:
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(
repo_id, revision=revision, safety_checker=None, torch_dtype=precision
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None