port dw_openpose, depth_anything, and lama processors to new model download scheme

This commit is contained in:
Lincoln Stein 2024-04-12 21:05:23 -04:00
parent 3a26c7bb9e
commit 41b909cbe3
7 changed files with 72 additions and 105 deletions

View File

@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# superclass just passes through image without processing # superclass just passes through image without processing
return image return image
@ -148,7 +148,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = self.load_image(context) raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image, context)
# currently can't see processed image in node UI without a showImage node, # currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
@ -189,7 +189,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
# Keep alpha channel for Canny processing to detect edges of transparent areas # Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA") return context.images.get_pil(self.image.image_name, "RGBA")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
processed_image = get_canny_edges( processed_image = get_canny_edges(
image, image,
self.low_threshold, self.low_threshold,
@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
hed_processor = HEDProcessor() hed_processor = HEDProcessor()
processed_image = hed_processor.run( processed_image = hed_processor.run(
image, image,
@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode") coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
lineart_processor = LineartProcessor() lineart_processor = LineartProcessor()
processed_image = lineart_processor.run( processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
processor = LineartAnimeProcessor() processor = LineartAnimeProcessor()
processed_image = processor.run( processed_image = processor.run(
image, image,
@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor( processed_image = midas_processor(
image, image,
@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor( processed_image: Image.Image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
) )
return processed_image return processed_image
@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor( processed_image = mlsd_processor(
image, image,
@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor( processed_image = pidi_processor(
image, image,
@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector() content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor( processed_image = content_shuffle_processor(
image, image,
@ -412,7 +413,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image) processed_image = zoe_depth_processor(image)
return processed_image return processed_image
@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor( processed_image = mediapipe_face_processor(
image, image,
@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor( processed_image = leres_processor(
image, image,
@ -503,7 +504,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img return np_img
def run_processor(self, img): def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image:
np_img = np.array(img, dtype=np.uint8) np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample( processed_np_image = self.tile_resample(
np_img, np_img,
@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints" "ybelkada/segment-anything", subfolder="checkpoints"
@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size) color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
np_image = np.array(image, dtype=np.uint8) np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2] height, width = np_image.shape[:2]
@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
) )
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
depth_anything_detector = DepthAnythingDetector() depth_anything_detector = DepthAnythingDetector(context)
depth_anything_detector.load_model(model_size=self.model_size) depth_anything_detector.load_model(model_size=self.model_size)
processed_image = depth_anything_detector(image=image, resolution=self.resolution) processed_image = depth_anything_detector(image=image, resolution=self.resolution)
@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False) draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image): def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
dw_openpose = DWOpenposeDetector() dw_openpose = DWOpenposeDetector(context)
processed_image = dw_openpose( processed_image = dw_openpose(
image, image,
draw_face=self.draw_face, draw_face=self.draw_face,

View File

@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
@abstractmethod @abstractmethod
def infill(self, image: Image.Image) -> Image.Image: def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image:
"""Infill the image with the specified method""" """Infill the image with the specified method"""
pass pass
@ -57,7 +57,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(context.images.get_dto(self.image.image_name)) return ImageOutput.build(context.images.get_dto(self.image.image_name))
# Perform Infill action # Perform Infill action
infilled_image = self.infill(input_image) infilled_image = self.infill(input_image, context)
# Create ImageDTO for Infilled Image # Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image) infilled_image_dto = context.images.save(image=infilled_image)
@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
description="The color to use to infill", description="The color to use to infill",
) )
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1]) infilled.paste(image, (0, 0), image.split()[-1])
@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
description="The seed to use for tile generation (omit for random)", description="The seed to use for tile generation (omit for random)",
) )
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size) output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled return output.infilled
@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale) width = int(image.width / self.downscale)
@ -132,8 +132,8 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
class LaMaInfillInvocation(InfillImageProcessorInvocation): class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
lama = LaMA() lama = LaMA(context)
return lama(image) return lama(image)
@ -141,7 +141,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
class CV2InfillInvocation(InfillImageProcessorInvocation): class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting""" """Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
return cv2_inpaint(image) return cv2_inpaint(image)
@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
description="The max threshold for color", description="The max threshold for color",
) )
def infill(self, image: Image.Image): def infill(self, image: Image.Image, context: InvocationContext):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple()) return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@ -534,10 +534,10 @@ class ModelsInterface(InvocationContextInterface):
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Load and cache the model file located at the indicated URL. Download, cache, and Load the model file located at the indicated URL.
This will check the model download cache for the model designated This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_model(). by the provided URL and download it if needed using download_and_cache_ckpt().
It will then load the model into the RAM cache. If the optional loader It will then load the model into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or memory. Otherwise the method will call safetensors.torch.load_file() or

View File

@ -1,5 +1,4 @@
import pathlib from typing import Literal, Optional, Union
from typing import Literal, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -10,7 +9,7 @@ from PIL import Image
from torchvision.transforms import Compose from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
@ -20,18 +19,9 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = { DEPTH_ANYTHING_MODELS = {
"large": { "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth", "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
} }
@ -53,18 +43,14 @@ transform = Compose(
class DepthAnythingDetector: class DepthAnythingDetector:
def __init__(self) -> None: def __init__(self, context: InvocationContext) -> None:
self.model = None self.context = context
self.model: Optional[DPT_DINOv2] = None
self.model_size: Union[Literal["large", "base", "small"], None] = None self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device() self.device = choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"): def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
download_with_progress_bar(
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
DEPTH_ANYTHING_MODELS[model_size]["url"],
DEPTH_ANYTHING_MODEL_PATH,
)
if not self.model or model_size != self.model_size: if not self.model or model_size != self.model_size:
del self.model del self.model
@ -78,7 +64,8 @@ class DepthAnythingDetector:
case "large": case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) assert self.model is not None
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval() self.model.eval()
self.model.to(choose_torch_device()) self.model.to(choose_torch_device())

View File

@ -3,6 +3,7 @@ import torch
from controlnet_aux.util import resize_image from controlnet_aux.util import resize_image
from PIL import Image from PIL import Image
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
@ -39,8 +40,8 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose Credits: https://github.com/IDEA-Research/DWPose
""" """
def __init__(self) -> None: def __init__(self, context: InvocationContext) -> None:
self.pose_estimation = Wholebody() self.pose_estimation = Wholebody(context)
def __call__( def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512

View File

@ -4,44 +4,31 @@
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from .onnxdet import inference_detector from .onnxdet import inference_detector
from .onnxpose import inference_pose from .onnxpose import inference_pose
DWPOSE_MODELS = { DWPOSE_MODELS = {
"yolox_l.onnx": { "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"local": "any/annotators/dwpose/yolox_l.onnx", "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
} }
config = get_config() config = get_config()
class Wholebody: class Wholebody:
def __init__(self): def __init__(self, context: InvocationContext):
device = choose_torch_device() device = choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@ -1,4 +1,3 @@
import gc
from typing import Any from typing import Any
import numpy as np import numpy as np
@ -6,9 +5,7 @@ import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
def norm_img(np_img): def norm_img(np_img):
@ -28,19 +25,15 @@ def load_jit_model(url_or_path, device):
class LaMA: class LaMA:
def __init__(self, context: InvocationContext):
self._context = context
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device() loaded_model = self._context.models.load_ckpt_from_url(
model_location = get_config().models_path / "core/misc/lama/lama.pt" source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
) )
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB")) image = np.asarray(input_image.convert("RGB"))
image = norm_img(image) image = norm_img(image)
@ -48,8 +41,10 @@ class LaMA:
mask = np.asarray(mask) mask = np.asarray(mask)
mask = np.invert(mask) mask = np.invert(mask)
mask = norm_img(mask) mask = norm_img(mask)
mask = (mask > 0) * 1 mask = (mask > 0) * 1
with loaded_model as model:
device = next(model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device) image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
@ -60,8 +55,4 @@ class LaMA:
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image) infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image return infilled_image