tweak error checking

This commit is contained in:
Damian Stewart 2022-12-16 02:07:49 +01:00
parent 009f32ed39
commit 05fac594ea

View File

@ -157,8 +157,10 @@ class TextualInversionManager():
""" """
if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77 if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77
raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})")
if len(prompt_token_ids) != self.clip_embedder.max_length: if len(prompt_token_ids) > self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids must be fully padded out to {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})")
if len(prompt_token_ids) < self.clip_embedder.max_length:
raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)")
if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id: if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id:
raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id") raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")