mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reformat with black and isort
This commit is contained in:
@ -177,6 +177,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
print(f"Error downloading {label} model", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
@ -284,37 +285,36 @@ def download_safety_checker():
|
||||
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
|
||||
print("...success", file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_vaes(precision: str):
|
||||
print("Installing stabilityai VAE...", file=sys.stderr)
|
||||
try:
|
||||
# first the diffusers version
|
||||
repo_id = 'stabilityai/sd-vae-ft-mse'
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=global_cache_dir('diffusers'),
|
||||
cache_dir=global_cache_dir("diffusers"),
|
||||
)
|
||||
if precision=='float16':
|
||||
args.update(
|
||||
torch_dtype=torch.float16,
|
||||
revision='fp16'
|
||||
)
|
||||
if precision == "float16":
|
||||
args.update(torch_dtype=torch.float16, revision="fp16")
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f'download of {repo_id} failed')
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
|
||||
repo_id = 'stabilityai/sd-vae-ft-mse-original'
|
||||
model_name = 'vae-ft-mse-840000-ema-pruned.ckpt'
|
||||
repo_id = "stabilityai/sd-vae-ft-mse-original"
|
||||
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
|
||||
# next the legacy checkpoint version
|
||||
if not hf_download_with_resume(
|
||||
repo_id = repo_id,
|
||||
model_name = model_name,
|
||||
model_dir = str(Globals.root / Model_dir / Weights_dir)
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f'download of {model_name} failed')
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
print("...downloaded successfully", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
@ -329,7 +329,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
with open('log.txt','w') as f:
|
||||
with open("log.txt", "w") as f:
|
||||
f.write(str(old_opts))
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
access_token = HfFolder.get_token()
|
||||
@ -576,14 +576,14 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
@ -672,7 +672,9 @@ def initialize_rootdir(root: str, yes_to_all: bool = False):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path=None) -> (Namespace, Namespace):
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path = None
|
||||
) -> (Namespace, Namespace):
|
||||
# parse_args() will read from init file if present
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
@ -747,6 +749,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
@ -816,7 +819,9 @@ def main():
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, init_file)
|
||||
init_options = Namespace(precision='float32' if opt.full_precision else 'float16')
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
if init_options:
|
||||
|
Reference in New Issue
Block a user