From 916cc261932256c35ffbbbd6ea11a3c80e47b141 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 16 Aug 2023 21:21:42 -0400 Subject: [PATCH] partial rewrite of checkpoint template creator --- scripts/create_checkpoint_template.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py index 7ff201c841..4226c46d8a 100755 --- a/scripts/create_checkpoint_template.py +++ b/scripts/create_checkpoint_template.py @@ -10,11 +10,21 @@ import json from pathlib import Path -from invokeai.backend.model_management.models.base import read_checkpoint_meta +from invokeai.backend.model_management.models.base import read_checkpoint_meta, ModelType parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") -parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file") +parser.add_argument(type=Path, help="Path to the input checkpoint/safetensors file") parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file") +parser.add_argument("--base-type", + type=str, + choices=['sd-1','sd-2','sdxl'], + help="Base type of the model", + ) +parser.add_argument("--model-type", + type=str, + choices=[x.value for x in ModelType], + help="Base type of the model", + ) opt = parser.parse_args() ckpt = read_checkpoint_meta(opt.checkpoint) @@ -26,9 +36,15 @@ tmpl = {} for key, tensor in ckpt.items(): tmpl[key] = list(tensor.shape) +meta = { + 'base_type': opt.base_type, + 'model_type': opt.model_type, + 'template': tmpl +} + try: with open(opt.template, "w") as f: - json.dump(tmpl, f) + json.dump(meta, f) print(f"Template written out as {opt.template}") except Exception as e: print(f"An exception occurred while writing template: {str(e)}")