InvokeAI/invokeai/backend/install/model_install_backend.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

638 lines
25 KiB
Python
Raw Permalink Normal View History

"""
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import warnings
2023-07-28 13:46:44 +00:00
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Set, Union
import requests
import torch
from diffusers import DiffusionPipeline
2023-07-01 18:32:58 +00:00
from diffusers import logging as dlogging
from huggingface_hub import HfApi, HfFolder, hf_hub_url
from omegaconf import OmegaConf
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
2023-08-18 15:13:28 +00:00
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
2023-08-17 23:17:38 +00:00
logger = InvokeAILogger.get_logger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
2023-09-24 16:22:29 +00:00
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
2023-07-28 13:46:44 +00:00
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
2023-07-28 13:46:44 +00:00
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
2023-07-28 13:46:44 +00:00
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
2023-07-28 13:46:44 +00:00
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
2023-07-28 13:46:44 +00:00
@dataclass
2023-07-28 13:46:44 +00:00
class InstallSelections:
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
2023-07-28 13:46:44 +00:00
class ModelLoadInfo:
name: str
model_type: ModelType
base_type: BaseModelType
path: Optional[Path] = None
repo_id: Optional[str] = None
subfolder: Optional[str] = None
2023-07-28 13:46:44 +00:00
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
requires: Optional[List[str]] = field(default_factory=list)
2023-07-28 13:46:44 +00:00
2023-09-24 23:00:38 +00:00
class ModelInstall(object):
2023-07-28 13:46:44 +00:00
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
2024-01-31 03:25:23 +00:00
civitai_api_key: Optional[str] = None,
2023-07-28 13:46:44 +00:00
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
2024-01-31 03:25:23 +00:00
self.civitai_api_key = civitai_api_key or config.civitai_api_key
self.reverse_paths = self._reverse_paths(self.datasets)
2023-07-28 13:46:44 +00:00
def all_models(self) -> Dict[str, ModelLoadInfo]:
"""
Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
2023-07-28 13:46:44 +00:00
"""
model_dict = {}
2023-07-28 13:46:44 +00:00
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
2023-07-28 13:46:44 +00:00
name, base, model_type = ModelManager.parse_key(key)
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_info = ModelLoadInfo(**value)
if model_info.subfolder and model_info.repo_id:
model_info.repo_id += f":{model_info.subfolder}"
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = list(self.mgr.list_models())
2023-07-28 13:46:44 +00:00
2023-06-23 18:13:37 +00:00
for md in installed_models:
2023-07-28 13:46:44 +00:00
base = md["base_model"]
model_type = md["model_type"]
name = md["model_name"]
2023-06-23 18:13:37 +00:00
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
2023-07-28 13:46:44 +00:00
name=name,
base_type=base,
model_type=model_type,
path=value.get("path"),
installed=True,
2023-06-23 18:13:37 +00:00
)
2023-07-28 13:46:44 +00:00
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
2023-07-29 17:00:43 +00:00
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
2023-07-14 23:52:47 +00:00
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print()
2023-07-28 13:46:44 +00:00
print(f"Installed models of type `{model_type}`:")
print(f"{'Model Key':50} Model Path")
2023-07-14 23:52:47 +00:00
for i in installed:
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
print()
2023-07-14 23:52:47 +00:00
# logic here a little reversed to maintain backward compatibility
2023-07-28 13:46:44 +00:00
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
2023-11-10 23:51:21 +00:00
for key, _value in self.datasets.items():
2023-07-28 13:46:44 +00:00
name, base, model_type = ModelManager.parse_key(key)
2023-07-27 03:06:27 +00:00
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
return models
2023-07-28 13:46:44 +00:00
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return {x for x in starters if self.datasets[x].get("recommended", False)}
2023-07-28 13:46:44 +00:00
def default_model(self) -> str:
starters = self.starter_models()
2023-07-28 13:46:44 +00:00
defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0]
def install(self, selections: InstallSelections):
2023-07-01 18:32:58 +00:00
verbosity = dlogging.get_verbosity() # quench NSFW nags
dlogging.set_verbosity_error()
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
2023-07-28 13:46:44 +00:00
# remove requested models
for key in selections.remove_models:
2023-07-28 13:46:44 +00:00
name, base, mtype = self.mgr.parse_key(key)
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
2023-07-01 18:32:58 +00:00
try:
2023-07-28 13:46:44 +00:00
self.mgr.del_model(name, base, mtype)
2023-07-01 18:32:58 +00:00
except FileNotFoundError as e:
logger.warning(e)
job += 1
2023-07-28 13:46:44 +00:00
# add requested models
self._remove_installed(selections.install_models)
self._add_required_models(selections.install_models)
for path in selections.install_models:
2023-07-28 13:46:44 +00:00
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
2023-07-28 13:46:44 +00:00
2023-07-01 18:32:58 +00:00
dlogging.set_verbosity(verbosity)
self.mgr.commit()
2023-07-28 13:46:44 +00:00
def heuristic_import(
self,
model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]:
"""
2023-07-03 23:32:54 +00:00
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
2023-07-28 13:46:44 +00:00
"""
if not models_installed:
models_installed = {}
2023-07-28 13:46:44 +00:00
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
2023-10-13 01:23:29 +00:00
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
2023-10-13 01:23:29 +00:00
# fix relative paths
if path.exists() and not path.is_absolute():
path = path.absolute() # make relative to current WD
# checkpoint file, or similar
if path.is_file():
2023-07-28 13:46:44 +00:00
models_installed.update({str(path): self._install_path(path)})
# folders style or similar
2023-07-28 13:46:44 +00:00
elif path.is_dir() and any(
(path / x).exists()
2023-11-10 23:55:06 +00:00
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
2023-07-28 13:46:44 +00:00
):
2023-07-18 02:21:11 +00:00
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
2023-07-28 13:46:44 +00:00
elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
2023-07-28 13:46:44 +00:00
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
models_to_remove = []
for path in model_list:
2023-09-24 23:00:38 +00:00
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
models_to_remove.append(path)
for path in models_to_remove:
logger.warning(f"{path} already installed. Skipping")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []
all_models = self.all_models()
for path in model_list:
2023-09-24 23:00:38 +00:00
if not (key := self.reverse_paths.get(path)):
continue
for requirement in all_models[key].requires:
requirement_key = self.reverse_paths.get(requirement)
if not all_models[requirement_key].installed:
additional_models.append(requirement)
model_list.extend(additional_models)
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
2023-07-28 13:46:44 +00:00
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
2023-07-06 00:25:47 +00:00
if not info:
2023-07-28 13:46:44 +00:00
logger.warning(f"Unable to parse format of {path}")
2023-07-06 00:25:47 +00:00
return None
2023-07-06 01:53:08 +00:00
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
2023-07-28 13:46:44 +00:00
attributes = self._make_attributes(path, info)
return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
model_attributes=attributes,
)
def _install_url(self, url: str) -> AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
2024-01-31 01:35:01 +00:00
CIVITAI_RE = r".*civitai.com.*"
civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE)
2024-01-31 03:32:40 +00:00
location = download_with_resume(
url, Path(staging), access_token=self.civitai_api_key if civit_url else None
)
if not location:
2023-07-28 13:46:44 +00:00
logger.error(f"Unable to download {url}. Skipping.")
2023-09-24 16:22:29 +00:00
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
2023-07-28 13:46:44 +00:00
models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
2023-07-28 13:46:44 +00:00
def _install_repo(self, repo_id: str) -> AddModelResult:
# hack to recover models stored in subfolders --
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
subfolder = None
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
repo_id = match.group(1)
subfolder = match.group(2)
hinfo = HfApi().model_info(repo_id)
2023-07-28 13:46:44 +00:00
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith(f"{subfolder}/")]
prefix = f"{subfolder}/" if subfolder else ""
2023-06-20 15:08:27 +00:00
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if f"{prefix}model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
elif f"{prefix}unet/model.onnx" in files:
2023-07-28 20:54:03 +00:00
location = self._download_hf_model(repo_id, files, staging)
2023-06-20 15:08:27 +00:00
else:
2023-07-28 13:46:44 +00:00
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
) # LoRA
2023-06-20 15:08:27 +00:00
break
2023-07-28 13:46:44 +00:00
elif (
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
2023-07-28 13:46:44 +00:00
): # vae, controlnet or some other standalone
2023-09-28 03:49:31 +00:00
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
2023-06-20 15:08:27 +00:00
break
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
2023-09-28 03:49:31 +00:00
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
2023-06-20 15:08:27 +00:00
break
elif f"{prefix}learned_embeds.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
)
2023-06-20 15:08:27 +00:00
break
elif (
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
): # IP-Adapter
2023-09-28 03:49:31 +00:00
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
2023-09-28 03:49:31 +00:00
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
2023-06-20 15:08:27 +00:00
if not location:
2023-07-28 13:46:44 +00:00
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
2023-06-20 15:08:27 +00:00
if not info:
2023-07-28 13:46:44 +00:00
logger.warning(f"Could not probe {location}. Skipping install.")
return {}
2023-07-28 13:46:44 +00:00
dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists():
shutil.rmtree(dest)
2023-07-28 13:46:44 +00:00
shutil.copytree(location, dest)
return self._install_path(dest, info)
2023-07-28 13:46:44 +00:00
def _get_model_name(self, path_name: str, location: Path) -> str:
"""
Calculate a name for the model - primitive implementation.
2023-07-28 13:46:44 +00:00
"""
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
elif location.is_dir():
return location.name
else:
return location.stem
2023-07-28 13:46:44 +00:00
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem
2023-07-28 13:46:44 +00:00
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
2023-07-28 13:46:44 +00:00
description = self.datasets[key].get("description") or description
2023-07-29 14:47:55 +00:00
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
2023-07-28 20:54:03 +00:00
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
2023-07-28 13:46:44 +00:00
attributes.update(
{
"variant": info.variant_type,
}
2023-07-28 13:46:44 +00:00
)
if info.format == "checkpoint":
try:
2023-07-28 13:46:44 +00:00
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
2023-09-24 16:22:29 +00:00
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
2023-07-28 13:46:44 +00:00
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else:
2023-07-28 13:46:44 +00:00
legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError:
2023-07-28 13:46:44 +00:00
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
else:
legacy_conf = Path(
self.config.root_path,
"configs/controlnet",
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
)
if legacy_conf:
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = root or self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
2023-07-28 13:46:44 +00:00
"""
Retrieve a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
2023-07-28 13:46:44 +00:00
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
2023-08-03 23:24:23 +00:00
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
2023-08-03 23:23:52 +00:00
model = None
2023-08-03 23:23:52 +00:00
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(
2023-08-03 23:23:52 +00:00
repo_id,
variant=variant,
torch_dtype=precision,
2023-08-03 23:24:23 +00:00
safety_checker=None,
subfolder=subfolder,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
2023-08-03 23:23:52 +00:00
if model:
break
2023-08-03 23:24:23 +00:00
if not model:
2023-07-28 13:46:44 +00:00
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
2023-07-28 13:46:44 +00:00
_, name = repo_id.split("/")
location = staging / name
paths = []
for filename in files:
2023-07-28 20:54:03 +00:00
filePath = Path(filename)
2023-07-28 13:46:44 +00:00
p = hf_download_with_resume(
repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
2023-07-28 13:46:44 +00:00
)
if p:
paths.append(p)
else:
2023-07-28 13:46:44 +00:00
logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None
@classmethod
2023-07-28 13:46:44 +00:00
def _reverse_paths(cls, datasets) -> dict:
"""
Reverse mapping from repo_id/path to destination name.
2023-07-28 13:46:44 +00:00
"""
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
2023-05-30 17:49:43 +00:00
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
2023-07-28 13:46:44 +00:00
# ---------------------------------------------
2023-07-28 13:46:44 +00:00
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
2023-08-17 23:17:38 +00:00
logger = InvokeAILogger.get_logger("InvokeAI")
2023-07-28 13:46:44 +00:00
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
return destination
2023-07-28 13:46:44 +00:00
# ---------------------------------------------
def hf_download_with_resume(
2023-07-28 13:46:44 +00:00
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
2023-07-28 20:54:03 +00:00
subfolder: str = None,
) -> Path:
2023-05-30 17:49:43 +00:00
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
2023-07-28 20:54:03 +00:00
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
2023-07-28 13:46:44 +00:00
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
2023-05-30 17:49:43 +00:00
elif resp.status_code == 404:
logger.warning("File not found")
2023-05-30 17:49:43 +00:00
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
logger.info(f"{model_name}: partial file found. Resuming...")
else:
logger.info(f"{model_name}: Downloading...")
try:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest