mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up diagnostic messages
This commit is contained in:
parent
ca10d0652f
commit
5461318eda
@ -212,6 +212,7 @@ class Generate:
|
||||
else:
|
||||
print('>> xformers not installed')
|
||||
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_manager = ModelManager(mconfig, self.device, self.precision,
|
||||
max_loaded_models=max_loaded_models,
|
||||
@ -234,7 +235,7 @@ class Generate:
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print('>> Initializing safety checker')
|
||||
print('>> Initializing NSFW checker model')
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
@ -251,6 +252,9 @@ class Generate:
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
print('>> NSFW checker is disabled')
|
||||
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
|
||||
@ -25,7 +25,6 @@ from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import Completer, get_completer
|
||||
from ldm.util import url_attachment_name
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
@ -492,7 +491,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
elif not os.path.exists(path[1]):
|
||||
print(f'** {path[1]}: model not found')
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
@ -504,7 +503,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
|
||||
elif not path[1] in gen.model_manager.list_models():
|
||||
print(f'** {path[1]}: model not found')
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
@ -574,160 +573,54 @@ def set_default_output_dir(opt:Args, completer:Completer):
|
||||
completer.set_default_dir(opt.outdir)
|
||||
|
||||
|
||||
def import_model(model_path: str, gen, opt, completer):
|
||||
def import_model(model_path: str, gen, opt, completer, convert=False)->str:
|
||||
"""
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||
(3) a huggingface repository id; or (4) a local directory containing a
|
||||
diffusers model.
|
||||
"""
|
||||
model.path = model_path.replace('\\','/') # windows
|
||||
model_path = model_path.replace('\\','/') # windows
|
||||
default_name = Path(model_path).stem
|
||||
default_description = f'Imported model {default_name}'
|
||||
model_name = None
|
||||
model_desc = None
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
|
||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
|
||||
elif os.path.isdir(model_path):
|
||||
|
||||
# Allow for a directory containing multiple models.
|
||||
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
|
||||
|
||||
if models:
|
||||
models = import_checkpoint_list(models, gen, opt, completer)
|
||||
model_name = models[0] if len(models) == 1 else None
|
||||
else:
|
||||
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
||||
|
||||
elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
|
||||
if Path(model_path).is_dir() and not (Path(model_path) / 'model_index.json').exists():
|
||||
pass
|
||||
else:
|
||||
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
|
||||
|
||||
if not model_name:
|
||||
model_name, model_desc = _get_model_name_and_desc(
|
||||
gen.model_manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
)
|
||||
imported_name = gen.model_manager.heuristic_import(
|
||||
model_path,
|
||||
model_name=model_name,
|
||||
description=model_desc,
|
||||
convert=convert,
|
||||
)
|
||||
|
||||
if not imported_name:
|
||||
print('** model failed to load. Aborting')
|
||||
return
|
||||
|
||||
if not _verify_load(model_name, gen):
|
||||
if not _verify_load(imported_name, gen):
|
||||
print('** model failed to load. Discarding configuration entry')
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.del_model(imported_name)
|
||||
return
|
||||
if click.confirm('Make this the default model?', default=False):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
gen.model_manager.set_default_model(imported_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f'>> {model_name} successfully installed')
|
||||
|
||||
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
||||
'''
|
||||
Does a mass import of all the checkpoint/safetensors on a path list
|
||||
'''
|
||||
model_names = list()
|
||||
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
|
||||
do_all = choice.startswith('a')
|
||||
if do_all:
|
||||
config_file = _ask_for_config_file(models[0], completer, plural=True)
|
||||
manager = gen.model_manager
|
||||
for model in sorted(models):
|
||||
model_name = f'{model.stem}'
|
||||
model_description = f'Imported model {model_name}'
|
||||
if model_name in manager.model_names():
|
||||
print(f'** {model_name} is already imported. Skipping.')
|
||||
elif manager.import_ckpt_model(
|
||||
model,
|
||||
config = config_file,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf):
|
||||
model_names.append(model_name)
|
||||
print(f'>> Model {model_name} imported successfully')
|
||||
else:
|
||||
print(f'** Model {model} failed to import')
|
||||
else:
|
||||
for model in sorted(models):
|
||||
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||
if model_name := import_ckpt_model(model, gen, opt, completer):
|
||||
print(f'>> Model {model.stem} imported successfully')
|
||||
model_names.append(model_name)
|
||||
else:
|
||||
printf('** Model {model} failed to import')
|
||||
print()
|
||||
return model_names
|
||||
|
||||
def import_diffuser_model(
|
||||
path_or_repo: Union[Path, str], gen, _, completer
|
||||
) -> Optional[str]:
|
||||
path_or_repo = path_or_repo.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_repo).stem
|
||||
default_description = f'Imported model {default_name}'
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
vae = None
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo,
|
||||
model_name = model_name,
|
||||
vae = vae,
|
||||
description = model_description):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(
|
||||
path_or_url: Union[Path, str], gen, opt, completer
|
||||
) -> Optional[str]:
|
||||
path_or_url = path_or_url.replace('\\','/')
|
||||
manager = gen.model_manager
|
||||
is_a_url = str(path_or_url).startswith(('http:','https:'))
|
||||
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
|
||||
default_name = Path(base_name).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
|
||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||
vae = None
|
||||
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
||||
done = (not vae) or os.path.exists(vae)
|
||||
completer.complete_extensions(None)
|
||||
config_file = _ask_for_config_file(path_or_url, completer)
|
||||
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config = config_file,
|
||||
vae = vae,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf,
|
||||
):
|
||||
print('** model failed to import')
|
||||
return None
|
||||
|
||||
return model_name
|
||||
|
||||
def _verify_load(model_name:str, gen)->bool:
|
||||
print('>> Verifying that new model loads...')
|
||||
current_model = gen.model_name
|
||||
try:
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
if not gen.set_model(model_name):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'** model failed to load: {str(e)}')
|
||||
@ -743,46 +636,12 @@ def _verify_load(model_name:str, gen)->bool:
|
||||
|
||||
def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''):
|
||||
model_name = _get_model_name(model_manager.list_models(),completer,model_name)
|
||||
model_description = model_description or f'Imported model {model_name}'
|
||||
completer.set_line(model_description)
|
||||
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
|
||||
return model_name, model_description
|
||||
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
|
||||
default = '1'
|
||||
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
|
||||
default = '3'
|
||||
choices={
|
||||
'1': 'v1-inference.yaml',
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
|
||||
prompt = '''What type of models are these?:
|
||||
[1] Models based on Stable Diffusion 1.X
|
||||
[2] Models based on Stable Diffusion 2.X
|
||||
[3] Inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else''' if plural else '''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else'''
|
||||
print(prompt)
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
if config_file := choices.get(choice,None):
|
||||
return Path('configs','stable-diffusion',config_file)
|
||||
|
||||
# otherwise ask user to select
|
||||
done = False
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
|
||||
while not done:
|
||||
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
|
||||
done = not config_path or os.path.exists(config_path)
|
||||
return config_path
|
||||
|
||||
|
||||
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
def convert_model(model_name_or_path: Union[Path,str], gen, opt, completer)->str:
|
||||
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
@ -796,58 +655,32 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
original_config_file = Path(model_info['config'])
|
||||
model_name = model_name_or_path
|
||||
model_description = model_info['description']
|
||||
vae = model_info['vae']
|
||||
else:
|
||||
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
||||
return
|
||||
elif os.path.exists(model_name_or_path):
|
||||
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
ckpt_path = Path(model_name_or_path)
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
ckpt_path.stem,
|
||||
f'Converted model {ckpt_path.stem}'
|
||||
if vae_repo:= ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
|
||||
vae_repo = dict(repo_id = vae_repo)
|
||||
else:
|
||||
vae_repo = None
|
||||
model_name = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffusers_path=Path(Globals.root, 'models', Globals.converted_ckpts_dir, model_name_or_path),
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
original_config_file=original_config_file,
|
||||
vae = vae_repo,
|
||||
)
|
||||
else:
|
||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True)
|
||||
|
||||
if not model_name:
|
||||
print('** Conversion failed. Aborting.')
|
||||
return
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
if original_config_file and not original_config_file.is_absolute():
|
||||
original_config_file = Path(Globals.root,original_config_file)
|
||||
|
||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name)
|
||||
if diffuser_path.exists():
|
||||
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||
return
|
||||
|
||||
vae = None
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffuser_path,
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
vae = vae,
|
||||
original_config_file = original_config_file,
|
||||
commit_to_conf=opt.conf,
|
||||
)
|
||||
if not new_config:
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if click.confirm(f'Load optimized model {model_name}?', default=True):
|
||||
gen.set_model(model_name)
|
||||
|
||||
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f'{ckpt_path} deleted')
|
||||
return model_name
|
||||
|
||||
def del_config(model_name:str, gen, opt, completer):
|
||||
current_model = gen.model_name
|
||||
|
@ -16,6 +16,7 @@ import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union
|
||||
@ -33,11 +34,16 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
global_models_dir)
|
||||
from ldm.invoke.globals import (Globals, global_cache_dir)
|
||||
from ldm.util import (ask_user, download_with_resume,
|
||||
url_attachment_name, instantiate_from_config)
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
V1_INPAINT = 2
|
||||
V2 = 3
|
||||
UNKNOWN = 99
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
|
||||
@ -467,19 +473,6 @@ class ModelManager(object):
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
">> Max VRAM used to load the model:",
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
"\n>> Current VRAM usage:"
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
@ -611,15 +604,15 @@ class ModelManager(object):
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(">> Model scanned ok!")
|
||||
print(">> Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
vae: dict = None,
|
||||
commit_to_conf: Path = None,
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
model_description: str = None,
|
||||
vae: dict = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Attempts to install the indicated diffuser model and returns True if successful.
|
||||
@ -657,7 +650,7 @@ class ModelManager(object):
|
||||
model_name: str = None,
|
||||
model_description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> bool:
|
||||
) -> str:
|
||||
"""
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
@ -674,6 +667,8 @@ class ModelManager(object):
|
||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
|
||||
Return value is the name of the imported file, or None if an error occurred.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
@ -682,9 +677,9 @@ class ModelManager(object):
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return False
|
||||
return
|
||||
if config_path is None or not config_path.exists():
|
||||
return False
|
||||
return
|
||||
|
||||
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
@ -703,41 +698,101 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return True
|
||||
return model_name
|
||||
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict)->SDLegacyType:
|
||||
'''
|
||||
Given a pickle or safetensors model object, probes contents
|
||||
of the object and returns an SDLegacyType indicating its
|
||||
format. Valid return values include:
|
||||
SDLegacyType.V1
|
||||
SDLegacyType.V1_INPAINT
|
||||
SDLegacyType.V2
|
||||
UNKNOWN
|
||||
'''
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
return SDLegacyType.V2
|
||||
|
||||
try:
|
||||
state_dict = checkpoint.get('state_dict') or checkpoint
|
||||
in_channels = state_dict['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
if in_channels == 9:
|
||||
return SDLegacyType.V1_INPAINT
|
||||
elif in_channels == 4:
|
||||
return SDLegacyType.V1
|
||||
else:
|
||||
return SDLegacyType.UNKNOWN
|
||||
except KeyError:
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
convert: bool= False,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
commit_to_conf: Path=None,
|
||||
):
|
||||
)->str:
|
||||
'''
|
||||
Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||
- a local directory containing .ckpt and .safetensors files
|
||||
- a local directory containing a diffusers model
|
||||
|
||||
After determining the nature of the model and downloading it
|
||||
(if necessary), the file is probed to determine the correct
|
||||
configuration file (if needed) and it is imported.
|
||||
|
||||
The model_name and/or description can be provided. If not, they will
|
||||
be generated automatically.
|
||||
|
||||
If convert is true, legacy models will be converted to diffusers
|
||||
before importing.
|
||||
|
||||
If commit_to_conf is provided, the newly loaded model will be written
|
||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||
will only remain in memory.
|
||||
|
||||
The (potentially derived) name of the model is returned on success, or None
|
||||
on failure. When multiple models are added from a directory, only the last
|
||||
imported one is returned.
|
||||
'''
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
print(f'>> Probing {thing} for import')
|
||||
|
||||
if thing.startswith(('http:','https:','ftp:')):
|
||||
print(f' | {thing} appears to be a URL')
|
||||
print(f' | {thing} appears to be a URL')
|
||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith(('.ckpt','.safetensors')):
|
||||
print(f' | {thing} appears to be a checkpoint file on disk')
|
||||
print(f' | {thing} appears to be a checkpoint file on disk')
|
||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1')
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
|
||||
print(f' | {thing} appears to be a diffusers file on disk')
|
||||
model_name = self.import_diffusers_model(
|
||||
print(f' | {thing} appears to be a diffusers file on disk')
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
||||
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
||||
model_name=model_name,
|
||||
description=description,
|
||||
commit_to_conf=commit_to_conf
|
||||
)
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
print(f'>> {thing} appears to be a directory. Will scan for models to import')
|
||||
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
|
||||
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf)
|
||||
return
|
||||
if model_name := self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf):
|
||||
print(f' >> {model_name} successfully imported')
|
||||
return model_name
|
||||
|
||||
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
|
||||
print(f' | {thing} appears to be a HuggingFace diffusers repo_id')
|
||||
print(f' | {thing} appears to be a HuggingFace diffusers repo_id')
|
||||
model_name = self.import_diffuser_model(thing, commit_to_conf=commit_to_conf)
|
||||
pipeline,_,_,_ = self._load_diffusers_model(self.config[model_name])
|
||||
|
||||
@ -750,68 +805,68 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: #already imported
|
||||
print(' > Already imported. Skipping')
|
||||
print(' | Already imported. Skipping')
|
||||
return
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
model_config_file = None
|
||||
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
print(' > SD-v2 model detected; model will be converted to diffusers format')
|
||||
model_config_file = None
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(' | SD-v1 model detected')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(' | SD-v1 inpainting model detected')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(' | SD-v2 model detected; model will be converted to diffusers format')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
|
||||
convert = True
|
||||
|
||||
if not model_config_file: # still trying
|
||||
in_channels = None
|
||||
try:
|
||||
state_dict = checkpoint.get('state_dict') or checkpoint
|
||||
in_channels = state_dict['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
|
||||
if in_channels == 9:
|
||||
print(' > SD-v1 inpainting model detected')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
||||
elif in_channels == 4:
|
||||
print(' > SD-v1 model detected')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
else:
|
||||
print(f'** {thing} does not have an expected number of in_channels ({in_channels}). It will probably break when loaded.')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
except KeyError:
|
||||
print(f'** {thing} does not have the expected SD-v1 model fields. It will probably break when loaded.')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
else:
|
||||
print(f'** {thing} is a legacy checkpoint file of unkown format. Will treat as a regular v1.X model')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
|
||||
self.convert_and_import(
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
else:
|
||||
self.import_ckpt_model(
|
||||
model_name = self.import_ckpt_model(
|
||||
model_path,
|
||||
config=model_config_file,
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')),
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
into models.yaml.
|
||||
"""
|
||||
ckpt_path = self._resolve_path(ckpt_path, 'models/ldm/stable-diffusion-v1')
|
||||
if original_config_file:
|
||||
original_config_file = self._resolve_path(original_config_file, 'configs/stable-diffusion')
|
||||
|
||||
new_config = None
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
@ -857,7 +912,7 @@ class ModelManager(object):
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return new_config
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f">> Finding Models In: {search_folder}")
|
||||
|
Loading…
Reference in New Issue
Block a user