backend..conditioning: remove code for legacy model

This commit is contained in:
Kevin Turner
2023-03-09 18:01:37 -08:00
parent ad7b1fa6fb
commit 9d339e94f2
3 changed files with 4 additions and 38 deletions

View File

@ -3,7 +3,6 @@ Initialization file for invokeai.backend.prompting
""" """
from .conditioning import ( from .conditioning import (
get_prompt_structure, get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object, get_tokens_for_prompt_object,
get_uc_and_c_and_ec, get_uc_and_c_and_ec,
split_weighted_subprompts, split_weighted_subprompts,

View File

@ -7,7 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
""" """
import re import re
from typing import Any, Optional, Union from typing import Optional, Union
from compel import Compel from compel import Compel
from compel.prompt_parser import ( from compel.prompt_parser import (
@ -17,7 +17,6 @@ from compel.prompt_parser import (
Fragment, Fragment,
PromptParser, PromptParser,
) )
from transformers import CLIPTokenizer
from invokeai.backend.globals import Globals from invokeai.backend.globals import Globals
@ -25,36 +24,6 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
def get_tokenizer(model) -> CLIPTokenizer:
# TODO remove legacy ckpt fallback handling
return (
getattr(model, "tokenizer", None) # diffusers
or model.cond_stage_model.tokenizer
) # ldm
def get_text_encoder(model) -> Any:
# TODO remove legacy ckpt fallback handling
return getattr(
model, "text_encoder", None
) or UnsqueezingLDMTransformer( # diffusers
model.cond_stage_model.transformer
) # ldm
class UnsqueezingLDMTransformer:
def __init__(self, ldm_transformer):
self.ldm_transformer = ldm_transformer
@property
def device(self):
return self.ldm_transformer.device
def __call__(self, *args, **kwargs):
insufficiently_unsqueezed_tensor = self.ldm_transformer(*args, **kwargs)
return insufficiently_unsqueezed_tensor.unsqueeze(0)
def get_uc_and_c_and_ec( def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
): ):
@ -64,11 +33,10 @@ def get_uc_and_c_and_ec(
prompt_string prompt_string
) )
tokenizer = get_tokenizer(model) tokenizer = model.tokenizer
text_encoder = get_text_encoder(model)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False truncate_long_prompts=False

View File

@ -29,7 +29,6 @@ from ..image_util import PngWriter, retrieve_metadata
from ...frontend.merge.merge_diffusers import merge_diffusion_models from ...frontend.merge.merge_diffusers import merge_diffusion_models
from ..prompting import ( from ..prompting import (
get_prompt_structure, get_prompt_structure,
get_tokenizer,
get_tokens_for_prompt_object, get_tokens_for_prompt_object,
) )
from ..stable_diffusion import PipelineIntermediateState from ..stable_diffusion import PipelineIntermediateState
@ -1274,7 +1273,7 @@ class InvokeAIWebServer:
None None
if type(parsed_prompt) is Blend if type(parsed_prompt) is Blend
else get_tokens_for_prompt_object( else get_tokens_for_prompt_object(
get_tokenizer(self.generate.model), parsed_prompt self.generate.model.tokenizer, parsed_prompt
) )
) )
attention_maps_image_base64_url = ( attention_maps_image_base64_url = (