Create new data structures for captioned images, and a list of captioned images. Create auto_caption_image node which can take a single image or list of images to caption

This commit is contained in:
brandonrising 2024-05-17 14:31:33 -04:00
parent a18d7adad4
commit 59327e827b
2 changed files with 66 additions and 4 deletions

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional from typing import Literal, Optional, List, Union
import cv2 import cv2
import numpy import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
from transformers import AutoModelForCausalLM, AutoTokenizer
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
WithBoard, WithBoard,
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
@invocation(
"auto_caption_image",
title="Automatically Caption Image",
tags=["image", "caption"],
category="image",
version="1.2.2",
)
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Adds a caption to an image"""
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
model_id = "vikhyatk/moondream2"
model_revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
moondream_model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=model_revision
)
output: CaptionImageOutputs = CaptionImageOutputs()
try:
from PIL.Image import Image
images: List[Image] = []
image_fields = self.images if isinstance(self.images, list) else [self.images]
for image in image_fields:
images.append(context.images.get_pil(image.image_name))
answers: List[str] = moondream_model.batch_answer(
images=images,
prompts=[self.prompt] * len(images),
tokenizer=tokenizer,
)
assert isinstance(answers, list)
for i, answer in enumerate(answers):
output.images.append(CaptionImageOutput(
image=image_fields[i],
width=images[i].width,
height=images[i].height,
caption=answer
))
except:
raise
finally:
del moondream_model
del tokenizer
return output
@invocation( @invocation(
"img_crop", "img_crop",
title="Crop Image", title="Crop Image",
@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Extracts the alpha channel of an image as a mask.""" """Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from") image: List[ImageField] = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask") invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional from typing import Optional, List
import torch import torch
@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
) )
@invocation_output("captioned_image_output")
class CaptionImageOutput(ImageOutput):
caption: str = OutputField(description="Caption for given image")
@invocation_output("captioned_image_outputs")
class CaptionImageOutputs(BaseInvocationOutput):
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
@invocation_output("image_collection_output") @invocation_output("image_collection_output")
class ImageCollectionOutput(BaseInvocationOutput): class ImageCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of images""" """Base class for nodes that output a collection of images"""