mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix long prompt weighting bug in ckpt codepath (#2382)
This commit is contained in:
commit
f169bb0020
@ -654,14 +654,15 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
|||||||
per_token_weights += [weight] * len(this_fragment_token_ids)
|
per_token_weights += [weight] * len(this_fragment_token_ids)
|
||||||
|
|
||||||
# leave room for bos/eos
|
# leave room for bos/eos
|
||||||
if len(all_token_ids) > self.max_length - 2:
|
max_token_count_without_bos_eos_markers = self.max_length - 2
|
||||||
excess_token_count = len(all_token_ids) - self.max_length - 2
|
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
||||||
|
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
||||||
# TODO build nice description string of how the truncation was applied
|
# TODO build nice description string of how the truncation was applied
|
||||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
||||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
||||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||||
all_token_ids = all_token_ids[0:self.max_length]
|
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
||||||
per_token_weights = per_token_weights[0:self.max_length]
|
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
||||||
|
|
||||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||||
# (77 = self.max_length)
|
# (77 = self.max_length)
|
||||||
|
Loading…
Reference in New Issue
Block a user