diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 0e7ebd6d69..a19dda30b8 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> FluxConditioningOutput: - t5_embeddings, clip_embeddings = self._encode_prompt(context) + # Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally + # scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary). + t5_embeddings = self._t5_encode(context) + clip_embeddings = self._clip_encode(context) conditioning_data = ConditioningFieldData( conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] ) @@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation): conditioning_name = context.conditioning.save(conditioning_data) return FluxConditioningOutput.build(conditioning_name) - def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: - # Load CLIP. - clip_tokenizer_info = context.models.load(self.clip.tokenizer) - clip_text_encoder_info = context.models.load(self.clip.text_encoder) - - # Load T5. + def _t5_encode(self, context: InvocationContext) -> torch.Tensor: t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) @@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation): prompt_embeds = t5_encoder(prompt) + assert isinstance(prompt_embeds, torch.Tensor) + return prompt_embeds + + def _clip_encode(self, context: InvocationContext) -> torch.Tensor: + clip_tokenizer_info = context.models.load(self.clip.tokenizer) + clip_text_encoder_info = context.models.load(self.clip.text_encoder) + + prompt = [self.prompt] + with ( clip_text_encoder_info as clip_text_encoder, clip_tokenizer_info as clip_tokenizer, @@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation): pooled_prompt_embeds = clip_encoder(prompt) - assert isinstance(prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor) - return prompt_embeds, pooled_prompt_embeds + return pooled_prompt_embeds