mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use the FluxPipeline.encode_prompt() api rather than trying to run the two text encoders separately.
This commit is contained in:
parent
3599a4a3e4
commit
b227b9059d
@ -43,41 +43,14 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
|
||||||
|
|
||||||
clip_embeddings = self._run_clip_text_encoder(context, model_path)
|
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
|
||||||
t5_embeddings = self._run_t5_text_encoder(context, model_path)
|
|
||||||
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
|
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
|
||||||
image = self._run_vae_decoding(context, model_path, latents)
|
image = self._run_vae_decoding(context, model_path, latents)
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _run_clip_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
|
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Run the CLIP text encoder."""
|
# Determine the T5 max sequence lenght based on the model.
|
||||||
tokenizer_path = flux_model_dir / "tokenizer"
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
|
|
||||||
assert isinstance(tokenizer, CLIPTokenizer)
|
|
||||||
|
|
||||||
text_encoder_path = flux_model_dir / "text_encoder"
|
|
||||||
with context.models.load_local_model(
|
|
||||||
model_path=text_encoder_path, loader=self._load_flux_text_encoder
|
|
||||||
) as text_encoder:
|
|
||||||
assert isinstance(text_encoder, CLIPTextModel)
|
|
||||||
flux_pipeline_with_te = FluxPipeline(
|
|
||||||
scheduler=None,
|
|
||||||
vae=None,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder_2=None,
|
|
||||||
tokenizer_2=None,
|
|
||||||
transformer=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return flux_pipeline_with_te._get_clip_prompt_embeds(
|
|
||||||
prompt=self.positive_prompt, device=TorchDevice.choose_torch_device()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_t5_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
|
|
||||||
"""Run the T5 text encoder."""
|
|
||||||
|
|
||||||
if self.model == "flux-schnell":
|
if self.model == "flux-schnell":
|
||||||
max_seq_len = 256
|
max_seq_len = 256
|
||||||
# elif self.model == "flux-dev":
|
# elif self.model == "flux-dev":
|
||||||
@ -85,28 +58,51 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model: {self.model}")
|
raise ValueError(f"Unknown model: {self.model}")
|
||||||
|
|
||||||
tokenizer_path = flux_model_dir / "tokenizer_2"
|
# Load the CLIP tokenizer.
|
||||||
tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_path, local_files_only=True)
|
clip_tokenizer_path = flux_model_dir / "tokenizer"
|
||||||
assert isinstance(tokenizer_2, T5TokenizerFast)
|
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
|
||||||
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||||
|
|
||||||
text_encoder_path = flux_model_dir / "text_encoder_2"
|
# Load the T5 tokenizer.
|
||||||
with context.models.load_local_model(
|
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
|
||||||
model_path=text_encoder_path, loader=self._load_flux_text_encoder_2
|
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||||
) as text_encoder_2:
|
assert isinstance(t5_tokenizer, T5TokenizerFast)
|
||||||
flux_pipeline_with_te2 = FluxPipeline(
|
|
||||||
|
clip_text_encoder_path = flux_model_dir / "text_encoder"
|
||||||
|
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
|
||||||
|
with (
|
||||||
|
context.models.load_local_model(
|
||||||
|
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
|
||||||
|
) as clip_text_encoder,
|
||||||
|
context.models.load_local_model(
|
||||||
|
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
|
||||||
|
) as t5_text_encoder,
|
||||||
|
):
|
||||||
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||||
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||||
|
pipeline = FluxPipeline(
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
vae=None,
|
vae=None,
|
||||||
text_encoder=None,
|
text_encoder=clip_text_encoder,
|
||||||
tokenizer=None,
|
tokenizer=clip_tokenizer,
|
||||||
text_encoder_2=text_encoder_2,
|
text_encoder_2=t5_text_encoder,
|
||||||
tokenizer_2=tokenizer_2,
|
tokenizer_2=t5_tokenizer,
|
||||||
transformer=None,
|
transformer=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return flux_pipeline_with_te2._get_t5_prompt_embeds(
|
# prompt_embeds: T5 embeddings
|
||||||
prompt=self.positive_prompt, max_sequence_length=max_seq_len, device=TorchDevice.choose_torch_device()
|
# pooled_prompt_embeds: CLIP embeddings
|
||||||
|
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||||
|
prompt=self.positive_prompt,
|
||||||
|
prompt_2=self.positive_prompt,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
max_sequence_length=max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
def _run_diffusion(
|
def _run_diffusion(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
|
Loading…
Reference in New Issue
Block a user