First working TI draft

This commit is contained in:
Sergey Borisov
2023-05-31 02:12:27 +03:00
parent 69ccd3a0b5
commit b47786e846
4 changed files with 219 additions and 41 deletions

View File

@ -1,6 +1,7 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from contextlib import ExitStack
import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -9,7 +10,8 @@ from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management.lora import LoRAHelper
from ...backend.model_management import SDModelType
from ...backend.model_management.lora import ModelPatcher
from compel import Compel
from compel.prompt_parser import (
@ -58,56 +60,61 @@ class CompelInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> CompelOutput:
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
with text_encoder_info as text_encoder,\
tokenizer_info as tokenizer,\
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
)
)
except Exception as e:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=None,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
# TODO: support legacy blend?
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
with LoRAHelper.apply_lora_text_encoder(text_encoder, loras):
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"