Allow multiple models to be imported by passing a directory.

This commit is contained in:
Dan Sully 2023-02-04 19:05:27 -08:00
parent 4895fe8395
commit 2ec864e37e
5 changed files with 39 additions and 17 deletions

View File

@ -4,6 +4,10 @@ import sys
import shlex
import traceback
from argparse import Namespace
from pathlib import Path
from typing import Optional, Union
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -16,10 +20,10 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.model_manager import ModelManager
from pathlib import Path
from argparse import Namespace
import pyparsing
import click # type: ignore
import ldm.invoke
import pyparsing # type: ignore
# global used in multiple functions (fix)
infile = None
@ -69,7 +73,7 @@ def main():
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
import transformers # type: ignore
transformers.logging.set_verbosity_error()
import diffusers
diffusers.logging.set_verbosity_error()
@ -572,7 +576,7 @@ def set_default_output_dir(opt:Args, completer:Completer):
completer.set_default_dir(opt.outdir)
def import_model(model_path:str, gen, opt, completer):
def import_model(model_path: str, gen, opt, completer):
'''
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
(3) a huggingface repository id
@ -581,12 +585,28 @@ def import_model(model_path:str, gen, opt, completer):
if model_path.startswith(('http:','https:','ftp:')):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
elif os.path.isdir(model_path):
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
# Allow for a directory containing multiple models.
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
if models:
# Only the last model name will be used below.
for model in sorted(models):
if click.confirm(f'Import {model.stem} ?', default=True):
model_name = import_ckpt_model(model, gen, opt, completer)
print()
else:
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
else:
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
@ -604,7 +624,7 @@ def import_model(model_path:str, gen, opt, completer):
completer.update_models(gen.model_manager.list_models())
print(f'>> {model_name} successfully installed')
def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f'Imported model {default_name}'
@ -627,7 +647,7 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
return None
return model_name
def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
manager = gen.model_manager
default_name = Path(path_or_url).stem
default_description = f'Imported model {default_name}'

View File

@ -756,6 +756,7 @@ class Args(object):
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
!import_model /path/to/weights/ -- prompts to add models in a directory
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model

View File

@ -18,7 +18,7 @@ import warnings
import safetensors.torch
from pathlib import Path
from shutil import move, rmtree
from typing import Union, Any
from typing import Any, Optional, Union
from huggingface_hub import scan_cache_dir
from ldm.util import download_with_progress_bar
@ -880,14 +880,14 @@ class ModelManager(object):
print('** Migration is done. Continuing...')
def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path:
def _resolve_path(self, source: Union[str, Path], dest_directory: str) -> Optional[Path]:
resolved_path = None
if source.startswith(('http:','https:','ftp:')):
if str(source).startswith(('http:','https:','ftp:')):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root,dest_directory)
dest = os.path.join(dest_directory,basename)
if download_with_progress_bar(source,dest):
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
else:
if not os.path.isabs(source):

View File

@ -284,9 +284,9 @@ class ProgressBar():
def download_with_progress_bar(url:str, dest:Path)->bool:
try:
if not os.path.exists(dest):
os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest)))
if not dest.exists():
dest.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(dest.stem))
return True
else:
return True

View File

@ -36,6 +36,7 @@ classifiers = [
dependencies = [
"accelerate",
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"datasets",
"diffusers[torch]~=0.11",