This commit is contained in:
Lincoln Stein 2023-08-07 21:01:59 -04:00
parent 4df581811e
commit 750f09fbed
2 changed files with 8 additions and 36 deletions

View File

@ -13,18 +13,8 @@ from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument(
"--checkpoint",
"--in",
type=Path,
help="Path to the input checkpoint/safetensors file"
)
parser.add_argument(
"--template",
"--out",
type=Path,
help="Path to the output .json file"
)
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
@ -37,12 +27,8 @@ for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
try:
with open(opt.template,'w') as f:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
print(f"Template written out as {opt.template}")
except Exception as e:
print(f"An exception occurred while writing template: {str(e)}")

View File

@ -13,18 +13,8 @@ from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument(
"--checkpoint",
"--in",
type=Path,
help="Path to the input checkpoint/safetensors file"
)
parser.add_argument(
"--template",
"--out",
type=Path,
help="Path to the template .json file to match against"
)
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
@ -36,16 +26,12 @@ checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
with open(opt.template,'r') as f:
with open(opt.template, "r") as f:
template = json.load(f)
if checkpoint_metadata == template:
print('True')
print("True")
sys.exit(0)
else:
print('False')
print("False")
sys.exit(-1)