diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index f229542a9a..0b992909ab 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -43,41 +43,14 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) - clip_embeddings = self._run_clip_text_encoder(context, model_path) - t5_embeddings = self._run_t5_text_encoder(context, model_path) + t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path) latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings) image = self._run_vae_decoding(context, model_path, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) - def _run_clip_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor: - """Run the CLIP text encoder.""" - tokenizer_path = flux_model_dir / "tokenizer" - tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, local_files_only=True) - assert isinstance(tokenizer, CLIPTokenizer) - - text_encoder_path = flux_model_dir / "text_encoder" - with context.models.load_local_model( - model_path=text_encoder_path, loader=self._load_flux_text_encoder - ) as text_encoder: - assert isinstance(text_encoder, CLIPTextModel) - flux_pipeline_with_te = FluxPipeline( - scheduler=None, - vae=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=None, - tokenizer_2=None, - transformer=None, - ) - - return flux_pipeline_with_te._get_clip_prompt_embeds( - prompt=self.positive_prompt, device=TorchDevice.choose_torch_device() - ) - - def _run_t5_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor: - """Run the T5 text encoder.""" - + def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]: + # Determine the T5 max sequence lenght based on the model. if self.model == "flux-schnell": max_seq_len = 256 # elif self.model == "flux-dev": @@ -85,28 +58,51 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): else: raise ValueError(f"Unknown model: {self.model}") - tokenizer_path = flux_model_dir / "tokenizer_2" - tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_path, local_files_only=True) - assert isinstance(tokenizer_2, T5TokenizerFast) + # Load the CLIP tokenizer. + clip_tokenizer_path = flux_model_dir / "tokenizer" + clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True) + assert isinstance(clip_tokenizer, CLIPTokenizer) - text_encoder_path = flux_model_dir / "text_encoder_2" - with context.models.load_local_model( - model_path=text_encoder_path, loader=self._load_flux_text_encoder_2 - ) as text_encoder_2: - flux_pipeline_with_te2 = FluxPipeline( + # Load the T5 tokenizer. + t5_tokenizer_path = flux_model_dir / "tokenizer_2" + t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) + assert isinstance(t5_tokenizer, T5TokenizerFast) + + clip_text_encoder_path = flux_model_dir / "text_encoder" + t5_text_encoder_path = flux_model_dir / "text_encoder_2" + with ( + context.models.load_local_model( + model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder + ) as clip_text_encoder, + context.models.load_local_model( + model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2 + ) as t5_text_encoder, + ): + assert isinstance(clip_text_encoder, CLIPTextModel) + assert isinstance(t5_text_encoder, T5EncoderModel) + pipeline = FluxPipeline( scheduler=None, vae=None, - text_encoder=None, - tokenizer=None, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, + text_encoder=clip_text_encoder, + tokenizer=clip_tokenizer, + text_encoder_2=t5_text_encoder, + tokenizer_2=t5_tokenizer, transformer=None, ) - return flux_pipeline_with_te2._get_t5_prompt_embeds( - prompt=self.positive_prompt, max_sequence_length=max_seq_len, device=TorchDevice.choose_torch_device() + # prompt_embeds: T5 embeddings + # pooled_prompt_embeds: CLIP embeddings + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt=self.positive_prompt, + prompt_2=self.positive_prompt, + device=TorchDevice.choose_torch_device(), + max_sequence_length=max_seq_len, ) + assert isinstance(prompt_embeds, torch.Tensor) + assert isinstance(pooled_prompt_embeds, torch.Tensor) + return prompt_embeds, pooled_prompt_embeds + def _run_diffusion( self, context: InvocationContext,