Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time.

This commit is contained in:
Ryan Dick 2024-08-28 14:31:08 +00:00
parent 29fe1533f2
commit c738fe051f

View File

@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput: def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context) # Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
) )
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name) return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
prompt_embeds = t5_encoder(prompt) prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with ( with (
clip_text_encoder_info as clip_text_encoder, clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer, clip_tokenizer_info as clip_tokenizer,
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
pooled_prompt_embeds = clip_encoder(prompt) pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor) assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds return pooled_prompt_embeds