From 05fac594eaf79d0058e3c48deee93df603f136c2 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 16 Dec 2022 02:07:49 +0100 Subject: [PATCH] tweak error checking --- ldm/modules/embedding_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 603c23a94a..97b98dc5e7 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -157,8 +157,10 @@ class TextualInversionManager(): """ if prompt_embeddings.shape[0] != self.clip_embedder.max_length: # typically 77 raise ValueError(f"prompt_embeddings must have {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") - if len(prompt_token_ids) != self.clip_embedder.max_length: - raise ValueError(f"prompt_token_ids must be fully padded out to {self.clip_embedder.max_length} entries (has: {prompt_embeddings.shape[0]})") + if len(prompt_token_ids) > self.clip_embedder.max_length: + raise ValueError(f"prompt_token_ids is too long (has {len(prompt_token_ids)} token ids, should have {self.clip_embedder.max_length})") + if len(prompt_token_ids) < self.clip_embedder.max_length: + raise ValueError(f"prompt_token_ids is too short (has {len(prompt_token_ids)} token ids, it must be fully padded out to {self.clip_embedder.max_length} entries)") if prompt_token_ids[0] != self.clip_embedder.tokenizer.bos_token_id or prompt_token_ids[-1] != self.clip_embedder.tokenizer.eos_token_id: raise ValueError("prompt_token_ids must start with with bos token id and end with the eos token id")