mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
port dw_openpose, depth_anything, and lama processors to new model download scheme
This commit is contained in:
parent
3a26c7bb9e
commit
41b909cbe3
@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
raw_image = self.load_image(context)
|
raw_image = self.load_image(context)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image, context)
|
||||||
|
|
||||||
# currently can't see processed image in node UI without a showImage node,
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
@ -189,7 +189,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
processed_image = get_canny_edges(
|
processed_image = get_canny_edges(
|
||||||
image,
|
image,
|
||||||
self.low_threshold,
|
self.low_threshold,
|
||||||
@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
hed_processor = HEDProcessor()
|
hed_processor = HEDProcessor()
|
||||||
processed_image = hed_processor.run(
|
processed_image = hed_processor.run(
|
||||||
image,
|
image,
|
||||||
@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
lineart_processor = LineartProcessor()
|
lineart_processor = LineartProcessor()
|
||||||
processed_image = lineart_processor.run(
|
processed_image = lineart_processor.run(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||||
@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
processor = LineartAnimeProcessor()
|
processor = LineartAnimeProcessor()
|
||||||
processed_image = processor.run(
|
processed_image = processor.run(
|
||||||
image,
|
image,
|
||||||
@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
|
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(
|
||||||
image,
|
image,
|
||||||
@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = normalbae_processor(
|
processed_image: Image.Image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image,
|
image,
|
||||||
@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image,
|
image,
|
||||||
@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(
|
processed_image = content_shuffle_processor(
|
||||||
image,
|
image,
|
||||||
@ -412,7 +413,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(
|
processed_image = mediapipe_face_processor(
|
||||||
image,
|
image,
|
||||||
@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image,
|
||||||
@ -503,7 +504,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
def run_processor(self, img):
|
def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(img, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(
|
processed_np_image = self.tile_resample(
|
||||||
np_img,
|
np_img,
|
||||||
@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||||
@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
np_image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = np_image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
depth_anything_detector = DepthAnythingDetector()
|
depth_anything_detector = DepthAnythingDetector(context)
|
||||||
depth_anything_detector.load_model(model_size=self.model_size)
|
depth_anything_detector.load_model(model_size=self.model_size)
|
||||||
|
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
draw_hands: bool = InputField(default=False)
|
draw_hands: bool = InputField(default=False)
|
||||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
dw_openpose = DWOpenposeDetector()
|
dw_openpose = DWOpenposeDetector(context)
|
||||||
processed_image = dw_openpose(
|
processed_image = dw_openpose(
|
||||||
image,
|
image,
|
||||||
draw_face=self.draw_face,
|
draw_face=self.draw_face,
|
||||||
|
@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def infill(self, image: Image.Image) -> Image.Image:
|
def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||||
"""Infill the image with the specified method"""
|
"""Infill the image with the specified method"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||||
|
|
||||||
# Perform Infill action
|
# Perform Infill action
|
||||||
infilled_image = self.infill(input_image)
|
infilled_image = self.infill(input_image, context)
|
||||||
|
|
||||||
# Create ImageDTO for Infilled Image
|
# Create ImageDTO for Infilled Image
|
||||||
infilled_image_dto = context.images.save(image=infilled_image)
|
infilled_image_dto = context.images.save(image=infilled_image)
|
||||||
@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
|
|||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
|
|||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||||
return output.infilled
|
return output.infilled
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
|||||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
width = int(image.width / self.downscale)
|
width = int(image.width / self.downscale)
|
||||||
@ -132,8 +132,8 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
|||||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
lama = LaMA()
|
lama = LaMA(context)
|
||||||
return lama(image)
|
return lama(image)
|
||||||
|
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
return cv2_inpaint(image)
|
return cv2_inpaint(image)
|
||||||
|
|
||||||
|
|
||||||
@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
description="The max threshold for color",
|
description="The max threshold for color",
|
||||||
)
|
)
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image, context: InvocationContext):
|
||||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||||
|
@ -534,10 +534,10 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Load and cache the model file located at the indicated URL.
|
Download, cache, and Load the model file located at the indicated URL.
|
||||||
|
|
||||||
This will check the model download cache for the model designated
|
This will check the model download cache for the model designated
|
||||||
by the provided URL and download it if needed using download_and_cache_model().
|
by the provided URL and download it if needed using download_and_cache_ckpt().
|
||||||
It will then load the model into the RAM cache. If the optional loader
|
It will then load the model into the RAM cache. If the optional loader
|
||||||
argument is provided, the loader will be invoked to load the model into
|
argument is provided, the loader will be invoked to load the model into
|
||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pathlib
|
from typing import Literal, Optional, Union
|
||||||
from typing import Literal, Union
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,7 +9,7 @@ from PIL import Image
|
|||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
@ -20,18 +19,9 @@ config = get_config()
|
|||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
DEPTH_ANYTHING_MODELS = {
|
||||||
"large": {
|
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||||
},
|
|
||||||
"base": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
|
||||||
},
|
|
||||||
"small": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -53,18 +43,14 @@ transform = Compose(
|
|||||||
|
|
||||||
|
|
||||||
class DepthAnythingDetector:
|
class DepthAnythingDetector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, context: InvocationContext) -> None:
|
||||||
self.model = None
|
self.context = context
|
||||||
|
self.model: Optional[DPT_DINOv2] = None
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
self.device = choose_torch_device()
|
self.device = choose_torch_device()
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
|
||||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
|
||||||
download_with_progress_bar(
|
|
||||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
|
||||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
|
||||||
DEPTH_ANYTHING_MODEL_PATH,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.model or model_size != self.model_size:
|
if not self.model or model_size != self.model_size:
|
||||||
del self.model
|
del self.model
|
||||||
@ -78,7 +64,8 @@ class DepthAnythingDetector:
|
|||||||
case "large":
|
case "large":
|
||||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
assert self.model is not None
|
||||||
|
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.model.to(choose_torch_device())
|
self.model.to(choose_torch_device())
|
||||||
|
@ -3,6 +3,7 @@ import torch
|
|||||||
from controlnet_aux.util import resize_image
|
from controlnet_aux.util import resize_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||||
|
|
||||||
@ -39,8 +40,8 @@ class DWOpenposeDetector:
|
|||||||
Credits: https://github.com/IDEA-Research/DWPose
|
Credits: https://github.com/IDEA-Research/DWPose
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, context: InvocationContext) -> None:
|
||||||
self.pose_estimation = Wholebody()
|
self.pose_estimation = Wholebody(context)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||||
|
@ -4,44 +4,31 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
|
|
||||||
DWPOSE_MODELS = {
|
DWPOSE_MODELS = {
|
||||||
"yolox_l.onnx": {
|
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
|
||||||
},
|
|
||||||
"dw-ll_ucoco_384.onnx": {
|
|
||||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self):
|
def __init__(self, context: InvocationContext):
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
|
|
||||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
|
||||||
download_with_progress_bar(
|
|
||||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
onnx_det = DET_MODEL_PATH
|
|
||||||
onnx_pose = POSE_MODEL_PATH
|
|
||||||
|
|
||||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -6,9 +5,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
@ -28,19 +25,15 @@ def load_jit_model(url_or_path, device):
|
|||||||
|
|
||||||
|
|
||||||
class LaMA:
|
class LaMA:
|
||||||
|
def __init__(self, context: InvocationContext):
|
||||||
|
self._context = context
|
||||||
|
|
||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = choose_torch_device()
|
loaded_model = self._context.models.load_ckpt_from_url(
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
loader=lambda path: load_jit_model(path, "cpu"),
|
||||||
if not model_location.exists():
|
|
||||||
download_with_progress_bar(
|
|
||||||
name="LaMa Inpainting Model",
|
|
||||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
dest_path=model_location,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = load_jit_model(model_location, device)
|
|
||||||
|
|
||||||
image = np.asarray(input_image.convert("RGB"))
|
image = np.asarray(input_image.convert("RGB"))
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
|
|
||||||
@ -48,8 +41,10 @@ class LaMA:
|
|||||||
mask = np.asarray(mask)
|
mask = np.asarray(mask)
|
||||||
mask = np.invert(mask)
|
mask = np.invert(mask)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
mask = (mask > 0) * 1
|
mask = (mask > 0) * 1
|
||||||
|
|
||||||
|
with loaded_model as model:
|
||||||
|
device = next(model.buffers()).device
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
@ -60,8 +55,4 @@ class LaMA:
|
|||||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||||
infilled_image = Image.fromarray(infilled_image)
|
infilled_image = Image.fromarray(infilled_image)
|
||||||
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return infilled_image
|
return infilled_image
|
||||||
|
Loading…
x
Reference in New Issue
Block a user