mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix blend tokenizaiton reporting; fix LDM checkpoint support
This commit is contained in:
parent
cedbe8fcd7
commit
97eac58a50
@ -25,7 +25,8 @@ from invokeai.backend.modules.parameters import parameters_to_command
|
|||||||
import invokeai.frontend.dist as frontend
|
import invokeai.frontend.dist as frontend
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure, split_weighted_subprompts
|
from ldm.invoke.conditioning import get_tokens_for_prompt_object, get_prompt_structure, split_weighted_subprompts, \
|
||||||
|
get_tokenizer
|
||||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
||||||
@ -1258,7 +1259,7 @@ class InvokeAIWebServer:
|
|||||||
parsed_prompt, _ = get_prompt_structure(
|
parsed_prompt, _ = get_prompt_structure(
|
||||||
generation_parameters["prompt"])
|
generation_parameters["prompt"])
|
||||||
tokens = None if type(parsed_prompt) is Blend else \
|
tokens = None if type(parsed_prompt) is Blend else \
|
||||||
get_tokens_for_prompt(self.generate.model.tokenizer, parsed_prompt)
|
get_tokens_for_prompt_object(get_tokenizer(self.generate.model), parsed_prompt)
|
||||||
attention_maps_image_base64_url = None if attention_maps_image is None \
|
attention_maps_image_base64_url = None if attention_maps_image is None \
|
||||||
else image_to_dataURL(attention_maps_image)
|
else image_to_dataURL(attention_maps_image)
|
||||||
|
|
||||||
|
@ -7,7 +7,9 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
import re
|
import re
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional, Any
|
||||||
|
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||||
@ -15,14 +17,38 @@ from .devices import torch_dtype
|
|||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
|
def get_tokenizer(model) -> CLIPTokenizer:
|
||||||
|
# TODO remove legacy ckpt fallback handling
|
||||||
|
return (getattr(model, 'tokenizer', None) # diffusers
|
||||||
|
or model.cond_stage_model.tokenizer) # ldm
|
||||||
|
|
||||||
|
def get_text_encoder(model) -> Any:
|
||||||
|
# TODO remove legacy ckpt fallback handling
|
||||||
|
return (getattr(model, 'text_encoder', None) # diffusers
|
||||||
|
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||||
|
|
||||||
|
class UnsqueezingLDMTransformer:
|
||||||
|
def __init__(self, ldm_transformer):
|
||||||
|
self.ldm_transformer = ldm_transformer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.ldm_transformer.device
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
|
||||||
|
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
|
|
||||||
compel = Compel(tokenizer=model.tokenizer,
|
tokenizer = get_tokenizer(model)
|
||||||
text_encoder=model.text_encoder,
|
text_encoder = get_text_encoder(model)
|
||||||
|
compel = Compel(tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype)
|
dtype_for_device_getter=torch_dtype)
|
||||||
|
|
||||||
@ -32,16 +58,16 @@ def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_l
|
|||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
positive_prompt = legacy_blend
|
positive_prompt = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
negative_prompt: FlattenedPrompt|Blend = compel.parse_prompt_string(negative_prompt_string)
|
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
|
|
||||||
tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt)
|
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
cross_attention_control_args=options.get(
|
cross_attention_control_args=options.get(
|
||||||
@ -55,13 +81,29 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
|
|||||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||||
"""
|
"""
|
||||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||||
negative_prompt = Compel.parse_prompt_string(negative_prompt_string)
|
positive_prompt: FlattenedPrompt|Blend
|
||||||
|
if legacy_blend is not None:
|
||||||
|
positive_prompt = legacy_blend
|
||||||
|
else:
|
||||||
|
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
negative_prompt: FlattenedPrompt|Blend = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
|
||||||
return positive_prompt, negative_prompt
|
return positive_prompt, negative_prompt
|
||||||
|
|
||||||
|
def get_max_token_count(tokenizer, prompt: FlattenedPrompt|Blend, truncate_if_too_long=True) -> int:
|
||||||
|
if type(prompt) is Blend:
|
||||||
|
blend: Blend = prompt
|
||||||
|
return max([get_max_token_count(tokenizer, c, truncate_if_too_long) for c in blend.prompts])
|
||||||
|
else:
|
||||||
|
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||||
|
|
||||||
|
if type(parsed_prompt) is Blend:
|
||||||
|
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||||
|
|
||||||
def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
|
||||||
text_fragments = [x.text if type(x) is Fragment else
|
text_fragments = [x.text if type(x) is Fragment else
|
||||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||||
str(x))
|
str(x))
|
||||||
|
@ -38,7 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel",
|
"compel>=0.1.6",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.13",
|
"diffusers[torch]~=0.13",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user