fix long prompt weighting bug in ckpt codepath (#2382)

This commit is contained in:
Lincoln Stein 2023-01-21 15:14:14 -05:00 committed by GitHub
commit f169bb0020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -654,14 +654,15 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
per_token_weights += [weight] * len(this_fragment_token_ids) per_token_weights += [weight] * len(this_fragment_token_ids)
# leave room for bos/eos # leave room for bos/eos
if len(all_token_ids) > self.max_length - 2: max_token_count_without_bos_eos_markers = self.max_length - 2
excess_token_count = len(all_token_ids) - self.max_length - 2 if len(all_token_ids) > max_token_count_without_bos_eos_markers:
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
# TODO build nice description string of how the truncation was applied # TODO build nice description string of how the truncation was applied
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_token_ids = all_token_ids[0:self.max_length] all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
per_token_weights = per_token_weights[0:self.max_length] per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token] # pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length) # (77 = self.max_length)