mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update TI handling for compatibility with transformers 4.40.0 (#6449)
## Summary - Updated the documentation for `TextualInversionManager` - Updated the `self.tokenizer.model_max_length` access to work with the latest transformers version. Thanks to @skunkworxdark for looking into this here: https://github.com/invoke-ai/InvokeAI/issues/6445#issuecomment-2133098342 ## Related Issues / Discussions Closes #6445 ## QA Instructions I tested with `transformers==4.41.1`, and compared the results against a recent InvokeAI version before updating tranformers - no change, as expected. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
df91d1b849
@ -1,7 +1,7 @@
|
|||||||
"""Textual Inversion wrapper class."""
|
"""Textual Inversion wrapper class."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# no type hints for BaseTextualInversionManager?
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||||
pad_tokens: Dict[int, List[int]]
|
|
||||||
tokenizer: CLIPTokenizer
|
|
||||||
|
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.pad_tokens = {}
|
self.pad_tokens: dict[int, list[int]] = {}
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
|
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
|
||||||
|
|
||||||
|
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
|
||||||
|
mapping of tokens to token_ids:
|
||||||
|
```
|
||||||
|
<ti_dog>: 49408
|
||||||
|
<ti_dog-!pad-1>: 49409
|
||||||
|
<ti_dog-!pad-2>: 49410
|
||||||
|
<ti_dog-!pad-3>: 49411
|
||||||
|
```
|
||||||
|
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
|
||||||
|
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
|
||||||
|
"""
|
||||||
|
# Short circuit if there are no pad tokens to save a little time.
|
||||||
if len(self.pad_tokens) == 0:
|
if len(self.pad_tokens) == 0:
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
|
||||||
|
# this assumption here.
|
||||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||||
raise ValueError("token_ids must not start with bos_token_id")
|
raise ValueError("token_ids must not start with bos_token_id")
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
raise ValueError("token_ids must not end with eos_token_id")
|
raise ValueError("token_ids must not end with eos_token_id")
|
||||||
|
|
||||||
new_token_ids = []
|
# Expand any TI tokens to their corresponding pad tokens.
|
||||||
|
new_token_ids: list[int] = []
|
||||||
for token_id in token_ids:
|
for token_id in token_ids:
|
||||||
new_token_ids.append(token_id)
|
new_token_ids.append(token_id)
|
||||||
if token_id in self.pad_tokens:
|
if token_id in self.pad_tokens:
|
||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
# Do not exceed the max model input size
|
# Do not exceed the max model input size. The -2 here is compensating for
|
||||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
|
||||||
# which first removes and then adds back the start and end tokens.
|
max_length = self.tokenizer.model_max_length - 2
|
||||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
|
||||||
if len(new_token_ids) > max_length:
|
if len(new_token_ids) > max_length:
|
||||||
|
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
|
||||||
|
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
|
||||||
|
# prompts.
|
||||||
new_token_ids = new_token_ids[0:max_length]
|
new_token_ids = new_token_ids[0:max_length]
|
||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
Loading…
Reference in New Issue
Block a user