Allow multiple models to be imported by passing a directory. (#2529)

This change allows passing a directory with multiple models in it to be
imported.

Ensures that diffusers directories will still work.

Fixed up some minor type issues.
This commit is contained in:
Lincoln Stein 2023-02-05 13:36:00 -05:00 committed by GitHub
commit cd5c112fcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 17 deletions

View File

@ -4,6 +4,10 @@ import sys
import shlex import shlex
import traceback import traceback
from argparse import Namespace
from pathlib import Path
from typing import Optional, Union
if sys.platform == "darwin": if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -16,10 +20,10 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from ldm.invoke.model_manager import ModelManager from ldm.invoke.model_manager import ModelManager
from pathlib import Path
from argparse import Namespace import click # type: ignore
import pyparsing
import ldm.invoke import ldm.invoke
import pyparsing # type: ignore
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -69,7 +73,7 @@ def main():
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported # when the frozen CLIP tokenizer is imported
import transformers import transformers # type: ignore
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
import diffusers import diffusers
diffusers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error()
@ -572,7 +576,7 @@ def set_default_output_dir(opt:Args, completer:Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path:str, gen, opt, completer): def import_model(model_path: str, gen, opt, completer):
''' '''
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
(3) a huggingface repository id (3) a huggingface repository id
@ -581,12 +585,28 @@ def import_model(model_path:str, gen, opt, completer):
if model_path.startswith(('http:','https:','ftp:')): if model_path.startswith(('http:','https:','ftp:')):
model_name = import_ckpt_model(model_path, gen, opt, completer) model_name = import_ckpt_model(model_path, gen, opt, completer)
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path): elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
model_name = import_ckpt_model(model_path, gen, opt, completer) model_name = import_ckpt_model(model_path, gen, opt, completer)
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
elif os.path.isdir(model_path): elif os.path.isdir(model_path):
# Allow for a directory containing multiple models.
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
if models:
# Only the last model name will be used below.
for model in sorted(models):
if click.confirm(f'Import {model.stem} ?', default=True):
model_name = import_ckpt_model(model, gen, opt, completer)
print()
else:
model_name = import_diffuser_model(Path(model_path), gen, opt, completer) model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
else: else:
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.') print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
@ -604,7 +624,7 @@ def import_model(model_path:str, gen, opt, completer):
completer.update_models(gen.model_manager.list_models()) completer.update_models(gen.model_manager.list_models())
print(f'>> {model_name} successfully installed') print(f'>> {model_name} successfully installed')
def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str: def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
manager = gen.model_manager manager = gen.model_manager
default_name = Path(path_or_repo).stem default_name = Path(path_or_repo).stem
default_description = f'Imported model {default_name}' default_description = f'Imported model {default_name}'
@ -627,7 +647,7 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
return None return None
return model_name return model_name
def import_ckpt_model(path_or_url:str, gen, opt, completer)->str: def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
manager = gen.model_manager manager = gen.model_manager
default_name = Path(path_or_url).stem default_name = Path(path_or_url).stem
default_description = f'Imported model {default_name}' default_description = f'Imported model {default_name}'

View File

@ -763,6 +763,7 @@ class Args(object):
!models -- list models in configs/models.yaml !models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name> !switch <model_name> -- switch to model named <model_name>
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config !import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
!import_model /path/to/weights/ -- interactively import models from a directory
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config !import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config !import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model !optimize_model <model_name> -- converts a .ckpt model to a diffusers model

View File

@ -18,7 +18,7 @@ import warnings
import safetensors.torch import safetensors.torch
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Union, Any from typing import Any, Optional, Union
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from ldm.util import download_with_progress_bar from ldm.util import download_with_progress_bar
@ -880,14 +880,14 @@ class ModelManager(object):
print('** Migration is done. Continuing...') print('** Migration is done. Continuing...')
def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path: def _resolve_path(self, source: Union[str, Path], dest_directory: str) -> Optional[Path]:
resolved_path = None resolved_path = None
if source.startswith(('http:','https:','ftp:')): if str(source).startswith(('http:','https:','ftp:')):
basename = os.path.basename(source) basename = os.path.basename(source)
if not os.path.isabs(dest_directory): if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root,dest_directory) dest_directory = os.path.join(Globals.root,dest_directory)
dest = os.path.join(dest_directory,basename) dest = os.path.join(dest_directory,basename)
if download_with_progress_bar(source,dest): if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest) resolved_path = Path(dest)
else: else:
if not os.path.isabs(source): if not os.path.isabs(source):

View File

@ -284,9 +284,9 @@ class ProgressBar():
def download_with_progress_bar(url:str, dest:Path)->bool: def download_with_progress_bar(url:str, dest:Path)->bool:
try: try:
if not os.path.exists(dest): if not dest.exists():
os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True) dest.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest))) request.urlretrieve(url,dest,ProgressBar(dest.stem))
return True return True
else: else:
return True return True

View File

@ -36,6 +36,7 @@ classifiers = [
dependencies = [ dependencies = [
"accelerate", "accelerate",
"albumentations", "albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"datasets", "datasets",
"diffusers[torch]~=0.11", "diffusers[torch]~=0.11",