InvokeAI/scripts/verify_checkpoint_template.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

38 lines
1.0 KiB
Python
Raw Normal View History

2023-08-08 01:01:48 +00:00
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and compare it to a template .json.
Returns True if their metadata match.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
2023-08-08 01:01:59 +00:00
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
2023-08-08 01:01:48 +00:00
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
2023-08-08 01:01:59 +00:00
with open(opt.template, "r") as f:
2023-08-08 01:01:48 +00:00
template = json.load(f)
if checkpoint_metadata == template:
2023-08-08 01:01:59 +00:00
print("True")
2023-08-08 01:01:48 +00:00
sys.exit(0)
else:
2023-08-08 01:01:59 +00:00
print("False")
2023-08-08 01:01:48 +00:00
sys.exit(-1)