From 3ead827d61fb3c935e3396d43455c12b6b59018f Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 12 Apr 2024 21:05:23 -0400 Subject: [PATCH] port dw_openpose, depth_anything, and lama processors to new model download scheme --- .../controlnet_image_processors.py | 45 ++++++++++--------- invokeai/app/invocations/infill.py | 18 ++++---- .../app/services/shared/invocation_context.py | 4 +- .../image_util/depth_anything/__init__.py | 37 +++++---------- .../image_util/dw_openpose/__init__.py | 5 ++- .../image_util/dw_openpose/wholebody.py | 29 ++++-------- .../backend/image_util/infill_methods/lama.py | 39 +++++++--------- 7 files changed, 72 insertions(+), 105 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index a49c910eeb..12a2ae9c96 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -137,7 +137,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # superclass just passes through image without processing return image @@ -148,7 +148,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? - processed_image = self.run_processor(raw_image) + processed_image = self.run_processor(raw_image, context) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery @@ -189,7 +189,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): # Keep alpha channel for Canny processing to detect edges of transparent areas return context.images.get_pil(self.image.image_name, "RGBA") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processed_image = get_canny_edges( image, self.low_threshold, @@ -216,7 +216,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: hed_processor = HEDProcessor() processed_image = hed_processor.run( image, @@ -243,7 +243,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) coarse: bool = InputField(default=False, description="Whether to use coarse mode") - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: lineart_processor = LineartProcessor() processed_image = lineart_processor.run( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse @@ -264,7 +264,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image) -> Image.Image: + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: processor = LineartAnimeProcessor() processed_image = processor.run( image, @@ -291,7 +291,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, @@ -318,9 +319,9 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") - processed_image = normalbae_processor( + processed_image: Image.Image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution ) return processed_image @@ -337,7 +338,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -360,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -388,7 +389,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -412,7 +413,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -433,7 +434,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -461,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -503,7 +504,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img): + def run_processor(self, img: Image.Image, context: InvocationContext) -> Image.Image: np_img = np.array(img, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, @@ -527,7 +528,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -573,7 +574,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image): + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -608,8 +609,8 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - depth_anything_detector = DepthAnythingDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + depth_anything_detector = DepthAnythingDetector(context) depth_anything_detector.load_model(model_size=self.model_size) processed_image = depth_anything_detector(image=image, resolution=self.resolution) @@ -631,8 +632,8 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - dw_openpose = DWOpenposeDetector() + def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: + dw_openpose = DWOpenposeDetector(context) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 418bc62fdc..edee275e72 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -38,7 +38,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") @abstractmethod - def infill(self, image: Image.Image) -> Image.Image: + def infill(self, image: Image.Image, context: InvocationContext) -> Image.Image: """Infill the image with the specified method""" pass @@ -57,7 +57,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): return ImageOutput.build(context.images.get_dto(self.image.image_name)) # Perform Infill action - infilled_image = self.infill(input_image) + infilled_image = self.infill(input_image, context) # Create ImageDTO for Infilled Image infilled_image_dto = context.images.save(image=infilled_image) @@ -75,7 +75,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation): description="The color to use to infill", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) @@ -94,7 +94,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation): description="The seed to use for tile generation (omit for random)", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): output = infill_tile(image, seed=self.seed, tile_size=self.tile_size) return output.infilled @@ -108,7 +108,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width / self.downscale) @@ -132,8 +132,8 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation): class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" - def infill(self, image: Image.Image): - lama = LaMA() + def infill(self, image: Image.Image, context: InvocationContext): + lama = LaMA(context) return lama(image) @@ -141,7 +141,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): class CV2InfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using OpenCV Inpainting""" - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return cv2_inpaint(image) @@ -163,5 +163,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation): description="The max threshold for color", ) - def infill(self, image: Image.Image): + def infill(self, image: Image.Image, context: InvocationContext): return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple()) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index e97d29d308..0d27b2520b 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -534,10 +534,10 @@ class ModelsInterface(InvocationContextInterface): loader: Optional[Callable[[Path], Dict[str | int, Any]]] = None, ) -> LoadedModel: """ - Load and cache the model file located at the indicated URL. + Download, cache, and Load the model file located at the indicated URL. This will check the model download cache for the model designated - by the provided URL and download it if needed using download_and_cache_model(). + by the provided URL and download it if needed using download_and_cache_ckpt(). It will then load the model into the RAM cache. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call safetensors.torch.load_file() or diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index ccac2ba949..560d977b55 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,5 +1,4 @@ -import pathlib -from typing import Literal, Union +from typing import Literal, Optional, Union import cv2 import numpy as np @@ -10,7 +9,7 @@ from PIL import Image from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.util.devices import choose_torch_device @@ -20,18 +19,9 @@ config = get_config() logger = InvokeAILogger.get_logger(config=config) DEPTH_ANYTHING_MODELS = { - "large": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", - }, - "base": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", - }, - "small": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vits14.pth", - }, + "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", } @@ -53,18 +43,14 @@ transform = Compose( class DepthAnythingDetector: - def __init__(self) -> None: - self.model = None + def __init__(self, context: InvocationContext) -> None: + self.context = context + self.model: Optional[DPT_DINOv2] = None self.model_size: Union[Literal["large", "base", "small"], None] = None self.device = choose_torch_device() - def load_model(self, model_size: Literal["large", "base", "small"] = "small"): - DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] - download_with_progress_bar( - pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, - DEPTH_ANYTHING_MODELS[model_size]["url"], - DEPTH_ANYTHING_MODEL_PATH, - ) + def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: + depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) if not self.model or model_size != self.model_size: del self.model @@ -78,7 +64,8 @@ class DepthAnythingDetector: case "large": self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) + assert self.model is not None + self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) self.model.eval() self.model.to(choose_torch_device()) diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index c258ef2c78..17ca0233c8 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -3,6 +3,7 @@ import torch from controlnet_aux.util import resize_image from PIL import Image +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody @@ -39,8 +40,8 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self) -> None: - self.pose_estimation = Wholebody() + def __init__(self, context: InvocationContext) -> None: + self.pose_estimation = Wholebody(context) def __call__( self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 35d340640d..3628b0abd5 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -4,44 +4,31 @@ import numpy as np import onnxruntime as ort +import torch from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.util.devices import choose_torch_device from .onnxdet import inference_detector from .onnxpose import inference_pose DWPOSE_MODELS = { - "yolox_l.onnx": { - "local": "any/annotators/dwpose/yolox_l.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - }, - "dw-ll_ucoco_384.onnx": { - "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", - }, + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", } config = get_config() class Wholebody: - def __init__(self): + def __init__(self, context: InvocationContext): device = choose_torch_device() - providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] + providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"] - DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] - download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) - - POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"] - download_with_progress_bar( - "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH - ) - - onnx_det = DET_MODEL_PATH - onnx_pose = POSE_MODEL_PATH + onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index fa354aeed1..8c3f33efad 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,4 +1,3 @@ -import gc from typing import Any import numpy as np @@ -6,9 +5,7 @@ import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import choose_torch_device +from invokeai.app.services.shared.invocation_context import InvocationContext def norm_img(np_img): @@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device): class LaMA: + def __init__(self, context: InvocationContext): + self._context = context + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = choose_torch_device() - model_location = get_config().models_path / "core/misc/lama/lama.pt" - - if not model_location.exists(): - download_with_progress_bar( - name="LaMa Inpainting Model", - url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - dest_path=model_location, - ) - - model = load_jit_model(model_location, device) + loaded_model = self._context.models.load_ckpt_from_url( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=lambda path: load_jit_model(path, "cpu"), + ) image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -48,20 +41,18 @@ class LaMA: mask = np.asarray(mask) mask = np.invert(mask) mask = norm_img(mask) - mask = (mask > 0) * 1 - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) - with torch.inference_mode(): - infilled_image = model(image, mask) + with loaded_model as model: + device = next(model.buffers()).device + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) + + with torch.inference_mode(): + infilled_image = model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) - del model - gc.collect() - torch.cuda.empty_cache() - return infilled_image