Update TI handling for compatibility with transformers 4.40.0 (#6449)

## Summary

- Updated the documentation for `TextualInversionManager`
- Updated the `self.tokenizer.model_max_length` access to work with the
latest transformers version. Thanks to @skunkworxdark for looking into
this here:
https://github.com/invoke-ai/InvokeAI/issues/6445#issuecomment-2133098342

## Related Issues / Discussions

Closes #6445 

## QA Instructions

I tested with `transformers==4.41.1`, and compared the results against a
recent InvokeAI version before updating tranformers - no change, as
expected.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-05-28 08:32:02 -04:00 committed by GitHub
commit df91d1b849
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class.""" """Textual Inversion wrapper class."""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Optional, Union
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
return result return result
# no type hints for BaseTextualInversionManager? class TextualInversionManager(BaseTextualInversionManager):
class TextualInversionManager(BaseTextualInversionManager): # type: ignore """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer): def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {} self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0: if len(self.pad_tokens) == 0:
return token_ids return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id: if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id") raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id") raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = [] # Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids: for token_id in token_ids:
new_token_ids.append(token_id) new_token_ids.append(token_id)
if token_id in self.pad_tokens: if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id]) new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size # Do not exceed the max model input size. The -2 here is compensating for
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), # compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
# which first removes and then adds back the start and end tokens. max_length = self.tokenizer.model_max_length - 2
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length: if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length] new_token_ids = new_token_ids[0:max_length]
return new_token_ids return new_token_ids