diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 2ffc1e6ff4..d412d0226f 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -104,12 +104,14 @@ class ModelInstall(object): prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, model_manager: Optional[ModelManager] = None, access_token: Optional[str] = None, + civit_api_key: Optional[str] = None, ): self.config = config self.mgr = model_manager or ModelManager(config.model_conf_path) self.datasets = OmegaConf.load(Dataset_path) self.prediction_helper = prediction_type_helper self.access_token = access_token or HfFolder.get_token() + self.civit_api_key = civit_api_key or os.environ.get("CIVIT_API_KEY") self.reverse_paths = self._reverse_paths(self.datasets) def all_models(self) -> Dict[str, ModelLoadInfo]: @@ -326,7 +328,10 @@ class ModelInstall(object): def _install_url(self, url: str) -> AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: - location = download_with_resume(url, Path(staging)) + CIVITAI_RE = r".*civitai.com.*" + civit_url = re.match(CIVITAI_RE, url, re.IGNORECASE) + print(civit_url) + location = download_with_resume(url, Path(staging), access_token=self.civit_api_key if civit_url else None) if not location: logger.error(f"Unable to download {url}. Skipping.") info = ModelProbe().heuristic_probe(location, self.prediction_helper) diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 4612b42cb9..85a4488915 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -286,9 +286,8 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path open_mode = "wb" exist_size = 0 - resp = requests.get(url, header, stream=True) + resp = requests.get(url, headers=header, stream=True, allow_redirects=True) content_length = int(resp.headers.get("content-length", 0)) - if dest.is_dir(): try: file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)