port new code for detecting sdxl-based embeddings

This commit is contained in:
Lincoln Stein 2023-12-10 15:48:02 -05:00
parent 3b1ff4a7f4
commit de2879f602

View File

@ -416,6 +416,8 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768: