create an embedding_manager for diffusers

This commit is contained in:
Kevin Turner 2022-11-28 18:26:52 -08:00
parent f9dcc9a9b4
commit ca1f76b7ba

View File

@ -16,6 +16,7 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.modules.embedding_manager import EmbeddingManager
from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder from ldm.modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -28,6 +29,16 @@ class PipelineIntermediateState:
predicted_original: Optional[torch.Tensor] = None predicted_original: Optional[torch.Tensor] = None
# copied from configs/stable-diffusion/v1-inference.yaml
_default_personalization_config_params = dict(
placeholder_strings=["*"],
initializer_wods=["sculpture"],
per_image_tokens=False,
num_vectors_per_token=8,
progressive_words=False
)
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
@ -89,6 +100,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
transformer=self.text_encoder transformer=self.text_encoder
) )
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.embedding_manager = EmbeddingManager(self.clip_embedder, **_default_personalization_config_params)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor,