mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time.
This commit is contained in:
parent
29fe1533f2
commit
c738fe051f
@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
# Load T5.
|
||||
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||
|
||||
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
|
||||
prompt = [self.prompt]
|
||||
|
||||
with (
|
||||
clip_text_encoder_info as clip_text_encoder,
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
return pooled_prompt_embeds
|
||||
|
Loading…
Reference in New Issue
Block a user