First working TI draft

This commit is contained in:
Sergey Borisov 2023-05-31 02:12:27 +03:00
parent 69ccd3a0b5
commit b47786e846
4 changed files with 219 additions and 41 deletions

View File

@ -1,6 +1,7 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from contextlib import ExitStack from contextlib import ExitStack
import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -9,7 +10,8 @@ from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management.lora import LoRAHelper from ...backend.model_management import SDModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel from compel import Compel
from compel.prompt_parser import ( from compel.prompt_parser import (
@ -58,56 +60,61 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput: def invoke(self, context: InvocationContext) -> CompelOutput:
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(), **self.clip.tokenizer.dict(),
) )
with text_encoder_info as text_encoder,\ text_encoder_info = context.services.model_manager.get_model(
tokenizer_info as tokenizer,\ **self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack: ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
# TODO: global? input? ti_list = []
#use_full_precision = precision == "float32" or precision == "autocast" for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
#use_full_precision = False name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
)
)
except Exception as e:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
# TODO: redo TI when separate model loding implemented with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
#textual_inversion_manager = TextualInversionManager( ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=None, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO: truncate_long_prompts=True, # TODO:
) )
# TODO: support legacy blend? conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
conjunction = Compel.parse_prompt_string(self.prompt) if context.services.configuration.log_tokenization:
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] log_tokenization_for_prompt_object(prompt, tokenizer)
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
with LoRAHelper.apply_lora_text_encoder(text_encoder, loras):
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support # TODO: long prompt support
#if not self.truncate_long_prompts: #if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import copy
from pathlib import Path from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any from typing import Optional, Dict, Tuple, Any
@ -11,6 +12,8 @@ from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel from transformers import CLIPTextModel
from compel.embeddings_provider import BaseTextualInversionManager
class LoRALayerBase: class LoRALayerBase:
#rank: Optional[int] #rank: Optional[int]
#alpha: Optional[float] #alpha: Optional[float]
@ -444,7 +447,7 @@ with LoRAHelper.apply_lora_unet(unet, loras):
""" """
# TODO: rename smth like ModelPatcher and add TI method? # TODO: rename smth like ModelPatcher and add TI method?
class LoRAHelper: class ModelPatcher:
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
@ -539,3 +542,135 @@ class LoRAHelper:
for module_key, hook in hooks.items(): for module_key, hook in hooks.items():
hook.remove() hook.remove()
hooks.clear() hooks.clear()
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
try:
ti_manager = TextualInversionManager()
ti_tokenizer = copy.deepcopy(tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
yield ti_tokenizer, ti_manager
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
class TextualInversionManager(BaseTextualInversionManager):
pad_tokens: Dict[int, List[int]]
def __init__(self):
self.pad_tokens = dict()
def expand_textual_inversion_token_ids_if_necessary(
self, token_ids: list[int]
) -> list[int]:
#if token_ids[0] == self.tokenizer.bos_token_id:
# raise ValueError("token_ids must not start with bos_token_id")
#if token_ids[-1] == self.tokenizer.eos_token_id:
# raise ValueError("token_ids must not end with eos_token_id")
if len(self.pad_tokens) == 0:
return token_ids
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
return new_token_ids

View File

@ -37,7 +37,7 @@ from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel from .lora import LoRAModel, TextualInversionModel
def get_model_path(repo_id_or_path: str): def get_model_path(repo_id_or_path: str):
globals = get_invokeai_config() globals = get_invokeai_config()
@ -155,6 +155,7 @@ class SDModelType(str, Enum):
Vae = "vae" Vae = "vae"
Scheduler = "scheduler" Scheduler = "scheduler"
Lora = "lora" Lora = "lora"
TextualInversion = "textual_inversion"
class ModelInfoBase: class ModelInfoBase:
@ -417,7 +418,7 @@ class LoRAModelInfo(ModelInfoBase):
def get_size(self, child_type: Optional[SDModelType] = None): def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in lora model") raise Exception("There is no child models in lora")
return self.model_size return self.model_size
def get_model( def get_model(
@ -426,7 +427,7 @@ class LoRAModelInfo(ModelInfoBase):
torch_dtype: Optional[torch.dtype] = None, torch_dtype: Optional[torch.dtype] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in lora model") raise Exception("There is no child models in lora")
model = LoRAModel.from_checkpoint( model = LoRAModel.from_checkpoint(
file_path=self.model_path, file_path=self.model_path,
@ -437,11 +438,46 @@ class LoRAModelInfo(ModelInfoBase):
return model return model
class TextualInversionModelInfo(ModelInfoBase):
#model_size: int
def __init__(self, file_path: str, model_type: SDModelType):
assert model_type == SDModelType.TextualInversion
# check manualy as super().__init__ will try to resolve repo_id too
if not os.path.exists(file_path):
raise Exception("Model not found")
super().__init__(file_path, model_type)
self.model_size = os.path.getsize(file_path)
def get_size(self, child_type: Optional[SDModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
def get_model(
self,
child_type: Optional[SDModelType] = None,
torch_dtype: Optional[torch.dtype] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
model = TextualInversionModel.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
)
self.model_size = model.embedding.nelement() * model.embedding.element_size()
return model
MODEL_TYPES = { MODEL_TYPES = {
SDModelType.Diffusers: DiffusersModelInfo, SDModelType.Diffusers: DiffusersModelInfo,
SDModelType.Classifier: ClassifierModelInfo, SDModelType.Classifier: ClassifierModelInfo,
SDModelType.Vae: VaeModelInfo, SDModelType.Vae: VaeModelInfo,
SDModelType.Lora: LoRAModelInfo, SDModelType.Lora: LoRAModelInfo,
SDModelType.TextualInversion: TextualInversionModelInfo,
} }

View File

@ -332,7 +332,7 @@ class ModelManager(object):
location = None location = None
revision = mconfig.get('revision') revision = mconfig.get('revision')
if model_type in [SDModelType.Lora]: if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
hash = "<NO_HASH>" # TODO: hash = "<NO_HASH>" # TODO:
else: else:
hash = self.cache.model_hash(location, revision) hash = self.cache.model_hash(location, revision)