Merge branch 'main' into fix/sd2-padding-token

This commit is contained in:
Kevin Turner 2023-01-21 13:11:02 -08:00 committed by GitHub
commit 87f3da92e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -652,14 +652,15 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
per_token_weights += [weight] * len(this_fragment_token_ids)
# leave room for bos/eos
if len(all_token_ids) > self.max_length - 2:
excess_token_count = len(all_token_ids) - self.max_length - 2
max_token_count_without_bos_eos_markers = self.max_length - 2
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
# TODO build nice description string of how the truncation was applied
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
all_token_ids = all_token_ids[0:self.max_length]
per_token_weights = per_token_weights[0:self.max_length]
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
# pad out to a 77-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
# (77 = self.max_length)