mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fully functional and ready for review
- quashed multiple bugs in model conversion and importing - found old issue in handling of resume of interrupted downloads - will require extensive testing
This commit is contained in:
parent
07be605dcb
commit
b1341bc611
@ -142,12 +142,11 @@ def main():
|
||||
report_model_error(opt, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := opt.autoimport:
|
||||
gen.model_manager.heuristic_import(str(path), convert=False, commit_to_conf=opt.conf)
|
||||
|
||||
if path := opt.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
conf_path=opt.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
gen.model_manager.heuristic_import(str(path), convert=True, commit_to_conf=opt.conf)
|
||||
|
||||
# web server loops forever
|
||||
if opt.web or opt.gui:
|
||||
@ -581,7 +580,7 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
(3) a huggingface repository id; or (4) a local directory containing a
|
||||
diffusers model.
|
||||
"""
|
||||
model.path = model_path.replace('\\','/') # windows
|
||||
model_path = model_path.replace('\\','/') # windows
|
||||
model_name = None
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
@ -653,7 +652,7 @@ def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
||||
print(f'>> Model {model.stem} imported successfully')
|
||||
model_names.append(model_name)
|
||||
else:
|
||||
printf('** Model {model} failed to import')
|
||||
print(f'** Model {model} failed to import')
|
||||
print()
|
||||
return model_names
|
||||
|
||||
@ -709,7 +708,8 @@ def import_ckpt_model(
|
||||
vae = input('VAE file for this model (leave blank for none): ').strip() or None
|
||||
done = (not vae) or os.path.exists(vae)
|
||||
completer.complete_extensions(None)
|
||||
|
||||
config_file = _ask_for_config_file(path_or_url, completer)
|
||||
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config = config_file,
|
||||
@ -786,7 +786,8 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
|
||||
original_config_file = None
|
||||
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
return
|
||||
|
@ -527,11 +527,17 @@ class Args(object):
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
|
@ -31,10 +31,8 @@ from transformers import (
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ldm.invoke.config.model_install import (
|
||||
download_from_hf,
|
||||
select_and_download_models,
|
||||
)
|
||||
from ldm.invoke.config.model_install_backend import download_from_hf
|
||||
from ldm.invoke.config.model_install import select_and_download_models
|
||||
from ldm.invoke.globals import Globals, global_config_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
|
@ -100,7 +100,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="== UNINSTALLED STARTER MODELS (recommended models selected) ==",
|
||||
name="== STARTER MODELS (recommended ones selected) ==",
|
||||
value="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
||||
begin_entry_at=2,
|
||||
editable=False,
|
||||
@ -221,6 +221,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
'''
|
||||
# starter models to install/remove
|
||||
starter_models = dict(map(lambda x: (self.starter_model_list[x], True), self.models_selected.value))
|
||||
self.parentApp.purge_deleted_models=False
|
||||
if hasattr(self,'previously_installed_models'):
|
||||
unchecked = [
|
||||
self.previously_installed_models.values[x]
|
||||
@ -243,7 +244,7 @@ class addModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
# URLs and the like
|
||||
self.parentApp.import_model_paths = self.import_model_paths.value.split()
|
||||
self.parentApp.convert_to_diffusers = self.convert_models.value == 1
|
||||
self.parentApp.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
# big chunk of dead code
|
||||
# was intended to be a status area in which output of installation steps (including tqdm) was logged in real time
|
||||
|
@ -69,6 +69,9 @@ def install_requested_models(
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
config_file_path=config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path,'w')
|
||||
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
@ -84,12 +87,20 @@ def install_requested_models(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # for historical reasons, we don't use model manager here
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
|
||||
if external_models and len(external_models)>0:
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
external_models = external_models or list()
|
||||
if scan_directory:
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
@ -102,6 +113,18 @@ def install_requested_models(
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
with open(initfile,'r') as input:
|
||||
with open(replacement,'w') as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f'{argument} {str(scan_directory)}'])
|
||||
os.replace(replacement,initfile)
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
|
@ -707,21 +707,19 @@ class ModelManager(object):
|
||||
convert: bool= False,
|
||||
commit_to_conf: Path=None,
|
||||
):
|
||||
model_path = None
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
print(f'here i am; thing={thing}, convert={convert}')
|
||||
|
||||
if thing.startswith(('http:','https:','ftp:')):
|
||||
print(f'* {thing} appears to be a URL')
|
||||
print(f'>> {thing} appears to be a URL')
|
||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1') # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith(('.ckpt','.safetensors')):
|
||||
print(f'* {thing} appears to be a checkpoint file on disk')
|
||||
print(f'>> {thing} appears to be a checkpoint file on disk')
|
||||
model_path = self._resolve_path(thing, 'models/ldm/stable-diffusion-v1')
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, 'model_index.json').exists():
|
||||
print(f'* {thing} appears to be a diffusers file on disk')
|
||||
print(f'>> {thing} appears to be a diffusers file on disk')
|
||||
model_name = self.import_diffusers_model(
|
||||
thing,
|
||||
vae=dict(repo_id='stabilityai/sd-vae-ft-mse'),
|
||||
@ -729,39 +727,44 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
print(f'* {thing} appears to be a directory. Will scan for models to import')
|
||||
print(f'>> {thing} appears to be a directory. Will scan for models to import')
|
||||
for m in list(Path(thing).rglob('*.ckpt')) + list(Path(thing).rglob('*.safetensors')):
|
||||
print('***',m)
|
||||
self.heuristic_import(str(m), convert, commit_to_conf=commit_to_conf)
|
||||
return
|
||||
|
||||
elif re.match(r'^[\w.+-]+/[\w.+-]+$', thing):
|
||||
print(f'* {thing} appears to be a HuggingFace diffusers repo_id')
|
||||
print(f'>> {thing} appears to be a HuggingFace diffusers repo_id')
|
||||
model_name = self.import_diffuser_model(thing, commit_to_conf=commit_to_conf)
|
||||
pipeline,_,_,_ = self._load_diffusers_model(self.config[model_name])
|
||||
|
||||
else:
|
||||
print(f"* {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
print(f">> {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
if not model_path:
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: #already imported
|
||||
return
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
|
||||
checkpoint = safetensors.torch.load_file(model_path) if model_path.suffix == '.safetensors' else torch.load(model_path)
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
print(f'* {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
|
||||
print(f'>> {thing} appears to be an SD-v2 model; model will be converted to diffusers format')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v2-inference-v.yaml')
|
||||
convert = True
|
||||
|
||||
elif re.search('inpaint', str(model_path), flags=re.IGNORECASE):
|
||||
print(f'* {thing} appears to be an SD-v1 inpainting model')
|
||||
print(f'>> {thing} appears to be an SD-v1 inpainting model')
|
||||
model_config_file = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml')
|
||||
|
||||
else:
|
||||
print(f'* {thing} appears to be an SD-v1 model')
|
||||
print(f'>> {thing} appears to be an SD-v1 model')
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir, model_path.stem)
|
||||
@ -776,10 +779,12 @@ class ModelManager(object):
|
||||
self.import_ckpt_model(
|
||||
model_path,
|
||||
config=model_config_file,
|
||||
vae=Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt'),
|
||||
vae=str(Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')),
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
|
||||
|
||||
# this is a defunct method, superseded by heuristic_import()
|
||||
# left here during transition
|
||||
def autoconvert_weights (
|
||||
self,
|
||||
conf_path: Path,
|
||||
@ -799,7 +804,7 @@ class ModelManager(object):
|
||||
ckpt_files = dict()
|
||||
for root, dirs, files in os.walk(weights_directory):
|
||||
for f in files:
|
||||
if not f.endswith(".ckpt"):
|
||||
if not f.endswith((".ckpt",".safetensors")):
|
||||
continue
|
||||
basename = Path(f).stem
|
||||
dest = Path(dest_directory, basename)
|
||||
|
39
ldm/util.py
39
ldm/util.py
@ -306,8 +306,12 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
dest/filename
|
||||
:param access_token: Access token to access this resource
|
||||
'''
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = requests.get(url, header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
@ -318,39 +322,42 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if dest.exists():
|
||||
exist_size = dest.stat().st_size
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
print('* corrupt existing file found. re-downloading')
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
resp.status_code == 416 or exist_size == content_length
|
||||
):
|
||||
print(f"* {dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {dest}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
if content_length < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=content_length,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
|
Loading…
Reference in New Issue
Block a user