fix broken !import_model downloads

1. Now works with sites that produce lots of redirects, such as CIVITAI
2. Derive name of destination model file from HTTP Content-Disposition header,
   if present.
3. Swap \\ for / in file paths provided by users, to hopefully fix issues with
   Windows.
This commit is contained in:
Lincoln Stein
2023-02-13 22:14:24 -05:00
parent 15a9412255
commit d38e7170fe
3 changed files with 210 additions and 119 deletions

View File

@ -36,8 +36,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_progress_bar,
instantiate_from_config)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@ -670,15 +670,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@ -748,7 +751,6 @@ class ModelManager(object):
into models.yaml.
"""
new_config = None
import transformers
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
@ -967,16 +969,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)