diff --git a/invokeai/backend/textual_inversion.py b/invokeai/backend/textual_inversion.py index f7390979bb..368736617b 100644 --- a/invokeai/backend/textual_inversion.py +++ b/invokeai/backend/textual_inversion.py @@ -1,7 +1,7 @@ """Textual Inversion wrapper class.""" from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch from compel.embeddings_provider import BaseTextualInversionManager @@ -66,33 +66,47 @@ class TextualInversionModelRaw(RawModel): return result -# no type hints for BaseTextualInversionManager? -class TextualInversionManager(BaseTextualInversionManager): # type: ignore - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer +class TextualInversionManager(BaseTextualInversionManager): + """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} + self.pad_tokens: dict[int, list[int]] = {} self.tokenizer = tokenizer def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + """Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens. + + For example, suppose we have a `` TI with 4 vectors that was added to the tokenizer with the following + mapping of tokens to token_ids: + ``` + : 49408 + : 49409 + : 49410 + : 49411 + ``` + `self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`. + This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`. + """ + # Short circuit if there are no pad tokens to save a little time. if len(self.pad_tokens) == 0: return token_ids + # This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify + # this assumption here. if token_ids[0] == self.tokenizer.bos_token_id: raise ValueError("token_ids must not start with bos_token_id") if token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("token_ids must not end with eos_token_id") - new_token_ids = [] + # Expand any TI tokens to their corresponding pad tokens. + new_token_ids: list[int] = [] for token_id in token_ids: new_token_ids.append(token_id) if token_id in self.pad_tokens: new_token_ids.extend(self.pad_tokens[token_id]) - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. + # Do not exceed the max model input size. The -2 here is compensating for + # compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens. max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 if len(new_token_ids) > max_length: new_token_ids = new_token_ids[0:max_length]