mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fully functional and ready for review
- quashed multiple bugs in model conversion and importing - found old issue in handling of resume of interrupted downloads - will require extensive testing
This commit is contained in:
parent
07be605dcb
commit
b1341bc611
@ -142,12 +142,11 @@ def main():
|
|||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
# autoimport new .ckpt files
|
if path := opt.autoimport:
|
||||||
|
gen.model_manager.heuristic_import(str(path), convert=False, commit_to_conf=opt.conf)
|
||||||
|
|
||||||
if path := opt.autoconvert:
|
if path := opt.autoconvert:
|
||||||
gen.model_manager.autoconvert_weights(
|
gen.model_manager.heuristic_import(str(path), convert=True, commit_to_conf=opt.conf)
|
||||||
conf_path=opt.conf,
|
|
||||||
weights_directory=path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# web server loops forever
|
# web server loops forever
|
||||||
if opt.web or opt.gui:
|
if opt.web or opt.gui:
|
||||||
@ -581,7 +580,7 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
(3) a huggingface repository id; or (4) a local directory containing a
|
(3) a huggingface repository id; or (4) a local directory containing a
|
||||||
diffusers model.
|
diffusers model.
|
||||||
"""
|
"""
|
||||||
model.path = model_path.replace('\\','/') # windows
|
model_path = model_path.replace('\\','/') # windows
|
||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
if model_path.startswith(('http:','https:','ftp:')):
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
@ -653,7 +652,7 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
|||||||
print(f'>> Model {model.stem} imported successfully')
|
print(f'>> Model {model.stem} imported successfully')
|
||||||
model_names.append(model_name)
|
model_names.append(model_name)
|
||||||
else:
|
else:
|
||||||
printf('** Model {model} failed to import')
|
print(f'** Model {model} failed to import')
|
||||||
print()
|
print()
|
||||||
return model_names
|
return model_names
|
||||||
|
|
||||||
@ -709,7 +708,8 @@ def import_ckpt_model(
|
|||||||
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
||||||
done = (not vae) or os.path.exists(vae)
|
done = (not vae) or os.path.exists(vae)
|
||||||
completer.complete_extensions(None)
|
completer.complete_extensions(None)
|
||||||
|
config_file = _ask_for_config_file(path_or_url, completer)
|
||||||
|
|
||||||
if not manager.import_ckpt_model(
|
if not manager.import_ckpt_model(
|
||||||
path_or_url,
|
path_or_url,
|
||||||
config = config_file,
|
config = config_file,
|
||||||
@ -786,7 +786,8 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
|||||||
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
|
original_config_file = None
|
||||||
|
|
||||||
if model_name_or_path == gen.model_name:
|
if model_name_or_path == gen.model_name:
|
||||||
print("** Can't convert the active model. !switch to another model first. **")
|
print("** Can't convert the active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
|
@ -527,11 +527,17 @@ class Args(object):
|
|||||||
default=False,
|
default=False,
|
||||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--autoimport',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--autoconvert',
|
'--autoconvert',
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--patchmatch',
|
'--patchmatch',
|
||||||
|
@ -31,10 +31,8 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
from ldm.invoke.config.model_install import (
|
from ldm.invoke.config.model_install_backend import download_from_hf
|
||||||
download_from_hf,
|
from ldm.invoke.config.model_install import select_and_download_models
|
||||||
select_and_download_models,
|
|
||||||
)
|
|
||||||
from ldm.invoke.globals import Globals, global_config_dir
|
from ldm.invoke.globals import Globals, global_config_dir
|
||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
)
|
)
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="== UNINSTALLED STARTER MODELS (recommended models selected) ==",
|
name="== STARTER MODELS (recommended ones selected) ==",
|
||||||
value="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
value="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
||||||
begin_entry_at=2,
|
begin_entry_at=2,
|
||||||
editable=False,
|
editable=False,
|
||||||
@ -221,6 +221,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
'''
|
'''
|
||||||
# starter models to install/remove
|
# starter models to install/remove
|
||||||
starter_models = dict(map(lambda x: (self.starter_model_list[x], True), self.models_selected.value))
|
starter_models = dict(map(lambda x: (self.starter_model_list[x], True), self.models_selected.value))
|
||||||
|
self.parentApp.purge_deleted_models=False
|
||||||
if hasattr(self,'previously_installed_models'):
|
if hasattr(self,'previously_installed_models'):
|
||||||
unchecked = [
|
unchecked = [
|
||||||
self.previously_installed_models.values[x]
|
self.previously_installed_models.values[x]
|
||||||
@ -243,7 +244,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
|
|
||||||
# URLs and the like
|
# URLs and the like
|
||||||
self.parentApp.import_model_paths = self.import_model_paths.value.split()
|
self.parentApp.import_model_paths = self.import_model_paths.value.split()
|
||||||
self.parentApp.convert_to_diffusers = self.convert_models.value == 1
|
self.parentApp.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||||
|
|
||||||
# big chunk of dead code
|
# big chunk of dead code
|
||||||
# was intended to be a status area in which output of installation steps (including tqdm) was logged in real time
|
# was intended to be a status area in which output of installation steps (including tqdm) was logged in real time
|
||||||
|
@ -69,6 +69,9 @@ def install_requested_models(
|
|||||||
config_file_path: Path = None,
|
config_file_path: Path = None,
|
||||||
):
|
):
|
||||||
config_file_path=config_file_path or default_config_file()
|
config_file_path=config_file_path or default_config_file()
|
||||||
|
if not config_file_path.exists():
|
||||||
|
open(config_file_path,'w')
|
||||||
|
|
||||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||||
|
|
||||||
if remove_models and len(remove_models) > 0:
|
if remove_models and len(remove_models) > 0:
|
||||||
@ -84,12 +87,20 @@ def install_requested_models(
|
|||||||
models=install_initial_models,
|
models=install_initial_models,
|
||||||
access_token=None,
|
access_token=None,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
) # for historical reasons, we don't use model manager here
|
) # FIX: for historical reasons, we don't use model manager here
|
||||||
update_config_file(successfully_downloaded, config_file_path)
|
update_config_file(successfully_downloaded, config_file_path)
|
||||||
if len(successfully_downloaded) < len(install_initial_models):
|
if len(successfully_downloaded) < len(install_initial_models):
|
||||||
print("** Some of the model downloads were not successful")
|
print("** Some of the model downloads were not successful")
|
||||||
|
|
||||||
if external_models and len(external_models)>0:
|
# due to above, we have to reload the model manager because conf file
|
||||||
|
# was changed behind its back
|
||||||
|
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||||
|
|
||||||
|
external_models = external_models or list()
|
||||||
|
if scan_directory:
|
||||||
|
external_models.append(str(scan_directory))
|
||||||
|
|
||||||
|
if len(external_models)>0:
|
||||||
print("== INSTALLING EXTERNAL MODELS ==")
|
print("== INSTALLING EXTERNAL MODELS ==")
|
||||||
for path_url_or_repo in external_models:
|
for path_url_or_repo in external_models:
|
||||||
try:
|
try:
|
||||||
@ -102,6 +113,18 @@ def install_requested_models(
|
|||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if scan_at_startup and scan_directory.is_dir():
|
||||||
|
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||||
|
initfile = Path(Globals.root, Globals.initfile)
|
||||||
|
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||||
|
with open(initfile,'r') as input:
|
||||||
|
with open(replacement,'w') as output:
|
||||||
|
while line := input.readline():
|
||||||
|
if not line.startswith(argument):
|
||||||
|
output.writelines([line])
|
||||||
|
output.writelines([f'{argument} {str(scan_directory)}'])
|
||||||
|
os.replace(replacement,initfile)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
def yes_or_no(prompt: str, default_yes=True):
|
||||||
|
@ -707,21 +707,19 @@ class ModelManager(object):
|
|||||||
convert: bool= False,
|
convert: bool= False,
|
||||||
commit_to_conf: Path=None,
|
commit_to_conf: Path=None,
|
||||||
):
|
):
|
||||||
model_path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
|
|
||||||
print(f'here i am; thing={thing}, convert={convert}')
|
|
||||||
|
|
||||||
if thing.startswith(('http:','https:','ftp:')):
|
if thing.startswith(('http:','https:','ftp:')):
|
||||||
print(f'* {thing} appears to be a URL')
|
print(f'>> {thing} appears to be a URL')
|
||||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
|
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith(('.ckpt','.safetensors')):
|
elif Path(thing).is_file() and thing.endswith(('.ckpt','.safetensors')):
|
||||||
print(f'* {thing} appears to be a checkpoint file on disk')
|
print(f'>> {thing} appears to be a checkpoint file on disk')
|
||||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1')
|
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1')
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
|
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
|
||||||
print(f'* {thing} appears to be a diffusers file on disk')
|
print(f'>> {thing} appears to be a diffusers file on disk')
|
||||||
model_name = self.import_diffusers_model(
|
model_name = self.import_diffusers_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
||||||
@ -729,39 +727,44 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
print(f'* {thing} appears to be a directory. Will scan for models to import')
|
print(f'>> {thing} appears to be a directory. Will scan for models to import')
|
||||||
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
|
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
|
||||||
print('***',m)
|
print('***',m)
|
||||||
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf)
|
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf)
|
||||||
return
|
return
|
||||||
|
|
||||||
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
|
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
|
||||||
print(f'* {thing} appears to be a HuggingFace diffusers repo_id')
|
print(f'>> {thing} appears to be a HuggingFace diffusers repo_id')
|
||||||
model_name = self.import_diffuser_model(thing, commit_to_conf=commit_to_conf)
|
model_name = self.import_diffuser_model(thing, commit_to_conf=commit_to_conf)
|
||||||
pipeline,_,_,_ = self._load_diffusers_model(self.config[model_name])
|
pipeline,_,_,_ = self._load_diffusers_model(self.config[model_name])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"* {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
print(f">> {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
# Model_path is set in the event of a legacy checkpoint file.
|
||||||
# If not set, we're all done
|
# If not set, we're all done
|
||||||
if not model_path:
|
if not model_path:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if model_path.stem in self.config: #already imported
|
||||||
|
return
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||||
|
|
||||||
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
|
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
print(f'* {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
|
print(f'>> {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
|
||||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
|
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
|
||||||
convert = True
|
convert = True
|
||||||
|
|
||||||
elif re.search('inpaint', str(model_path), flags=re.IGNORECASE):
|
elif re.search('inpaint', str(model_path), flags=re.IGNORECASE):
|
||||||
print(f'* {thing} appears to be an SD-v1 inpainting model')
|
print(f'>> {thing} appears to be an SD-v1 inpainting model')
|
||||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f'* {thing} appears to be an SD-v1 model')
|
print(f'>> {thing} appears to be an SD-v1 model')
|
||||||
|
|
||||||
if convert:
|
if convert:
|
||||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
|
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
|
||||||
@ -776,10 +779,12 @@ class ModelManager(object):
|
|||||||
self.import_ckpt_model(
|
self.import_ckpt_model(
|
||||||
model_path,
|
model_path,
|
||||||
config=model_config_file,
|
config=model_config_file,
|
||||||
vae=Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt'),
|
vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')),
|
||||||
commit_to_conf=commit_to_conf,
|
commit_to_conf=commit_to_conf,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this is a defunct method, superseded by heuristic_import()
|
||||||
|
# left here during transition
|
||||||
def autoconvert_weights (
|
def autoconvert_weights (
|
||||||
self,
|
self,
|
||||||
conf_path: Path,
|
conf_path: Path,
|
||||||
@ -799,7 +804,7 @@ class ModelManager(object):
|
|||||||
ckpt_files = dict()
|
ckpt_files = dict()
|
||||||
for root, dirs, files in os.walk(weights_directory):
|
for root, dirs, files in os.walk(weights_directory):
|
||||||
for f in files:
|
for f in files:
|
||||||
if not f.endswith(".ckpt"):
|
if not f.endswith((".ckpt",".safetensors")):
|
||||||
continue
|
continue
|
||||||
basename = Path(f).stem
|
basename = Path(f).stem
|
||||||
dest = Path(dest_directory, basename)
|
dest = Path(dest_directory, basename)
|
||||||
|
39
ldm/util.py
39
ldm/util.py
@ -306,8 +306,12 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
dest/filename
|
dest/filename
|
||||||
:param access_token: Access token to access this resource
|
:param access_token: Access token to access this resource
|
||||||
'''
|
'''
|
||||||
resp = requests.get(url, stream=True)
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||||
total = int(resp.headers.get("content-length", 0))
|
open_mode = "wb"
|
||||||
|
exist_size = 0
|
||||||
|
|
||||||
|
resp = requests.get(url, header, stream=True)
|
||||||
|
content_length = int(resp.headers.get("content-length", 0))
|
||||||
|
|
||||||
if dest.is_dir():
|
if dest.is_dir():
|
||||||
try:
|
try:
|
||||||
@ -318,39 +322,42 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
else:
|
else:
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
|
||||||
open_mode = "wb"
|
|
||||||
exist_size = 0
|
|
||||||
|
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
exist_size = dest.stat().st_size
|
exist_size = dest.stat().st_size
|
||||||
header["Range"] = f"bytes={exist_size}-"
|
header["Range"] = f"bytes={exist_size}-"
|
||||||
open_mode = "ab"
|
open_mode = "ab"
|
||||||
|
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||||
|
|
||||||
|
if exist_size > content_length:
|
||||||
|
print('* corrupt existing file found. re-downloading')
|
||||||
|
os.remove(dest)
|
||||||
|
exist_size = 0
|
||||||
|
|
||||||
if (
|
if (
|
||||||
resp.status_code == 416
|
resp.status_code == 416 or exist_size == content_length
|
||||||
): # "range not satisfiable", which means nothing to return
|
):
|
||||||
print(f"* {dest}: complete file found. Skipping.")
|
print(f"* {dest}: complete file found. Skipping.")
|
||||||
return dest
|
return dest
|
||||||
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
|
print(f"* {dest}: partial file found. Resuming...")
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||||
elif exist_size > 0:
|
|
||||||
print(f"* {dest}: partial file found. Resuming...")
|
|
||||||
else:
|
else:
|
||||||
print(f"* {dest}: Downloading...")
|
print(f"* {dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if total < 2000:
|
if content_length < 2000:
|
||||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
desc=str(dest),
|
desc=str(dest),
|
||||||
initial=exist_size,
|
initial=exist_size,
|
||||||
total=total + exist_size,
|
total=content_length,
|
||||||
unit="iB",
|
unit="iB",
|
||||||
unit_scale=True,
|
unit_scale=True,
|
||||||
unit_divisor=1000,
|
unit_divisor=1000,
|
||||||
) as bar:
|
) as bar:
|
||||||
for data in resp.iter_content(chunk_size=1024):
|
for data in resp.iter_content(chunk_size=1024):
|
||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user