Load and unload clip/t5 encoders and run inference separately in text encoding

This commit is contained in:
Brandon Rising 2024-08-23 13:28:05 -04:00
parent d09269ce75
commit 5f5f535dc0

View File

@ -59,23 +59,28 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
t5_text_encoder_info as t5_text_encoder,
clip_tokenizer_info as clip_tokenizer,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
assert isinstance(t5_tokenizer, T5Tokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt = [self.prompt]
prompt_embeds = t5_encoder(prompt)
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)