mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
port dw_openpose, depth_anything, and lama processors to new model download scheme
This commit is contained in:
parent
c140d3b1df
commit
3ead827d61
@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
@ -148,7 +148,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
processed_image = self.run_processor(raw_image, context)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
@ -189,7 +189,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
processed_image = get_canny_edges(
|
||||
image,
|
||||
self.low_threshold,
|
||||
@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
lineart_processor = LineartProcessor()
|
||||
processed_image = lineart_processor.run(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||
@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
processed_image: Image.Image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
@ -412,7 +413,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
@ -503,7 +504,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, img):
|
||||
def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
depth_anything_detector = DepthAnythingDetector(context)
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
dw_openpose = DWOpenposeDetector()
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
dw_openpose = DWOpenposeDetector(context)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
@ -57,7 +57,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
infilled_image = self.infill(input_image, context)
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
return output.infilled
|
||||
|
||||
@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
width = int(image.width / self.downscale)
|
||||
@ -132,8 +132,8 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
lama = LaMA(context)
|
||||
return lama(image)
|
||||
|
||||
|
||||
@ -141,7 +141,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
|
@ -534,10 +534,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Load and cache the model file located at the indicated URL.
|
||||
Download, cache, and Load the model file located at the indicated URL.
|
||||
|
||||
This will check the model download cache for the model designated
|
||||
by the provided URL and download it if needed using download_and_cache_model().
|
||||
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||
It will then load the model into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pathlib
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -10,7 +9,7 @@ from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -20,18 +19,9 @@ config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||
},
|
||||
"base": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||
},
|
||||
"small": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||
},
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
@ -53,18 +43,14 @@ transform = Compose(
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
def __init__(self, context: InvocationContext) -> None:
|
||||
self.context = context
|
||||
self.model: Optional[DPT_DINOv2] = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
download_with_progress_bar(
|
||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
||||
DEPTH_ANYTHING_MODEL_PATH,
|
||||
)
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
|
||||
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
|
||||
|
||||
if not self.model or model_size != self.model_size:
|
||||
del self.model
|
||||
@ -78,7 +64,8 @@ class DepthAnythingDetector:
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
assert self.model is not None
|
||||
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(choose_torch_device())
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
|
||||
@ -39,8 +40,8 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pose_estimation = Wholebody()
|
||||
def __init__(self, context: InvocationContext) -> None:
|
||||
self.pose_estimation = Wholebody(context)
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||
|
@ -4,44 +4,31 @@
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": {
|
||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
},
|
||||
"dw-ll_ucoco_384.onnx": {
|
||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
},
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
def __init__(self, context: InvocationContext):
|
||||
device = choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
||||
download_with_progress_bar(
|
||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
||||
)
|
||||
|
||||
onnx_det = DET_MODEL_PATH
|
||||
onnx_pose = POSE_MODEL_PATH
|
||||
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@ -6,9 +5,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device):
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, context: InvocationContext):
|
||||
self._context = context
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
loaded_model = self._context.models.load_ckpt_from_url(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=lambda path: load_jit_model(path, "cpu"),
|
||||
)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
@ -48,20 +41,18 @@ class LaMA:
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
with loaded_model as model:
|
||||
device = next(model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return infilled_image
|
||||
|
Loading…
Reference in New Issue
Block a user