mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix blend tokenizaiton reporting; fix LDM checkpoint support
This commit is contained in:
parent
cedbe8fcd7
commit
97eac58a50
@ -25,7 +25,8 @@ from invokeai.backend.modules.parameters import parameters_to_command
|
||||
import invokeai.frontend.dist as frontend
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure, split_weighted_subprompts
|
||||
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
|
||||
get_tokenizer
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
||||
@ -1258,7 +1259,7 @@ class InvokeAIWebServer:
|
||||
parsed_prompt, _ = get_prompt_structure(
|
||||
generation_parameters["prompt"])
|
||||
tokens = None if type(parsed_prompt) is Blend else \
|
||||
get_tokens_for_prompt(self.generate.model.tokenizer, parsed_prompt)
|
||||
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
|
||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||
else image_to_dataURL(attention_maps_image)
|
||||
|
||||
|
@ -7,7 +7,9 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
||||
|
||||
'''
|
||||
import re
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Any
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||
@ -15,14 +17,38 @@ from .devices import torch_dtype
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'tokenizer', None) # diffusers
|
||||
or model.cond_stage_model.tokenizer) # ldm
|
||||
|
||||
def get_text_encoder(model) -> Any:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'text_encoder', None) # diffusers
|
||||
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||
|
||||
class UnsqueezingLDMTransformer:
|
||||
def __init__(self, ldm_transformer):
|
||||
self.ldm_transformer = ldm_transformer
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.ldm_transformer.device
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
|
||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
tokenizer = get_tokenizer(model)
|
||||
text_encoder = get_text_encoder(model)
|
||||
compel = Compel(tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
|
||||
@ -32,16 +58,16 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = compel.parse_prompt_string(negative_prompt_string)
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt)
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
@ -55,13 +81,29 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
|
||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||
"""
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt = Compel.parse_prompt_string(negative_prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
def get_max_token_count(tokenizer, prompt: FlattenedPrompt|Blend, truncate_if_too_long=True) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts])
|
||||
else:
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||
text_fragments = [x.text if type(x) is Fragment else
|
||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||
str(x))
|
||||
|
@ -38,7 +38,7 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel",
|
||||
"compel>=0.1.6",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.13",
|
||||
"dnspython==2.2.1",
|
||||
|
Loading…
Reference in New Issue
Block a user