mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time.
This commit is contained in:
parent
29fe1533f2
commit
c738fe051f
@ -40,7 +40,10 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
|
||||||
|
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||||
|
t5_embeddings = self._t5_encode(context)
|
||||||
|
clip_embeddings = self._clip_encode(context)
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
)
|
)
|
||||||
@ -48,12 +51,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return FluxConditioningOutput.build(conditioning_name)
|
return FluxConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||||
# Load CLIP.
|
|
||||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
|
||||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
|
||||||
|
|
||||||
# Load T5.
|
|
||||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||||
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||||
|
|
||||||
@ -70,6 +68,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt_embeds = t5_encoder(prompt)
|
prompt_embeds = t5_encoder(prompt)
|
||||||
|
|
||||||
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
|
||||||
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
|
|
||||||
|
prompt = [self.prompt]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
clip_text_encoder_info as clip_text_encoder,
|
clip_text_encoder_info as clip_text_encoder,
|
||||||
clip_tokenizer_info as clip_tokenizer,
|
clip_tokenizer_info as clip_tokenizer,
|
||||||
@ -81,6 +88,5 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
pooled_prompt_embeds = clip_encoder(prompt)
|
pooled_prompt_embeds = clip_encoder(prompt)
|
||||||
|
|
||||||
assert isinstance(prompt_embeds, torch.Tensor)
|
|
||||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
return prompt_embeds, pooled_prompt_embeds
|
return pooled_prompt_embeds
|
||||||
|
Loading…
Reference in New Issue
Block a user