backend..conditioning: remove code for legacy model

This commit is contained in:
Kevin Turner 2023-03-09 18:01:37 -08:00
parent ad7b1fa6fb
commit 9d339e94f2
3 changed files with 4 additions and 38 deletions

View File

@ -3,7 +3,6 @@ Initialization file for invokeai.backend.prompting
"""
from .conditioning import (
get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object,
get_uc_and_c_and_ec,
split_weighted_subprompts,

View File

@ -7,7 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
"""
import re
from typing import Any, Optional, Union
from typing import Optional, Union
from compel import Compel
from compel.prompt_parser import (
@ -17,7 +17,6 @@ from compel.prompt_parser import (
Fragment,
PromptParser,
)
from transformers import CLIPTokenizer
from invokeai.backend.globals import Globals
@ -25,36 +24,6 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (
getattr(model, "tokenizer", None) # diffusers
or model.cond_stage_model.tokenizer
) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return getattr(
model, "text_encoder", None
) or UnsqueezingLDMTransformer( # diffusers
model.cond_stage_model.transformer
) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@property
def device(self):
return self.ldm_transformer.device
def __call__(self, *args, **kwargs):
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
@ -64,11 +33,10 @@ def get_uc_and_c_and_ec(
prompt_string
)
tokenizer = get_tokenizer(model)
text_encoder = get_text_encoder(model)
tokenizer = model.tokenizer
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False

View File

@ -29,7 +29,6 @@ from ..image_util import PngWriter, retrieve_metadata
from ...frontend.merge.merge_diffusers import merge_diffusion_models
from ..prompting import (
get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object,
)
from ..stable_diffusion import PipelineIntermediateState
@ -1274,7 +1273,7 @@ class InvokeAIWebServer:
None
if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt
self.generate.model.tokenizer, parsed_prompt
)
)
attention_maps_image_base64_url = (