mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
committed by
psychedelicious
parent
ba1f8878dd
commit
2ad0752582
98
invokeai/backend/textual_inversion.py
Normal file
98
invokeai/backend/textual_inversion.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Textual Inversion wrapper class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
from .raw_model import RawModel
|
||||
|
||||
class TextualInversionModelRaw(RawModel):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Self:
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
# both v1 and v2 format embeddings
|
||||
# difference mostly in metadata
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||
" token will be used.",
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
|
||||
# v3 (easynegative)
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v5(sdxl safetensors file)
|
||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if len(result.embedding.shape) == 1:
|
||||
result.embedding = result.embedding.unsqueeze(0)
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
Reference in New Issue
Block a user