Add docs to TextualInversionManager and improve types. No changes to functionality.

This commit is contained in:
Ryan Dick 2024-05-27 10:32:49 -04:00 committed by Kent Keirsey
parent 21aa42627b
commit 994c61b67a

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
@ -66,33 +66,47 @@ class TextualInversionModelRaw(RawModel):
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0:
return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
# Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
# Do not exceed the max model input size. The -2 here is compensating for
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]