fully functional and ready for review

- quashed multiple bugs in model conversion and importing
- found old issue in handling of resume of interrupted downloads
- will require extensive testing
This commit is contained in:
Lincoln Stein 2023-02-16 03:22:25 -05:00
parent 07be605dcb
commit b1341bc611
7 changed files with 91 additions and 50 deletions

View File

@ -142,12 +142,11 @@ def main():
report_model_error(opt, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := opt.autoimport:
gen.model_manager.heuristic_import(str(path), convert=False, commit_to_conf=opt.conf)
if path := opt.autoconvert:
gen.model_manager.autoconvert_weights(
conf_path=opt.conf,
weights_directory=path,
)
gen.model_manager.heuristic_import(str(path), convert=True, commit_to_conf=opt.conf)
# web server loops forever
if opt.web or opt.gui:
@ -581,7 +580,7 @@ def import_model(model_path: str, gen, opt, completer):
(3) a huggingface repository id; or (4) a local directory containing a
diffusers model.
"""
model.path = model_path.replace('\\','/') # windows
model_path = model_path.replace('\\','/') # windows
model_name = None
if model_path.startswith(('http:','https:','ftp:')):
@ -653,7 +652,7 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
print(f'>> Model {model.stem} imported successfully')
model_names.append(model_name)
else:
printf('** Model {model} failed to import')
print(f'** Model {model} failed to import')
print()
return model_names
@ -709,7 +708,8 @@ def import_ckpt_model(
vae = input('VAE file for this model (leave blank for none): ').strip() or None
done = (not vae) or os.path.exists(vae)
completer.complete_extensions(None)
config_file = _ask_for_config_file(path_or_url, completer)
if not manager.import_ckpt_model(
path_or_url,
config = config_file,
@ -786,7 +786,8 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager
ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **")
return

View File

@ -527,11 +527,17 @@ class Args(object):
default=False,
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
)
model_group.add_argument(
'--autoimport',
default=None,
type=str,
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
)
model_group.add_argument(
'--autoconvert',
default=None,
type=str,
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
)
model_group.add_argument(
'--patchmatch',

View File

@ -31,10 +31,8 @@ from transformers import (
)
import invokeai.configs as configs
from ldm.invoke.config.model_install import (
download_from_hf,
select_and_download_models,
)
from ldm.invoke.config.model_install_backend import download_from_hf
from ldm.invoke.config.model_install import select_and_download_models
from ldm.invoke.globals import Globals, global_config_dir
from ldm.invoke.readline import generic_completer

View File

@ -100,7 +100,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== UNINSTALLED STARTER MODELS (recommended models selected) ==",
name="== STARTER MODELS (recommended ones selected) ==",
value="Select from a starter set of Stable Diffusion models from HuggingFace:",
begin_entry_at=2,
editable=False,
@ -221,6 +221,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
'''
# starter models to install/remove
starter_models = dict(map(lambda x: (self.starter_model_list[x], True), self.models_selected.value))
self.parentApp.purge_deleted_models=False
if hasattr(self,'previously_installed_models'):
unchecked = [
self.previously_installed_models.values[x]
@ -243,7 +244,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
# URLs and the like
self.parentApp.import_model_paths = self.import_model_paths.value.split()
self.parentApp.convert_to_diffusers = self.convert_models.value == 1
self.parentApp.convert_to_diffusers = self.convert_models.value[0] == 1
# big chunk of dead code
# was intended to be a status area in which output of installation steps (including tqdm) was logged in real time

View File

@ -69,6 +69,9 @@ def install_requested_models(
config_file_path: Path = None,
):
config_file_path=config_file_path or default_config_file()
if not config_file_path.exists():
open(config_file_path,'w')
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
if remove_models and len(remove_models) > 0:
@ -84,12 +87,20 @@ def install_requested_models(
models=install_initial_models,
access_token=None,
precision=precision,
) # for historical reasons, we don't use model manager here
) # FIX: for historical reasons, we don't use model manager here
update_config_file(successfully_downloaded, config_file_path)
if len(successfully_downloaded) < len(install_initial_models):
print("** Some of the model downloads were not successful")
if external_models and len(external_models)>0:
# due to above, we have to reload the model manager because conf file
# was changed behind its back
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
external_models = external_models or list()
if scan_directory:
external_models.append(str(scan_directory))
if len(external_models)>0:
print("== INSTALLING EXTERNAL MODELS ==")
for path_url_or_repo in external_models:
try:
@ -102,6 +113,18 @@ def install_requested_models(
sys.exit(-1)
except Exception:
pass
if scan_at_startup and scan_directory.is_dir():
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f'{Globals.initfile}.new')
with open(initfile,'r') as input:
with open(replacement,'w') as output:
while line := input.readline():
if not line.startswith(argument):
output.writelines([line])
output.writelines([f'{argument} {str(scan_directory)}'])
os.replace(replacement,initfile)
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):

View File

@ -707,21 +707,19 @@ class ModelManager(object):
convert: bool= False,
commit_to_conf: Path=None,
):
model_path = None
model_path: Path = None
thing = path_url_or_repo # to save typing
print(f'here i am; thing={thing}, convert={convert}')
if thing.startswith(('http:','https:','ftp:')):
print(f'* {thing} appears to be a URL')
print(f'>> {thing} appears to be a URL')
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
elif Path(thing).is_file() and thing.endswith(('.ckpt','.safetensors')):
print(f'* {thing} appears to be a checkpoint file on disk')
print(f'>> {thing} appears to be a checkpoint file on disk')
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1')
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
print(f'* {thing} appears to be a diffusers file on disk')
print(f'>> {thing} appears to be a diffusers file on disk')
model_name = self.import_diffusers_model(
thing,
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
@ -729,39 +727,44 @@ class ModelManager(object):
)
elif Path(thing).is_dir():
print(f'* {thing} appears to be a directory. Will scan for models to import')
print(f'>> {thing} appears to be a directory. Will scan for models to import')
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
print('***',m)
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf)
return
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
print(f'* {thing} appears to be a HuggingFace diffusers repo_id')
print(f'>> {thing} appears to be a HuggingFace diffusers repo_id')
model_name = self.import_diffuser_model(thing, commit_to_conf=commit_to_conf)
pipeline,_,_,_ = self._load_diffusers_model(self.config[model_name])
else:
print(f"* {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
print(f">> {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
# Model_path is set in the event of a legacy checkpoint file.
# If not set, we're all done
if not model_path:
return
if model_path.stem in self.config: #already imported
return
# another round of heuristics to guess the correct config file.
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
print(f'* {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
print(f'>> {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
convert = True
elif re.search('inpaint', str(model_path), flags=re.IGNORECASE):
print(f'* {thing} appears to be an SD-v1 inpainting model')
print(f'>> {thing} appears to be an SD-v1 inpainting model')
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
else:
print(f'* {thing} appears to be an SD-v1 model')
print(f'>> {thing} appears to be an SD-v1 model')
if convert:
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
@ -776,10 +779,12 @@ class ModelManager(object):
self.import_ckpt_model(
model_path,
config=model_config_file,
vae=Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt'),
vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')),
commit_to_conf=commit_to_conf,
)
# this is a defunct method, superseded by heuristic_import()
# left here during transition
def autoconvert_weights (
self,
conf_path: Path,
@ -799,7 +804,7 @@ class ModelManager(object):
ckpt_files = dict()
for root, dirs, files in os.walk(weights_directory):
for f in files:
if not f.endswith(".ckpt"):
if not f.endswith((".ckpt",".safetensors")):
continue
basename = Path(f).stem
dest = Path(dest_directory, basename)

View File

@ -306,8 +306,12 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
dest/filename
:param access_token: Access token to access this resource
'''
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
resp = requests.get(url, header, stream=True)
content_length = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
@ -318,39 +322,42 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
else:
dest.parent.mkdir(parents=True, exist_ok=True)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if dest.exists():
exist_size = dest.stat().st_size
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
print('* corrupt existing file found. re-downloading')
os.remove(dest)
exist_size = 0
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
resp.status_code == 416 or exist_size == content_length
):
print(f"* {dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
elif exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
else:
print(f"* {dest}: Downloading...")
try:
if total < 2000:
if content_length < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
desc=str(dest),
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
desc=str(dest),
initial=exist_size,
total=content_length,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)