bring in url download bugfix from PR 2630

This commit is contained in:
Lincoln Stein
2023-02-16 00:37:17 -05:00
4 changed files with 359 additions and 174 deletions

View File

@ -32,14 +32,12 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (
Globals,
global_autoscan_dir,
global_cache_dir,
global_models_dir,
)
from ldm.util import ask_user, download_with_progress_bar, instantiate_from_config
from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@ -673,15 +671,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@ -1042,16 +1043,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)