Load and unload clip/t5 encoders and run inference separately in text encoding

This commit is contained in:
Brandon Rising 2024-08-23 13:28:05 -04:00 committed by Brandon
parent 012864ceb1
commit 6764dcfdaa

View File

@ -59,23 +59,28 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with ( with (
clip_text_encoder_info as clip_text_encoder,
t5_text_encoder_info as t5_text_encoder, t5_text_encoder_info as t5_text_encoder,
clip_tokenizer_info as clip_tokenizer,
t5_tokenizer_info as t5_tokenizer, t5_tokenizer_info as t5_tokenizer,
): ):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
assert isinstance(t5_tokenizer, T5Tokenizer) assert isinstance(t5_tokenizer, T5Tokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt = [self.prompt]
prompt_embeds = t5_encoder(prompt) prompt_embeds = t5_encoder(prompt)
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt) pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor) assert isinstance(prompt_embeds, torch.Tensor)