Add docs to TextualInversionManager and improve types. No changes to functionality.

This commit is contained in:
Ryan Dick 2024-05-27 10:32:49 -04:00 committed by Kent Keirsey
parent 21aa42627b
commit 994c61b67a

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class.""" """Textual Inversion wrapper class."""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Optional, Union
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
@ -66,33 +66,47 @@ class TextualInversionModelRaw(RawModel):
return result return result
# no type hints for BaseTextualInversionManager? class TextualInversionManager(BaseTextualInversionManager):
class TextualInversionManager(BaseTextualInversionManager): # type: ignore """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer): def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {} self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0: if len(self.pad_tokens) == 0:
return token_ids return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id: if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id") raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id") raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = [] # Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids: for token_id in token_ids:
new_token_ids.append(token_id) new_token_ids.append(token_id)
if token_id in self.pad_tokens: if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id]) new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size # Do not exceed the max model input size. The -2 here is compensating for
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), # compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length: if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length] new_token_ids = new_token_ids[0:max_length]