mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into refactor/rename-get-logger
This commit is contained in:
@ -1,19 +1,19 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama"]
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
@ -563,7 +563,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.image.image_name).convert("L")
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
||||
@ -700,8 +700,13 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if self.mask_blur_radius > 0:
|
||||
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
inverted_nm = 255 - nm
|
||||
dilation_size = int(round(self.mask_blur_radius) + 20)
|
||||
dilating_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size))
|
||||
inverted_dilated_nm = cv2.dilate(inverted_nm, dilating_kernel)
|
||||
dilated_nm = 255 - inverted_dilated_nm
|
||||
nmd = cv2.erode(
|
||||
nm,
|
||||
dilated_nm,
|
||||
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
||||
iterations=int(self.mask_blur_radius / 2),
|
||||
)
|
||||
@ -773,39 +778,95 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
COLOR_CHANNELS = Literal[
|
||||
"Red (RGBA)",
|
||||
"Green (RGBA)",
|
||||
"Blue (RGBA)",
|
||||
"Alpha (RGBA)",
|
||||
"Cyan (CMYK)",
|
||||
"Magenta (CMYK)",
|
||||
"Yellow (CMYK)",
|
||||
"Black (CMYK)",
|
||||
"Hue (HSV)",
|
||||
"Saturation (HSV)",
|
||||
"Value (HSV)",
|
||||
"Luminosity (LAB)",
|
||||
"A (LAB)",
|
||||
"B (LAB)",
|
||||
"Y (YCbCr)",
|
||||
"Cb (YCbCr)",
|
||||
"Cr (YCbCr)",
|
||||
]
|
||||
|
||||
CHANNEL_FORMATS = {
|
||||
"Red (RGBA)": ("RGBA", 0),
|
||||
"Green (RGBA)": ("RGBA", 1),
|
||||
"Blue (RGBA)": ("RGBA", 2),
|
||||
"Alpha (RGBA)": ("RGBA", 3),
|
||||
"Cyan (CMYK)": ("CMYK", 0),
|
||||
"Magenta (CMYK)": ("CMYK", 1),
|
||||
"Yellow (CMYK)": ("CMYK", 2),
|
||||
"Black (CMYK)": ("CMYK", 3),
|
||||
"Hue (HSV)": ("HSV", 0),
|
||||
"Saturation (HSV)": ("HSV", 1),
|
||||
"Value (HSV)": ("HSV", 2),
|
||||
"Luminosity (LAB)": ("LAB", 0),
|
||||
"A (LAB)": ("LAB", 1),
|
||||
"B (LAB)": ("LAB", 2),
|
||||
"Y (YCbCr)": ("YCbCr", 0),
|
||||
"Cb (YCbCr)": ("YCbCr", 1),
|
||||
"Cr (YCbCr)": ("YCbCr", 2),
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_luminosity_adjust",
|
||||
title="Adjust Image Luminosity",
|
||||
tags=["image", "luminosity", "hsl"],
|
||||
"img_channel_offset",
|
||||
title="Offset Image Channel",
|
||||
tags=[
|
||||
"image",
|
||||
"offset",
|
||||
"red",
|
||||
"green",
|
||||
"blue",
|
||||
"alpha",
|
||||
"cyan",
|
||||
"magenta",
|
||||
"yellow",
|
||||
"black",
|
||||
"hue",
|
||||
"saturation",
|
||||
"luminosity",
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
luminosity: float = InputField(
|
||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||
)
|
||||
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
|
||||
# Adjust the luminosity (value)
|
||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
# Put the channel back into the image
|
||||
converted_image[:, :, channel_number] = image_channel
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
@ -827,36 +888,60 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_saturation_adjust",
|
||||
title="Adjust Image Saturation",
|
||||
tags=["image", "saturation", "hsl"],
|
||||
"img_channel_multiply",
|
||||
title="Multiply Image Channel",
|
||||
tags=[
|
||||
"image",
|
||||
"invert",
|
||||
"scale",
|
||||
"multiply",
|
||||
"red",
|
||||
"green",
|
||||
"blue",
|
||||
"alpha",
|
||||
"cyan",
|
||||
"magenta",
|
||||
"yellow",
|
||||
"black",
|
||||
"hue",
|
||||
"saturation",
|
||||
"luminosity",
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
"""Scale a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
|
||||
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
|
||||
# Adjust the saturation
|
||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel * self.scale, 0, 255)
|
||||
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
# Invert the channel if requested
|
||||
if self.invert_channel:
|
||||
image_channel = 255 - image_channel
|
||||
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
# Put the channel back into the image
|
||||
converted_image[:, :, channel_number] = image_channel
|
||||
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
|
@ -8,19 +8,17 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
"lama",
|
||||
]
|
||||
methods = ["tile", "solid", "lama", "cv2"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
@ -49,6 +47,10 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
return im_patched
|
||||
|
||||
|
||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||
return cv2_inpaint(im)
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
@ -194,15 +196,35 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
infill_image = image.copy()
|
||||
width = int(image.width / self.downscale)
|
||||
height = int(image.height / self.downscale)
|
||||
infill_image = infill_image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
infilled = infill_patchmatch(infill_image)
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
infilled = infilled.resize(
|
||||
(image.width, image.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
@ -245,3 +267,30 @@ class LaMaInfillInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class CV2InfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -182,7 +182,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("iterate")
|
||||
@invocation("iterate", version="1.0.0")
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
@ -203,7 +203,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("collect")
|
||||
@invocation("collect", version="1.0.0")
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
|
||||
|
Reference in New Issue
Block a user