diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index 931e17a3a1..8f971534f7 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -4,6 +4,10 @@ import sys import shlex import traceback +from argparse import Namespace +from pathlib import Path +from typing import Optional, Union + if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -16,10 +20,10 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager -from pathlib import Path -from argparse import Namespace -import pyparsing + +import click # type: ignore import ldm.invoke +import pyparsing # type: ignore # global used in multiple functions (fix) infile = None @@ -69,7 +73,7 @@ def main(): # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported - import transformers + import transformers # type: ignore transformers.logging.set_verbosity_error() import diffusers diffusers.logging.set_verbosity_error() @@ -572,7 +576,7 @@ def set_default_output_dir(opt:Args, completer:Completer): completer.set_default_dir(opt.outdir) -def import_model(model_path:str, gen, opt, completer): +def import_model(model_path: str, gen, opt, completer): ''' model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or (3) a huggingface repository id @@ -581,12 +585,28 @@ def import_model(model_path:str, gen, opt, completer): if model_path.startswith(('http:','https:','ftp:')): model_name = import_ckpt_model(model_path, gen, opt, completer) + elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path): model_name = import_ckpt_model(model_path, gen, opt, completer) - elif re.match('^[\w.+-]+/[\w.+-]+$',model_path): - model_name = import_diffuser_model(model_path, gen, opt, completer) + elif os.path.isdir(model_path): - model_name = import_diffuser_model(Path(model_path), gen, opt, completer) + + # Allow for a directory containing multiple models. + models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors')) + + if models: + # Only the last model name will be used below. + for model in sorted(models): + + if click.confirm(f'Import {model.stem} ?', default=True): + model_name = import_ckpt_model(model, gen, opt, completer) + print() + else: + model_name = import_diffuser_model(Path(model_path), gen, opt, completer) + + elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path): + model_name = import_diffuser_model(model_path, gen, opt, completer) + else: print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.') @@ -604,7 +624,7 @@ def import_model(model_path:str, gen, opt, completer): completer.update_models(gen.model_manager.list_models()) print(f'>> {model_name} successfully installed') -def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str: +def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_repo).stem default_description = f'Imported model {default_name}' @@ -627,7 +647,7 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str: return None return model_name -def import_ckpt_model(path_or_url:str, gen, opt, completer)->str: +def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: manager = gen.model_manager default_name = Path(path_or_url).stem default_description = f'Imported model {default_name}' diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 4707565424..c9c7ffe3b0 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -763,6 +763,7 @@ class Args(object): !models -- list models in configs/models.yaml !switch -- switch to model named !import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config + !import_model /path/to/weights/ -- interactively import models from a directory !import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config !import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config !optimize_model -- converts a .ckpt model to a diffusers model diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index b9421255a2..43854d7938 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -18,7 +18,7 @@ import warnings import safetensors.torch from pathlib import Path from shutil import move, rmtree -from typing import Union, Any +from typing import Any, Optional, Union from huggingface_hub import scan_cache_dir from ldm.util import download_with_progress_bar @@ -880,14 +880,14 @@ class ModelManager(object): print('** Migration is done. Continuing...') - def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path: + def _resolve_path(self, source: Union[str, Path], dest_directory: str) -> Optional[Path]: resolved_path = None - if source.startswith(('http:','https:','ftp:')): + if str(source).startswith(('http:','https:','ftp:')): basename = os.path.basename(source) if not os.path.isabs(dest_directory): dest_directory = os.path.join(Globals.root,dest_directory) dest = os.path.join(dest_directory,basename) - if download_with_progress_bar(source,dest): + if download_with_progress_bar(str(source), Path(dest)): resolved_path = Path(dest) else: if not os.path.isabs(source): diff --git a/ldm/util.py b/ldm/util.py index 7d44dcd266..447875537f 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -284,9 +284,9 @@ class ProgressBar(): def download_with_progress_bar(url:str, dest:Path)->bool: try: - if not os.path.exists(dest): - os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True) - request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest))) + if not dest.exists(): + dest.parent.mkdir(parents=True, exist_ok=True) + request.urlretrieve(url,dest,ProgressBar(dest.stem)) return True else: return True diff --git a/pyproject.toml b/pyproject.toml index 18df1e2e2f..f3dfa69b91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ classifiers = [ dependencies = [ "accelerate", "albumentations", + "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "datasets", "diffusers[torch]~=0.11",