mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Create new data structures for captioned images, and a list of captioned images. Create auto_caption_image node which can take a single image or list of images to caption
This commit is contained in:
parent
a18d7adad4
commit
59327e827b
@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, List, Union
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -15,7 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput, CaptionImageOutputs, CaptionImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
@ -66,6 +67,56 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"auto_caption_image",
|
||||
title="Automatically Caption Image",
|
||||
tags=["image", "caption"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
)
|
||||
class CaptionImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Adds a caption to an image"""
|
||||
|
||||
images: Union[ImageField,List[ImageField]] = InputField(description="The image to caption")
|
||||
prompt: str = InputField(default="Describe this list of images in 20 words or less", description="Describe how you would like the image to be captioned.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CaptionImageOutputs:
|
||||
|
||||
model_id = "vikhyatk/moondream2"
|
||||
model_revision = "2024-04-02"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
|
||||
moondream_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, revision=model_revision
|
||||
)
|
||||
output: CaptionImageOutputs = CaptionImageOutputs()
|
||||
try:
|
||||
from PIL.Image import Image
|
||||
images: List[Image] = []
|
||||
image_fields = self.images if isinstance(self.images, list) else [self.images]
|
||||
for image in image_fields:
|
||||
images.append(context.images.get_pil(image.image_name))
|
||||
answers: List[str] = moondream_model.batch_answer(
|
||||
images=images,
|
||||
prompts=[self.prompt] * len(images),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
assert isinstance(answers, list)
|
||||
for i, answer in enumerate(answers):
|
||||
output.images.append(CaptionImageOutput(
|
||||
image=image_fields[i],
|
||||
width=images[i].width,
|
||||
height=images[i].height,
|
||||
caption=answer
|
||||
))
|
||||
except:
|
||||
raise
|
||||
finally:
|
||||
del moondream_model
|
||||
del tokenizer
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
@ -194,7 +245,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
image: List[ImageField] = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
@ -247,6 +247,17 @@ class ImageOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("captioned_image_output")
|
||||
class CaptionImageOutput(ImageOutput):
|
||||
caption: str = OutputField(description="Caption for given image")
|
||||
|
||||
|
||||
|
||||
@invocation_output("captioned_image_outputs")
|
||||
class CaptionImageOutputs(BaseInvocationOutput):
|
||||
images: List[CaptionImageOutput] = OutputField(description="List of captioned images", default=[])
|
||||
|
||||
|
||||
@invocation_output("image_collection_output")
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
Loading…
Reference in New Issue
Block a user