reformat with black and isort

This commit is contained in:
Lincoln Stein
2023-02-21 14:12:57 -05:00
parent 4878c7a2d5
commit 5a4967582e
3 changed files with 307 additions and 225 deletions

View File

@ -177,6 +177,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
print(f"Error downloading {label} model", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# ---------------------------------------------
# this will preload the Bert tokenizer fles
def download_bert():
@ -284,37 +285,36 @@ def download_safety_checker():
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
print("...success", file=sys.stderr)
# -------------------------------------
def download_vaes(precision: str):
print("Installing stabilityai VAE...", file=sys.stderr)
try:
# first the diffusers version
repo_id = 'stabilityai/sd-vae-ft-mse'
repo_id = "stabilityai/sd-vae-ft-mse"
args = dict(
cache_dir=global_cache_dir('diffusers'),
cache_dir=global_cache_dir("diffusers"),
)
if precision=='float16':
args.update(
torch_dtype=torch.float16,
revision='fp16'
)
if precision == "float16":
args.update(torch_dtype=torch.float16, revision="fp16")
if not AutoencoderKL.from_pretrained(repo_id, **args):
raise Exception(f'download of {repo_id} failed')
raise Exception(f"download of {repo_id} failed")
repo_id = 'stabilityai/sd-vae-ft-mse-original'
model_name = 'vae-ft-mse-840000-ema-pruned.ckpt'
repo_id = "stabilityai/sd-vae-ft-mse-original"
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
# next the legacy checkpoint version
if not hf_download_with_resume(
repo_id = repo_id,
model_name = model_name,
model_dir = str(Globals.root / Model_dir / Weights_dir)
repo_id=repo_id,
model_name=model_name,
model_dir=str(Globals.root / Model_dir / Weights_dir),
):
raise Exception(f'download of {model_name} failed')
raise Exception(f"download of {model_name} failed")
print("...downloaded successfully", file=sys.stderr)
except Exception as e:
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# -------------------------------------
def get_root(root: str = None) -> str:
if root:
@ -329,7 +329,7 @@ class editOptsForm(npyscreen.FormMultiPage):
def create(self):
program_opts = self.parentApp.program_opts
old_opts = self.parentApp.invokeai_opts
with open('log.txt','w') as f:
with open("log.txt", "w") as f:
f.write(str(old_opts))
first_time = not (Globals.root / Globals.initfile).exists()
access_token = HfFolder.get_token()
@ -576,14 +576,14 @@ class editOptsForm(npyscreen.FormMultiPage):
new_opts = Namespace()
for attr in [
"outdir",
"safety_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers",
"always_use_cpu",
"embedding_path",
"ckpt_convert",
"outdir",
"safety_checker",
"free_gpu_mem",
"max_loaded_models",
"xformers",
"always_use_cpu",
"embedding_path",
"ckpt_convert",
]:
setattr(new_opts, attr, getattr(self, attr).value)
@ -672,7 +672,9 @@ def initialize_rootdir(root: str, yes_to_all: bool = False):
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path=None) -> (Namespace, Namespace):
def run_console_ui(
program_opts: Namespace, initfile: Path = None
) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
editApp = EditOptApplication(program_opts, invokeai_opts)
@ -747,6 +749,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
opt.hf_token = HfFolder.get_token()
write_opts(opt, initfile)
# -------------------------------------
def main():
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
@ -816,7 +819,9 @@ def main():
if opt.yes_to_all:
write_default_options(opt, init_file)
init_options = Namespace(precision='float32' if opt.full_precision else 'float16')
init_options = Namespace(
precision="float32" if opt.full_precision else "float16"
)
else:
init_options, models_to_download = run_console_ui(opt, init_file)
if init_options: