mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow multiple models to be imported by passing a directory. (#2529)
This change allows passing a directory with multiple models in it to be imported. Ensures that diffusers directories will still work. Fixed up some minor type issues.
This commit is contained in:
commit
cd5c112fcd
@ -4,6 +4,10 @@ import sys
|
|||||||
import shlex
|
import shlex
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
@ -16,10 +20,10 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
|||||||
from ldm.invoke.image_util import make_grid
|
from ldm.invoke.image_util import make_grid
|
||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from ldm.invoke.model_manager import ModelManager
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from pathlib import Path
|
|
||||||
from argparse import Namespace
|
import click # type: ignore
|
||||||
import pyparsing
|
|
||||||
import ldm.invoke
|
import ldm.invoke
|
||||||
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
# global used in multiple functions (fix)
|
# global used in multiple functions (fix)
|
||||||
infile = None
|
infile = None
|
||||||
@ -69,7 +73,7 @@ def main():
|
|||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
import transformers
|
import transformers # type: ignore
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
import diffusers
|
import diffusers
|
||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
@ -572,7 +576,7 @@ def set_default_output_dir(opt:Args, completer:Completer):
|
|||||||
completer.set_default_dir(opt.outdir)
|
completer.set_default_dir(opt.outdir)
|
||||||
|
|
||||||
|
|
||||||
def import_model(model_path:str, gen, opt, completer):
|
def import_model(model_path: str, gen, opt, completer):
|
||||||
'''
|
'''
|
||||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
|
||||||
(3) a huggingface repository id
|
(3) a huggingface repository id
|
||||||
@ -581,12 +585,28 @@ def import_model(model_path:str, gen, opt, completer):
|
|||||||
|
|
||||||
if model_path.startswith(('http:','https:','ftp:')):
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
|
|
||||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
|
||||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
|
||||||
elif os.path.isdir(model_path):
|
elif os.path.isdir(model_path):
|
||||||
|
|
||||||
|
# Allow for a directory containing multiple models.
|
||||||
|
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
|
||||||
|
|
||||||
|
if models:
|
||||||
|
# Only the last model name will be used below.
|
||||||
|
for model in sorted(models):
|
||||||
|
|
||||||
|
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||||
|
model_name = import_ckpt_model(model, gen, opt, completer)
|
||||||
|
print()
|
||||||
|
else:
|
||||||
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
||||||
|
|
||||||
|
elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path):
|
||||||
|
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
||||||
|
|
||||||
@ -604,7 +624,7 @@ def import_model(model_path:str, gen, opt, completer):
|
|||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
print(f'>> {model_name} successfully installed')
|
print(f'>> {model_name} successfully installed')
|
||||||
|
|
||||||
def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
default_name = Path(path_or_repo).stem
|
default_name = Path(path_or_repo).stem
|
||||||
default_description = f'Imported model {default_name}'
|
default_description = f'Imported model {default_name}'
|
||||||
@ -627,7 +647,7 @@ def import_diffuser_model(path_or_repo:str, gen, opt, completer)->str:
|
|||||||
return None
|
return None
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
default_name = Path(path_or_url).stem
|
default_name = Path(path_or_url).stem
|
||||||
default_description = f'Imported model {default_name}'
|
default_description = f'Imported model {default_name}'
|
||||||
|
@ -763,6 +763,7 @@ class Args(object):
|
|||||||
!models -- list models in configs/models.yaml
|
!models -- list models in configs/models.yaml
|
||||||
!switch <model_name> -- switch to model named <model_name>
|
!switch <model_name> -- switch to model named <model_name>
|
||||||
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
|
!import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config
|
||||||
|
!import_model /path/to/weights/ -- interactively import models from a directory
|
||||||
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
|
!import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config
|
||||||
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
|
!import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config
|
||||||
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model
|
!optimize_model <model_name> -- converts a .ckpt model to a diffusers model
|
||||||
|
@ -18,7 +18,7 @@ import warnings
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Union, Any
|
from typing import Any, Optional, Union
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from ldm.util import download_with_progress_bar
|
from ldm.util import download_with_progress_bar
|
||||||
|
|
||||||
@ -880,14 +880,14 @@ class ModelManager(object):
|
|||||||
print('** Migration is done. Continuing...')
|
print('** Migration is done. Continuing...')
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(self, source:Union[str,Path], dest_directory:str)->Path:
|
def _resolve_path(self, source: Union[str, Path], dest_directory: str) -> Optional[Path]:
|
||||||
resolved_path = None
|
resolved_path = None
|
||||||
if source.startswith(('http:','https:','ftp:')):
|
if str(source).startswith(('http:','https:','ftp:')):
|
||||||
basename = os.path.basename(source)
|
basename = os.path.basename(source)
|
||||||
if not os.path.isabs(dest_directory):
|
if not os.path.isabs(dest_directory):
|
||||||
dest_directory = os.path.join(Globals.root,dest_directory)
|
dest_directory = os.path.join(Globals.root,dest_directory)
|
||||||
dest = os.path.join(dest_directory,basename)
|
dest = os.path.join(dest_directory,basename)
|
||||||
if download_with_progress_bar(source,dest):
|
if download_with_progress_bar(str(source), Path(dest)):
|
||||||
resolved_path = Path(dest)
|
resolved_path = Path(dest)
|
||||||
else:
|
else:
|
||||||
if not os.path.isabs(source):
|
if not os.path.isabs(source):
|
||||||
|
@ -284,9 +284,9 @@ class ProgressBar():
|
|||||||
|
|
||||||
def download_with_progress_bar(url:str, dest:Path)->bool:
|
def download_with_progress_bar(url:str, dest:Path)->bool:
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(dest):
|
if not dest.exists():
|
||||||
os.makedirs((os.path.dirname(dest) or '.'), exist_ok=True)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
request.urlretrieve(url,dest,ProgressBar(os.path.basename(dest)))
|
request.urlretrieve(url,dest,ProgressBar(dest.stem))
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
@ -36,6 +36,7 @@ classifiers = [
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"albumentations",
|
"albumentations",
|
||||||
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.11",
|
"diffusers[torch]~=0.11",
|
||||||
|
Loading…
Reference in New Issue
Block a user