2023-08-08 01:01:48 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
"""
|
|
|
|
Read a checkpoint/safetensors file and compare it to a template .json.
|
|
|
|
Returns True if their metadata match.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
|
|
|
|
2023-08-10 20:00:33 +00:00
|
|
|
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
|
2023-08-08 01:01:59 +00:00
|
|
|
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
|
|
|
|
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
|
2023-08-08 01:01:48 +00:00
|
|
|
|
|
|
|
opt = parser.parse_args()
|
|
|
|
ckpt = read_checkpoint_meta(opt.checkpoint)
|
|
|
|
while "state_dict" in ckpt:
|
|
|
|
ckpt = ckpt["state_dict"]
|
|
|
|
|
|
|
|
checkpoint_metadata = {}
|
|
|
|
|
|
|
|
for key, tensor in ckpt.items():
|
|
|
|
checkpoint_metadata[key] = list(tensor.shape)
|
|
|
|
|
2023-08-08 01:01:59 +00:00
|
|
|
with open(opt.template, "r") as f:
|
2023-08-08 01:01:48 +00:00
|
|
|
template = json.load(f)
|
|
|
|
|
|
|
|
if checkpoint_metadata == template:
|
2023-08-08 01:01:59 +00:00
|
|
|
print("True")
|
2023-08-08 01:01:48 +00:00
|
|
|
sys.exit(0)
|
|
|
|
else:
|
2023-08-08 01:01:59 +00:00
|
|
|
print("False")
|
2023-08-08 01:01:48 +00:00
|
|
|
sys.exit(-1)
|