fix blend tokenizaiton reporting; fix LDM checkpoint support

This commit is contained in:
Damian Stewart 2023-02-22 10:28:54 +01:00
parent cedbe8fcd7
commit 97eac58a50
3 changed files with 56 additions and 13 deletions

View File

@ -25,7 +25,8 @@ from invokeai.backend.modules.parameters import parameters_to_command
import invokeai.frontend.dist as frontend
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure, split_weighted_subprompts
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
get_tokenizer
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
@ -1258,7 +1259,7 @@ class InvokeAIWebServer:
parsed_prompt, _ = get_prompt_structure(
generation_parameters["prompt"])
tokens = None if type(parsed_prompt) is Blend else \
get_tokens_for_prompt(self.generate.model.tokenizer, parsed_prompt)
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
attention_maps_image_base64_url = None if attention_maps_image is None \
else image_to_dataURL(attention_maps_image)

View File

@ -7,7 +7,9 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
'''
import re
from typing import Union, Optional
from typing import Union, Optional, Any
from transformers import CLIPTokenizer, CLIPTextModel
from compel import Compel
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
@ -15,14 +17,38 @@ from .devices import torch_dtype
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.invoke.globals import Globals
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'tokenizer', None) # diffusers
or model.cond_stage_model.tokenizer) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return (getattr(model, 'text_encoder', None) # diffusers
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@property
def device(self):
return self.ldm_transformer.device
def __call__(self, *args, **kwargs):
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model)
compel = Compel(tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype)
@ -32,16 +58,16 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = compel.parse_prompt_string(negative_prompt_string)
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt)
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
@ -55,13 +81,29 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
"""
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt = Compel.parse_prompt_string(negative_prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_prompt: FlattenedPrompt|Blend
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
return positive_prompt, negative_prompt
def get_max_token_count(tokenizer, prompt: FlattenedPrompt|Blend, truncate_if_too_long=True) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts])
else:
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
text_fragments = [x.text if type(x) is Fragment else
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
str(x))

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations",
"click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel",
"compel>=0.1.6",
"datasets",
"diffusers[torch]~=0.13",
"dnspython==2.2.1",