clean up diagnostic messages

This commit is contained in:
Lincoln Stein
2023-02-19 19:38:29 -05:00
parent ca10d0652f
commit 5461318eda
3 changed files with 176 additions and 284 deletions

View File

@ -212,6 +212,7 @@ class Generate:
else: else:
print('>> xformers not installed') print('>> xformers not installed')
# model caching system for fast switching # model caching system for fast switching
self.model_manager = ModelManager(mconfig, self.device, self.precision, self.model_manager = ModelManager(mconfig, self.device, self.precision,
max_loaded_models=max_loaded_models, max_loaded_models=max_loaded_models,
@ -234,7 +235,7 @@ class Generate:
# load safety checker if requested # load safety checker if requested
if safety_checker: if safety_checker:
try: try:
print('>> Initializing safety checker') print('>> Initializing NSFW checker model')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
@ -251,6 +252,9 @@ class Generate:
except Exception: except Exception:
print('** An error was encountered while installing the safety checker:') print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc()) print(traceback.format_exc())
else:
print('>> NSFW checker is disabled')
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """

View File

@ -5,7 +5,7 @@ import sys
import traceback import traceback
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Union
import click import click
@ -25,7 +25,6 @@ from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import Completer, get_completer from ldm.invoke.readline import Completer, get_completer
from ldm.util import url_attachment_name
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -492,7 +491,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
elif not os.path.exists(path[1]): elif not os.path.exists(path[1]):
print(f'** {path[1]}: model not found') print(f'** {path[1]}: model not found')
else: else:
optimize_model(path[1], gen, opt, completer) convert_model(path[1], gen, opt, completer)
completer.add_history(command) completer.add_history(command)
operation = None operation = None
@ -504,7 +503,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
elif not path[1] in gen.model_manager.list_models(): elif not path[1] in gen.model_manager.list_models():
print(f'** {path[1]}: model not found') print(f'** {path[1]}: model not found')
else: else:
optimize_model(path[1], gen, opt, completer) convert_model(path[1], gen, opt, completer)
completer.add_history(command) completer.add_history(command)
operation = None operation = None
@ -574,160 +573,54 @@ def set_default_output_dir(opt:Args, completer:Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer): def import_model(model_path: str, gen, opt, completer, convert=False)->str:
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
diffusers model. diffusers model.
""" """
model.path = model_path.replace('\\','/') # windows model_path = model_path.replace('\\','/') # windows
default_name = Path(model_path).stem
default_description = f'Imported model {default_name}'
model_name = None model_name = None
model_desc = None
if model_path.startswith(('http:','https:','ftp:')): if Path(model_path).is_dir() and not (Path(model_path) / 'model_index.json').exists():
model_name = import_ckpt_model(model_path, gen, opt, completer) pass
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif os.path.isdir(model_path):
# Allow for a directory containing multiple models.
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
if models:
models = import_checkpoint_list(models, gen, opt, completer)
model_name = models[0] if len(models) == 1 else None
else: else:
model_name = import_diffuser_model(Path(model_path), gen, opt, completer) model_name, model_desc = _get_model_name_and_desc(
gen.model_manager,
completer,
model_name=default_name,
)
imported_name = gen.model_manager.heuristic_import(
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
)
elif re.match(r'^[\w.+-]+/[\w.+-]+$', model_path): if not imported_name:
model_name = import_diffuser_model(model_path, gen, opt, completer) print('** model failed to load. Aborting')
else:
print(f'** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can\'t import.')
if not model_name:
return return
if not _verify_load(model_name, gen): if not _verify_load(imported_name, gen):
print('** model failed to load. Discarding configuration entry') print('** model failed to load. Discarding configuration entry')
gen.model_manager.del_model(model_name) gen.model_manager.del_model(imported_name)
return return
if click.confirm('Make this the default model?', default=False): if click.confirm('Make this the default model?', default=False):
gen.model_manager.set_default_model(model_name) gen.model_manager.set_default_model(imported_name)
gen.model_manager.commit(opt.conf) gen.model_manager.commit(opt.conf)
completer.update_models(gen.model_manager.list_models()) completer.update_models(gen.model_manager.list_models())
print(f'>> {model_name} successfully installed') print(f'>> {model_name} successfully installed')
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
'''
Does a mass import of all the checkpoint/safetensors on a path list
'''
model_names = list()
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
do_all = choice.startswith('a')
if do_all:
config_file = _ask_for_config_file(models[0], completer, plural=True)
manager = gen.model_manager
for model in sorted(models):
model_name = f'{model.stem}'
model_description = f'Imported model {model_name}'
if model_name in manager.model_names():
print(f'** {model_name} is already imported. Skipping.')
elif manager.import_ckpt_model(
model,
config = config_file,
model_name = model_name,
model_description = model_description,
commit_to_conf = opt.conf):
model_names.append(model_name)
print(f'>> Model {model_name} imported successfully')
else:
print(f'** Model {model} failed to import')
else:
for model in sorted(models):
if click.confirm(f'Import {model.stem} ?', default=True):
if model_name := import_ckpt_model(model, gen, opt, completer):
print(f'>> Model {model.stem} imported successfully')
model_names.append(model_name)
else:
printf('** Model {model} failed to import')
print()
return model_names
def import_diffuser_model(
path_or_repo: Union[Path, str], gen, _, completer
) -> Optional[str]:
path_or_repo = path_or_repo.replace('\\','/') # windows
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f'Imported model {default_name}'
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
model_name=default_name,
model_description=default_description
)
vae = None
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
if not manager.import_diffuser_model(
path_or_repo,
model_name = model_name,
vae = vae,
description = model_description):
print('** model failed to import')
return None
return model_name
def import_ckpt_model(
path_or_url: Union[Path, str], gen, opt, completer
) -> Optional[str]:
path_or_url = path_or_url.replace('\\','/')
manager = gen.model_manager
is_a_url = str(path_or_url).startswith(('http:','https:'))
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
default_name = Path(base_name).stem
default_description = f"Imported model {default_name}"
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
model_name=default_name,
model_description=default_description
)
completer.complete_extensions(('.ckpt','.safetensors'))
vae = None
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
completer.set_line(str(default))
done = False
while not done:
vae = input('VAE file for this model (leave blank for none): ').strip() or None
done = (not vae) or os.path.exists(vae)
completer.complete_extensions(None)
config_file = _ask_for_config_file(path_or_url, completer)
if not manager.import_ckpt_model(
path_or_url,
config = config_file,
vae = vae,
model_name = model_name,
model_description = model_description,
commit_to_conf = opt.conf,
):
print('** model failed to import')
return None
return model_name return model_name
def _verify_load(model_name:str, gen)->bool: def _verify_load(model_name:str, gen)->bool:
print('>> Verifying that new model loads...') print('>> Verifying that new model loads...')
current_model = gen.model_name current_model = gen.model_name
try: try:
if not gen.model_manager.get_model(model_name): if not gen.set_model(model_name):
return False return False
except Exception as e: except Exception as e:
print(f'** model failed to load: {str(e)}') print(f'** model failed to load: {str(e)}')
@ -743,46 +636,12 @@ def _verify_load(model_name:str, gen)->bool:
def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''): def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_description:str=''):
model_name = _get_model_name(model_manager.list_models(),completer,model_name) model_name = _get_model_name(model_manager.list_models(),completer,model_name)
model_description = model_description or f'Imported model {model_name}'
completer.set_line(model_description) completer.set_line(model_description)
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
return model_name, model_description return model_name, model_description
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path: def convert_model(model_name_or_path: Union[Path,str], gen, opt, completer)->str:
default = '1'
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
default = '3'
choices={
'1': 'v1-inference.yaml',
'2': 'v2-inference-v.yaml',
'3': 'v1-inpainting-inference.yaml',
}
prompt = '''What type of models are these?:
[1] Models based on Stable Diffusion 1.X
[2] Models based on Stable Diffusion 2.X
[3] Inpainting models based on Stable Diffusion 1.X
[4] Something else''' if plural else '''What type of model is this?:
[1] A model based on Stable Diffusion 1.X
[2] A model based on Stable Diffusion 2.X
[3] An inpainting models based on Stable Diffusion 1.X
[4] Something else'''
print(prompt)
choice = input(f'Your choice: [{default}] ')
choice = choice.strip() or default
if config_file := choices.get(choice,None):
return Path('configs','stable-diffusion',config_file)
# otherwise ask user to select
done = False
completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
while not done:
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
done = not config_path or os.path.exists(config_path)
return config_path
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace('\\','/') # windows model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager manager = gen.model_manager
ckpt_path = None ckpt_path = None
@ -796,58 +655,32 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
original_config_file = Path(model_info['config']) original_config_file = Path(model_info['config'])
model_name = model_name_or_path model_name = model_name_or_path
model_description = model_info['description'] model_description = model_info['description']
vae = model_info['vae']
else: else:
print(f'** {model_name_or_path} is not a legacy .ckpt weights file') print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
return return
elif os.path.exists(model_name_or_path): if vae_repo:= ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer) vae_repo = dict(repo_id = vae_repo)
if not original_config_file:
return
ckpt_path = Path(model_name_or_path)
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
ckpt_path.stem,
f'Converted model {ckpt_path.stem}'
)
else: else:
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') vae_repo = None
return model_name = gen.model_manager.convert_and_import(
if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root,ckpt_path)
if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(Globals.root,original_config_file)
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name)
if diffuser_path.exists():
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
return
vae = None
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
new_config = gen.model_manager.convert_and_import(
ckpt_path, ckpt_path,
diffuser_path, diffusers_path=Path(Globals.root, 'models', Globals.converted_ckpts_dir, model_name_or_path),
model_name=model_name, model_name=model_name,
model_description=model_description, model_description=model_description,
vae = vae, original_config_file=original_config_file,
original_config_file = original_config_file, vae = vae_repo,
commit_to_conf=opt.conf,
) )
if not new_config: else:
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True)
if not model_name:
print('** Conversion failed. Aborting.')
return return
completer.update_models(gen.model_manager.list_models())
if click.confirm(f'Load optimized model {model_name}?', default=True):
gen.set_model(model_name)
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False): if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
ckpt_path.unlink(missing_ok=True) ckpt_path.unlink(missing_ok=True)
print(f'{ckpt_path} deleted') print(f'{ckpt_path} deleted')
return model_name
def del_config(model_name:str, gen, opt, completer): def del_config(model_name:str, gen, opt, completer):
current_model = gen.model_name current_model = gen.model_name

View File

@ -16,6 +16,7 @@ import sys
import textwrap import textwrap
import time import time
import warnings import warnings
from enum import Enum
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -33,11 +34,16 @@ from picklescan.scanner import scan_file_path
from ldm.invoke.generator.diffusers_pipeline import \ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir, from ldm.invoke.globals import (Globals, global_cache_dir)
global_models_dir)
from ldm.util import (ask_user, download_with_resume, from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config) url_attachment_name, instantiate_from_config)
class SDLegacyType(Enum):
V1 = 1
V1_INPAINT = 2
V2 = 3
UNKNOWN = 99
DEFAULT_MAX_MODELS = 2 DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import() VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse", "vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
@ -467,19 +473,6 @@ class ModelManager(object):
for module in model.modules(): for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
module._orig_padding_mode = module.padding_mode module._orig_padding_mode = module.padding_mode
# usage statistics
toc = time.time()
print(">> Model loaded in", "%4.2fs" % (toc - tic))
if self._has_cuda():
print(
">> Max VRAM used to load the model:",
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
"\n>> Current VRAM usage:"
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash return model, width, height, model_hash
def _load_diffusers_model(self, mconfig): def _load_diffusers_model(self, mconfig):
@ -611,13 +604,13 @@ class ModelManager(object):
print("### Exiting InvokeAI") print("### Exiting InvokeAI")
sys.exit() sys.exit()
else: else:
print(">> Model scanned ok!") print(">> Model scanned ok")
def import_diffuser_model( def import_diffuser_model(
self, self,
repo_or_path: Union[str, Path], repo_or_path: Union[str, Path],
model_name: str = None, model_name: str = None,
description: str = None, model_description: str = None,
vae: dict = None, vae: dict = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
) -> bool: ) -> bool:
@ -657,7 +650,7 @@ class ModelManager(object):
model_name: str = None, model_name: str = None,
model_description: str = None, model_description: str = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
) -> bool: ) -> str:
""" """
Attempts to install the indicated ckpt file and returns True if successful. Attempts to install the indicated ckpt file and returns True if successful.
@ -674,6 +667,8 @@ class ModelManager(object):
then these will be derived from the weight file name. If you provide a commit_to_conf then these will be derived from the weight file name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the path to the configuration file, then the new entry will be committed to the
models.yaml file. models.yaml file.
Return value is the name of the imported file, or None if an error occurred.
""" """
if str(weights).startswith(("http:", "https:")): if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights) model_name = model_name or url_attachment_name(weights)
@ -682,9 +677,9 @@ class ModelManager(object):
config_path = self._resolve_path(config, "configs/stable-diffusion") config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists(): if weights_path is None or not weights_path.exists():
return False return
if config_path is None or not config_path.exists(): if config_path is None or not config_path.exists():
return False return
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = ( model_description = (
@ -703,17 +698,74 @@ class ModelManager(object):
self.add_model(model_name, new_config, True) self.add_model(model_name, new_config, True)
if commit_to_conf: if commit_to_conf:
self.commit(commit_to_conf) self.commit(commit_to_conf)
return True return model_name
@classmethod
def probe_model_type(self, checkpoint: dict)->SDLegacyType:
'''
Given a pickle or safetensors model object, probes contents
of the object and returns an SDLegacyType indicating its
format. Valid return values include:
SDLegacyType.V1
SDLegacyType.V1_INPAINT
SDLegacyType.V2
UNKNOWN
'''
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
return SDLegacyType.V2
try:
state_dict = checkpoint.get('state_dict') or checkpoint
in_channels = state_dict['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
if in_channels == 9:
return SDLegacyType.V1_INPAINT
elif in_channels == 4:
return SDLegacyType.V1
else:
return SDLegacyType.UNKNOWN
except KeyError:
return SDLegacyType.UNKNOWN
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
convert: bool= False, convert: bool= False,
model_name: str = None,
description: str = None,
commit_to_conf: Path=None, commit_to_conf: Path=None,
): )->str:
'''
Accept a string which could be:
- a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file
- a local directory containing .ckpt and .safetensors files
- a local directory containing a diffusers model
After determining the nature of the model and downloading it
(if necessary), the file is probed to determine the correct
configuration file (if needed) and it is imported.
The model_name and/or description can be provided. If not, they will
be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory.
The (potentially derived) name of the model is returned on success, or None
on failure. When multiple models are added from a directory, only the last
imported one is returned.
'''
model_path: Path = None model_path: Path = None
thing = path_url_or_repo # to save typing thing = path_url_or_repo # to save typing
print(f'>> Probing {thing} for import')
if thing.startswith(('http:','https:','ftp:')): if thing.startswith(('http:','https:','ftp:')):
print(f' | {thing} appears to be a URL') print(f' | {thing} appears to be a URL')
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
@ -724,17 +776,20 @@ class ModelManager(object):
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists(): elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
print(f' | {thing} appears to be a diffusers file on disk') print(f' | {thing} appears to be a diffusers file on disk')
model_name = self.import_diffusers_model( model_name = self.import_diffuser_model(
thing, thing,
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'), vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
model_name=model_name,
description=description,
commit_to_conf=commit_to_conf commit_to_conf=commit_to_conf
) )
elif Path(thing).is_dir(): elif Path(thing).is_dir():
print(f'>> {thing} appears to be a directory. Will scan for models to import') print(f'>> {thing} appears to be a directory. Will scan for models to import')
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')): for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf) if model_name := self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf):
return print(f' >> {model_name} successfully imported')
return model_name
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing): elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
print(f' | {thing} appears to be a HuggingFace diffusers repo_id') print(f' | {thing} appears to be a HuggingFace diffusers repo_id')
@ -750,53 +805,49 @@ class ModelManager(object):
return return
if model_path.stem in self.config: #already imported if model_path.stem in self.config: #already imported
print(' > Already imported. Skipping') print(' | Already imported. Skipping')
return return
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
model_config_file = None
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path) checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
model_type = self.probe_model_type(checkpoint)
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" model_config_file = None
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if model_type == SDLegacyType.V1:
print(' > SD-v2 model detected; model will be converted to diffusers format') print(' | SD-v1 model detected')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
elif model_type == SDLegacyType.V1_INPAINT:
print(' | SD-v1 inpainting model detected')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
elif model_type == SDLegacyType.V2:
print(' | SD-v2 model detected; model will be converted to diffusers format')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml') model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
convert = True convert = True
if not model_config_file: # still trying
in_channels = None
try:
state_dict = checkpoint.get('state_dict') or checkpoint
in_channels = state_dict['model.diffusion_model.input_blocks.0.0.weight'].shape[1]
if in_channels == 9:
print(' > SD-v1 inpainting model detected')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
elif in_channels == 4:
print(' > SD-v1 model detected')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
else: else:
print(f'** {thing} does not have an expected number of in_channels ({in_channels}). It will probably break when loaded.') print(f'** {thing} is a legacy checkpoint file of unkown format. Will treat as a regular v1.X model')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
except KeyError:
print(f'** {thing} does not have the expected SD-v1 model fields. It will probably break when loaded.')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
if convert: if convert:
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem) diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
self.convert_and_import( model_name = self.convert_and_import(
model_path, model_path,
diffusers_path=diffuser_path, diffusers_path=diffuser_path,
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'), vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
model_name=model_name,
model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
commit_to_conf=commit_to_conf, commit_to_conf=commit_to_conf,
) )
else: else:
self.import_ckpt_model( model_name = self.import_ckpt_model(
model_path, model_path,
config=model_config_file, config=model_config_file,
model_name=model_name,
model_description=description,
vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')), vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')),
commit_to_conf=commit_to_conf, commit_to_conf=commit_to_conf,
) )
return model_name
def convert_and_import( def convert_and_import(
self, self,
@ -812,6 +863,10 @@ class ModelManager(object):
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
into models.yaml. into models.yaml.
""" """
ckpt_path = self._resolve_path(ckpt_path, 'models/ldm/stable-diffusion-v1')
if original_config_file:
original_config_file = self._resolve_path(original_config_file, 'configs/stable-diffusion')
new_config = None new_config = None
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
@ -857,7 +912,7 @@ class ModelManager(object):
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)" "** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
) )
return new_config return model_name
def search_models(self, search_folder): def search_models(self, search_folder):
print(f">> Finding Models In: {search_folder}") print(f">> Finding Models In: {search_folder}")