This commit is contained in:
Ryan Dick 2024-07-24 15:01:49 -04:00
parent ca5a4ee59d
commit 2387a5d686
2 changed files with 92 additions and 1 deletions

View File

@ -0,0 +1,85 @@
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
InputField,
)
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
AUGMENT_PROMPT_INSTRUCTION = """Your task is to translate a short image caption and a style caption to a more detailed caption for the same image. The detailed caption should adhere to the following:
- be 1 sentence long
- use descriptive language that relates to the subject of interest
- it may add new details, but shouldn't change the subject of the original caption
Here are some examples:
Original caption: "A cat on a table"
Detailed caption: "A fluffy cat with a curious expression, sitting on a wooden table next to a vase of flowers."
Original caption: "medieval armor"
Detailed caption: "The gleaming suit of medieval armor stands proudly in the museum, its intricate engravings telling tales of long-forgotten battles and chivalry."
Original caption: "A panda bear as a mad scientist"
Detailed caption: "Clad in a tiny lab coat and goggles, the panda bear feverishly mixes colorful potions, embodying the eccentricity of a mad scientist in its whimsical laboratory."
Here is the prompt to translate:
Original caption: "{}"
Detailed caption:"""
@invocation("promp_augment", title="Prompt Augmentation", tags=["prompt"], category="conditioning", version="1.0.0")
class PrompAugmentationInvocation(BaseInvocation):
"""Use an LLM to augment a text prompt."""
prompt: str = InputField(description="The text prompt to augment.")
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> StringOutput:
# TODO(ryand): Address the following situations in the input prompt:
# - Prompt contains a TI embeddings.
# - Prompt contains .and() compel syntax. (Is ther any other compel syntax we need to handle?)
# - Prompt contains quotation marks that could cause confusion when embedded in an LLM instruct prompt.
# Load the model and tokenizer.
model_source = "microsoft/Phi-3-mini-4k-instruct"
def model_loader(model_path: Path):
return AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=TorchDevice.choose_torch_dtype(), local_files_only=True
)
def tokenizer_loader(model_path: Path):
return AutoTokenizer.from_pretrained(model_path, local_files_only=True)
with (
context.models.load_remote_model(source=model_source, loader=model_loader) as model,
context.models.load_remote_model(source=model_source, loader=tokenizer_loader) as tokenizer,
):
# Tokenize the input prompt.
augmented_prompt = self._run_instruct_model(model, tokenizer, self.prompt)
return StringOutput(value=augmented_prompt)
def _run_instruct_model(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str) -> str:
messages = [
{
"role": "user",
"content": AUGMENT_PROMPT_INSTRUCTION.format(prompt),
}
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
inputs = inputs.to(model.device)
outputs = model.generate(
inputs,
max_new_tokens=200,
temperature=0.9,
do_sample=True,
)
text = tokenizer.batch_decode(outputs)[0]
assert isinstance(text, str)
output = text.split("<|assistant|>")[-1].strip()
output = output.split("<|end|>")[0].strip()
return output

View File

@ -62,7 +62,13 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee # downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of. # will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): #
# On July 24, 2024, this regex filter was modified to support downloading the `microsoft/Phi-3-mini-4k-instruct`
# model. I am making this note in case it is relevant as we continue to improve this logic and make it less
# brittle.
# - Before: r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
# - After: r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file) paths.append(file)
# limit search to subfolder if requested # limit search to subfolder if requested