mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
diffusers: restore prompt weighting feature
This commit is contained in:
parent
05a1d68ef4
commit
e99faeb8d7
@ -1,5 +1,4 @@
|
|||||||
import secrets
|
import secrets
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union, Callable
|
from typing import List, Optional, Union, Callable
|
||||||
|
|
||||||
@ -11,6 +10,8 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
@ -76,6 +77,11 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
|
# InvokeAI's interface for text embeddings and whatnot
|
||||||
|
self.clip_embedder = WeightedFrozenCLIPEmbedder(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
transformer=self.text_encoder
|
||||||
|
)
|
||||||
|
|
||||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||||
r"""
|
r"""
|
||||||
@ -312,27 +318,12 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
|
@torch.inference_mode()
|
||||||
fragment_weights=None, **kwargs):
|
def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
|
||||||
"""
|
"""
|
||||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||||
"""
|
"""
|
||||||
assert return_tokens == True
|
return self.clip_embedder.encode(c, return_tokens=return_tokens, fragment_weights=fragment_weights)
|
||||||
if fragment_weights:
|
|
||||||
weights = fragment_weights[0]
|
|
||||||
if any(weight != 1.0 for weight in weights):
|
|
||||||
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
|
|
||||||
|
|
||||||
if kwargs:
|
|
||||||
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
|
|
||||||
|
|
||||||
text_fragments = c[0]
|
|
||||||
text_input = self._tokenize(text_fragments)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
token_ids = text_input.input_ids.to(self.text_encoder.device)
|
|
||||||
text_embeddings = self.text_encoder(token_ids)[0]
|
|
||||||
return text_embeddings, text_input.input_ids
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||||
|
@ -239,22 +239,22 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
version='openai/clip-vit-large-patch14',
|
version='openai/clip-vit-large-patch14',
|
||||||
device=choose_torch_device(),
|
|
||||||
max_length=77,
|
max_length=77,
|
||||||
|
tokenizer=None,
|
||||||
|
transformer=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cache = os.path.join(Globals.root,'models',version)
|
cache = os.path.join(Globals.root,'models',version)
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
self.tokenizer = tokenizer or CLIPTokenizer.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=cache,
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(
|
self.transformer = transformer or CLIPTextModel.from_pretrained(
|
||||||
version,
|
version,
|
||||||
cache_dir=cache,
|
cache_dir=cache,
|
||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
@ -460,6 +460,10 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
def encode(self, text, **kwargs):
|
def encode(self, text, **kwargs):
|
||||||
return self(text, **kwargs)
|
return self(text, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.transformer.device
|
||||||
|
|
||||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||||
|
|
||||||
fragment_weights_key = "fragment_weights"
|
fragment_weights_key = "fragment_weights"
|
||||||
@ -548,7 +552,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
|
|
||||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||||
|
|
||||||
# append to batch
|
# append to batch
|
||||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user