This commit is contained in:
Ryan Dick 2024-07-23 16:52:35 -04:00
parent b7a1086325
commit 8c09b345ec

View File

@ -41,7 +41,7 @@ Here are the captions describing the image:
Simple caption: "{initial_prompt}" Simple caption: "{initial_prompt}"
Detailed caption: "{detailed_prompt}" Detailed caption: "{detailed_prompt}"
Now imagine a possible style for the image and generate a list of words that describe the style. Now imagine a possible style for the image and generate a list of five words separated by commas that describe the style.
Style words:""" Style words:"""
output = run_model(model, tokenizer, instruction) output = run_model(model, tokenizer, instruction)
@ -49,7 +49,7 @@ Style words:"""
def generate_augmented_prompt(model, tokenizer, initial_prompt): def generate_augmented_prompt(model, tokenizer, initial_prompt):
instruction = f"""Your task is to translate a short image caption to a more detailed caption for the same image. The detailed caption should adhere to the following: instruction = f"""Your task is to translate a short image caption and a style caption to a more detailed caption for the same image. The detailed caption should adhere to the following:
- be 1 sentence long - be 1 sentence long
- use descriptive language that relates to the subject of interest - use descriptive language that relates to the subject of interest
- it may add new details, but shouldn't change the subject of the original caption - it may add new details, but shouldn't change the subject of the original caption
@ -71,6 +71,41 @@ Detailed caption:"""
return output_prompt return output_prompt
STYLES = [
"photography, RAW, high resolution, masterpiece, film grain, 8k, dynamic lighting",
"2D animation, cartoon, digital art, vibrant colors",
"3D animation, animated, cartoon, render, soft shadows, subsurface scattering",
"painting, brush strokes, impressionism, oil painting",
"sci-fi, futuristic, neon, cyberpunk, dystopian, gritty",
# "Pop Art": "vibrant, bold, commercial, graphic, iconic",
# "Impressionism": "soft, brushstroke, atmospheric, pastel, fleeting",
# "Minimalism": "clean, simple, monochrome, sparse, understated",
# "Surrealism": "dreamlike, bizarre, illogical, subconscious, imaginative",
# "Cyberpunk": "neon, dystopian, futuristic, urban, gritty",
# "Gothic Art": "dark, medieval, ornate, macabre, dramatic",
# "Art Nouveau": "organic, flowing, decorative, floral, curvilinear",
# "Baroque": "ornate, dramatic, dynamic, extravagant, rich",
# "Renaissance Art": "classical, realistic, balanced, harmonious, humanistic",
# "Street Art": "graffiti, urban, bold, rebellious, contemporary",
# "Impressionist Photography": "blurry, atmospheric, evocative, impressionistic, dreamy",
# "Retro Futurism": "nostalgic, sci-fi, retro, utopian, colorful",
# "Anime Style": "exaggerated, cel-shaded, expressive, stylized, colorful",
# "Manga Style": "graphic, dynamic, expressive, inked, stylized",
# "Watercolor Painting": "translucent, flowing, wash, gradient, ethereal",
# "Pixel Art": "retro, low-res, grid-based, nostalgic, 8-bit",
# "Fantasy Art": "mythical, epic, magical, imaginative, detailed",
# "Comic Book Style": "bold, inked, dynamic, panel, narrative",
# "Photorealism": "meticulous, detailed, lifelike, precise, high-definition",
# "Vintage Poster": "retro, nostalgic, graphic, bold, stylized",
# "Fantasy Illustration": "enchanting, whimsical, colorful, imaginative, ethereal",
# "Anime Realism": "hybrid, stylized, polished, detailed, emotive",
# "Hyperrealism": "ultra-detailed, uncanny, lifelike, precise, meticulous",
# "Sci-Fi Illustration": "futuristic, imaginative, detailed, cosmic, speculative",
# "Anime Chibi": "cute, exaggerated, simplified, colorful, whimsical",
# "Vintage Photography": "sepia, nostalgic, timeless, film, classic",
]
def main(): def main():
device = torch.device("cuda") device = torch.device("cuda")
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.float16)
@ -119,12 +154,14 @@ def main():
# torch.random.manual_seed(1234) # torch.random.manual_seed(1234)
for initial_prompt in test_prompts: for initial_prompt in test_prompts:
# Randomly select a style.
style = STYLES[0]
print("----------------------") print("----------------------")
detailed_prompt = generate_augmented_prompt(model, tokenizer, initial_prompt) detailed_prompt = generate_augmented_prompt(model, tokenizer, initial_prompt)
style_prompt = generate_style_prompt(model, tokenizer, initial_prompt, detailed_prompt) # style_prompt = generate_style_prompt(model, tokenizer, initial_prompt, detailed_prompt)
print(f"Original Prompt: '{initial_prompt}'\n\n") print(f"Original Prompt: '{initial_prompt}'\n\n")
print(f"Detailed Prompt: '{detailed_prompt}'\n\n") print(f"Detailed Prompt: '{detailed_prompt}'\n\n")
print(f"Style Prompt: '{style_prompt}'\n\n") # print(f"Style Prompt: '{style_prompt}'\n\n")
if __name__ == "__main__": if __name__ == "__main__":