diffusers: restore prompt weighting feature

This commit is contained in:
Kevin Turner 2022-11-11 13:16:09 -08:00
parent 05a1d68ef4
commit e99faeb8d7
2 changed files with 19 additions and 24 deletions

View File

@ -1,5 +1,4 @@
import secrets import secrets
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union, Callable from typing import List, Optional, Union, Callable
@ -11,6 +10,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
@ -76,6 +77,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# InvokeAI's interface for text embeddings and whatnot
self.clip_embedder = WeightedFrozenCLIPEmbedder(
tokenizer=self.tokenizer,
transformer=self.text_encoder
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
@ -312,27 +318,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings return text_embeddings
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True, @torch.inference_mode()
fragment_weights=None, **kwargs): def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
""" """
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
""" """
assert return_tokens == True return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
text_fragments = c[0]
text_input = self._tokenize(text_fragments)
with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids
@torch.inference_mode() @torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]): def _tokenize(self, prompt: Union[str, List[str]]):

View File

@ -239,22 +239,22 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def __init__( def __init__(
self, self,
version='openai/clip-vit-large-patch14', version='openai/clip-vit-large-patch14',
device=choose_torch_device(),
max_length=77, max_length=77,
tokenizer=None,
transformer=None,
): ):
super().__init__() super().__init__()
cache = os.path.join(Globals.root,'models',version) cache = os.path.join(Globals.root,'models',version)
self.tokenizer = CLIPTokenizer.from_pretrained( self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
version, version,
cache_dir=cache, cache_dir=cache,
local_files_only=True local_files_only=True
) )
self.transformer = CLIPTextModel.from_pretrained( self.transformer = transformer or CLIPTextModel.from_pretrained(
version, version,
cache_dir=cache, cache_dir=cache,
local_files_only=True local_files_only=True
) )
self.device = device
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()
@ -460,6 +460,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self(text, **kwargs) return self(text, **kwargs)
@property
def device(self):
return self.transformer.device
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights" fragment_weights_key = "fragment_weights"
@ -548,7 +552,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch # append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)