mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: restore prompt weighting feature
This commit is contained in:
parent
05a1d68ef4
commit
e99faeb8d7
@ -1,5 +1,4 @@
|
||||
import secrets
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
@ -11,6 +10,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -76,6 +77,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.clip_embedder = WeightedFrozenCLIPEmbedder(
|
||||
tokenizer=self.tokenizer,
|
||||
transformer=self.text_encoder
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@ -312,27 +318,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
return text_embeddings
|
||||
|
||||
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
|
||||
fragment_weights=None, **kwargs):
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||
"""
|
||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
assert return_tokens == True
|
||||
if fragment_weights:
|
||||
weights = fragment_weights[0]
|
||||
if any(weight != 1.0 for weight in weights):
|
||||
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
|
||||
|
||||
if kwargs:
|
||||
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
|
||||
|
||||
text_fragments = c[0]
|
||||
text_input = self._tokenize(text_fragments)
|
||||
|
||||
with torch.inference_mode():
|
||||
token_ids = text_input.input_ids.to(self.text_encoder.device)
|
||||
text_embeddings = self.text_encoder(token_ids)[0]
|
||||
return text_embeddings, text_input.input_ids
|
||||
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||
|
@ -239,22 +239,22 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device=choose_torch_device(),
|
||||
max_length=77,
|
||||
tokenizer=None,
|
||||
transformer=None,
|
||||
):
|
||||
super().__init__()
|
||||
cache = os.path.join(Globals.root,'models',version)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
self.transformer = transformer or CLIPTextModel.from_pretrained(
|
||||
version,
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
|
||||
@ -460,6 +460,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.transformer.device
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
@ -548,7 +552,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user