partial rewrite of checkpoint template creator

This commit is contained in:
Lincoln Stein
2023-08-16 21:21:42 -04:00
parent e83d00595d
commit 916cc26193

View File

@ -10,11 +10,21 @@ import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.base import read_checkpoint_meta, ModelType
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument(type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
parser.add_argument("--base-type",
type=str,
choices=['sd-1','sd-2','sdxl'],
help="Base type of the model",
)
parser.add_argument("--model-type",
type=str,
choices=[x.value for x in ModelType],
help="Base type of the model",
)
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
@ -26,9 +36,15 @@ tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
meta = {
'base_type': opt.base_type,
'model_type': opt.model_type,
'template': tmpl
}
try:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
json.dump(meta, f)
print(f"Template written out as {opt.template}")
except Exception as e:
print(f"An exception occurred while writing template: {str(e)}")