diffusers: restore prompt weighting feature

This commit is contained in:
Kevin Turner 2022-11-11 13:16:09 -08:00
parent 05a1d68ef4
commit e99faeb8d7
2 changed files with 19 additions and 24 deletions

View File

@ -1,5 +1,4 @@
import secrets
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union, Callable
@ -11,6 +10,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
@dataclass
class PipelineIntermediateState:
@ -76,6 +77,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# InvokeAI's interface for text embeddings and whatnot
self.clip_embedder = WeightedFrozenCLIPEmbedder(
tokenizer=self.tokenizer,
transformer=self.text_encoder
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
@ -312,27 +318,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
fragment_weights=None, **kwargs):
@torch.inference_mode()
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
"""
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
"""
assert return_tokens == True
if fragment_weights:
weights = fragment_weights[0]
if any(weight != 1.0 for weight in weights):
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
if kwargs:
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
text_fragments = c[0]
text_input = self._tokenize(text_fragments)
with torch.inference_mode():
token_ids = text_input.input_ids.to(self.text_encoder.device)
text_embeddings = self.text_encoder(token_ids)[0]
return text_embeddings, text_input.input_ids
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
@torch.inference_mode()
def _tokenize(self, prompt: Union[str, List[str]]):

View File

@ -239,22 +239,22 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device=choose_torch_device(),
max_length=77,
tokenizer=None,
transformer=None,
):
super().__init__()
cache = os.path.join(Globals.root,'models',version)
self.tokenizer = CLIPTokenizer.from_pretrained(
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
version,
cache_dir=cache,
local_files_only=True
)
self.transformer = CLIPTextModel.from_pretrained(
self.transformer = transformer or CLIPTextModel.from_pretrained(
version,
cache_dir=cache,
local_files_only=True
)
self.device = device
self.max_length = max_length
self.freeze()
@ -460,6 +460,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs):
return self(text, **kwargs)
@property
def device(self):
return self.transformer.device
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"
@ -548,7 +552,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)