correctly detect when an embedding is incompatible with the current model

- Fixed the test for token length; tested on several .pt and .bin files
- Also added a __main__ entrypoint for CLI.py, to make pdb debugging a bit
  more convenient.
This commit is contained in:
Lincoln Stein 2023-02-19 22:30:57 -05:00
parent 7eafcd47a6
commit 172ce3dc25
2 changed files with 5 additions and 1 deletions

View File

@ -1387,3 +1387,7 @@ def check_internet() -> bool:
return True return True
except: except:
return False return False
if __name__ == '__main__':
main()

View File

@ -84,7 +84,7 @@ class TextualInversionManager:
return return
elif ( elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0] self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["embedding"].shape[0] != embedding_info['token_dim']
): ):
print( print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model." f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."