mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First working TI draft
This commit is contained in:
parent
69ccd3a0b5
commit
b47786e846
@ -1,6 +1,7 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
import re
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
@ -9,7 +10,8 @@ from .model import ClipField
|
|||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||||
from ...backend.model_management.lora import LoRAHelper
|
from ...backend.model_management import SDModelType
|
||||||
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
from compel.prompt_parser import (
|
from compel.prompt_parser import (
|
||||||
@ -58,56 +60,61 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.text_encoder.dict(),
|
|
||||||
)
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(),
|
||||||
)
|
)
|
||||||
with text_encoder_info as text_encoder,\
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
tokenizer_info as tokenizer,\
|
**self.clip.text_encoder.dict(),
|
||||||
|
)
|
||||||
|
with tokenizer_info as orig_tokenizer,\
|
||||||
|
text_encoder_info as text_encoder,\
|
||||||
ExitStack() as stack:
|
ExitStack() as stack:
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
# TODO: global? input?
|
ti_list = []
|
||||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
#use_full_precision = False
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
stack.enter_context(
|
||||||
|
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
#print(e)
|
||||||
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
# TODO: redo TI when separate model loding implemented
|
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||||
#textual_inversion_manager = TextualInversionManager(
|
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# text_encoder=text_encoder,
|
|
||||||
# full_precision=use_full_precision,
|
|
||||||
#)
|
|
||||||
|
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=None,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
if context.services.configuration.log_tokenization:
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
|
||||||
|
|
||||||
with LoRAHelper.apply_lora_text_encoder(text_encoder, loras):
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
|
|
||||||
# TODO: long prompt support
|
# TODO: long prompt support
|
||||||
#if not self.truncate_long_prompts:
|
#if not self.truncate_long_prompts:
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, Dict, Tuple, Any
|
from typing import Optional, Dict, Tuple, Any
|
||||||
@ -11,6 +12,8 @@ from torch.utils.hooks import RemovableHandle
|
|||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
#rank: Optional[int]
|
#rank: Optional[int]
|
||||||
#alpha: Optional[float]
|
#alpha: Optional[float]
|
||||||
@ -444,7 +447,7 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
# TODO: rename smth like ModelPatcher and add TI method?
|
||||||
class LoRAHelper:
|
class ModelPatcher:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
@ -539,3 +542,135 @@ class LoRAHelper:
|
|||||||
for module_key, hook in hooks.items():
|
for module_key, hook in hooks.items():
|
||||||
hook.remove()
|
hook.remove()
|
||||||
hooks.clear()
|
hooks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_ti(
|
||||||
|
cls,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
ti_list: List[Any],
|
||||||
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
|
init_tokens_count = None
|
||||||
|
new_tokens_added = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ti_manager = TextualInversionManager()
|
||||||
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
|
def _get_trigger(ti, index):
|
||||||
|
trigger = ti.name
|
||||||
|
if index > 0:
|
||||||
|
trigger += f"-!pad-{i}"
|
||||||
|
return f"<{trigger}>"
|
||||||
|
|
||||||
|
# modify tokenizer
|
||||||
|
new_tokens_added = 0
|
||||||
|
for ti in ti_list:
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
|
# modify text_encoder
|
||||||
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
|
for ti in ti_list:
|
||||||
|
ti_tokens = []
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
embedding = ti.embedding[i]
|
||||||
|
trigger = _get_trigger(ti, i)
|
||||||
|
|
||||||
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||||
|
|
||||||
|
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_embeddings.weight.data[token_id] = embedding
|
||||||
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
|
if len(ti_tokens) > 1:
|
||||||
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||||
|
|
||||||
|
yield ti_tokenizer, ti_manager
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if init_tokens_count and new_tokens_added:
|
||||||
|
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionModel:
|
||||||
|
name: str
|
||||||
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if not isinstance(file_path, Path):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
result = cls() # TODO:
|
||||||
|
result.name = file_path.stem # TODO:
|
||||||
|
|
||||||
|
if file_path.suffix == ".safetensors":
|
||||||
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
|
||||||
|
# both v1 and v2 format embeddings
|
||||||
|
# difference mostly in metadata
|
||||||
|
if "string_to_param" in state_dict:
|
||||||
|
if len(state_dict["string_to_param"]) > 1:
|
||||||
|
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
|
||||||
|
|
||||||
|
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||||
|
|
||||||
|
# v3 (easynegative)
|
||||||
|
elif "emb_params" in state_dict:
|
||||||
|
result.embedding = state_dict["emb_params"]
|
||||||
|
|
||||||
|
# v4(diffusers bin files)
|
||||||
|
else:
|
||||||
|
result.embedding = next(iter(state_dict.values()))
|
||||||
|
|
||||||
|
if not isinstance(result.embedding, torch.Tensor):
|
||||||
|
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
|
pad_tokens: Dict[int, List[int]]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pad_tokens = dict()
|
||||||
|
|
||||||
|
def expand_textual_inversion_token_ids_if_necessary(
|
||||||
|
self, token_ids: list[int]
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
|
#if token_ids[0] == self.tokenizer.bos_token_id:
|
||||||
|
# raise ValueError("token_ids must not start with bos_token_id")
|
||||||
|
#if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
|
# raise ValueError("token_ids must not end with eos_token_id")
|
||||||
|
|
||||||
|
if len(self.pad_tokens) == 0:
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
new_token_ids = []
|
||||||
|
for token_id in token_ids:
|
||||||
|
new_token_ids.append(token_id)
|
||||||
|
if token_id in self.pad_tokens:
|
||||||
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
|
return new_token_ids
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ from transformers import logging as transformers_logging
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel, TextualInversionModel
|
||||||
|
|
||||||
def get_model_path(repo_id_or_path: str):
|
def get_model_path(repo_id_or_path: str):
|
||||||
globals = get_invokeai_config()
|
globals = get_invokeai_config()
|
||||||
@ -155,6 +155,7 @@ class SDModelType(str, Enum):
|
|||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
|
TextualInversion = "textual_inversion"
|
||||||
|
|
||||||
|
|
||||||
class ModelInfoBase:
|
class ModelInfoBase:
|
||||||
@ -417,7 +418,7 @@ class LoRAModelInfo(ModelInfoBase):
|
|||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in lora model")
|
raise Exception("There is no child models in lora")
|
||||||
return self.model_size
|
return self.model_size
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
@ -426,7 +427,7 @@ class LoRAModelInfo(ModelInfoBase):
|
|||||||
torch_dtype: Optional[torch.dtype] = None,
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in lora model")
|
raise Exception("There is no child models in lora")
|
||||||
|
|
||||||
model = LoRAModel.from_checkpoint(
|
model = LoRAModel.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=self.model_path,
|
||||||
@ -437,11 +438,46 @@ class LoRAModelInfo(ModelInfoBase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionModelInfo(ModelInfoBase):
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
def __init__(self, file_path: str, model_type: SDModelType):
|
||||||
|
assert model_type == SDModelType.TextualInversion
|
||||||
|
# check manualy as super().__init__ will try to resolve repo_id too
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise Exception("Model not found")
|
||||||
|
super().__init__(file_path, model_type)
|
||||||
|
|
||||||
|
self.model_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SDModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
child_type: Optional[SDModelType] = None,
|
||||||
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in textual inversion")
|
||||||
|
|
||||||
|
model = TextualInversionModel.from_checkpoint(
|
||||||
|
file_path=self.model_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_size = model.embedding.nelement() * model.embedding.element_size()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
MODEL_TYPES = {
|
||||||
SDModelType.Diffusers: DiffusersModelInfo,
|
SDModelType.Diffusers: DiffusersModelInfo,
|
||||||
SDModelType.Classifier: ClassifierModelInfo,
|
SDModelType.Classifier: ClassifierModelInfo,
|
||||||
SDModelType.Vae: VaeModelInfo,
|
SDModelType.Vae: VaeModelInfo,
|
||||||
SDModelType.Lora: LoRAModelInfo,
|
SDModelType.Lora: LoRAModelInfo,
|
||||||
|
SDModelType.TextualInversion: TextualInversionModelInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ class ModelManager(object):
|
|||||||
location = None
|
location = None
|
||||||
|
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
if model_type in [SDModelType.Lora]:
|
if model_type in [SDModelType.Lora, SDModelType.TextualInversion]:
|
||||||
hash = "<NO_HASH>" # TODO:
|
hash = "<NO_HASH>" # TODO:
|
||||||
else:
|
else:
|
||||||
hash = self.cache.model_hash(location, revision)
|
hash = self.cache.model_hash(location, revision)
|
||||||
|
Loading…
Reference in New Issue
Block a user