From eb70bc2ae4c12f128148147e02b6aefd45874bff Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 7 Aug 2023 21:00:47 -0400 Subject: [PATCH] add scripts to create model templates and check whether they match --- scripts/create_checkpoint_template.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100755 scripts/create_checkpoint_template.py diff --git a/scripts/create_checkpoint_template.py b/scripts/create_checkpoint_template.py new file mode 100755 index 0000000000..5b8fca8b58 --- /dev/null +++ b/scripts/create_checkpoint_template.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +""" +Read a checkpoint/safetensors file and write out a template .json file containing +its metadata for use in fast model probing. +""" + +import sys +import argparse +import json + +from pathlib import Path + +from invokeai.backend.model_management.models.base import read_checkpoint_meta + +parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model") +parser.add_argument( + "--checkpoint", + "--in", + type=Path, + help="Path to the input checkpoint/safetensors file" +) +parser.add_argument( + "--template", + "--out", + type=Path, + help="Path to the output .json file" +) + +opt = parser.parse_args() +ckpt = read_checkpoint_meta(opt.checkpoint) +while "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + +tmpl = {} + +for key, tensor in ckpt.items(): + tmpl[key] = list(tensor.shape) + +try: + with open(opt.template,'w') as f: + json.dump(tmpl, f) + print(f"Template written out as {opt.template}") +except Exception as e: + print(f"An exception occurred while writing template: {str(e)}") + + + +