diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 50fec22994..f8358d1df5 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -133,10 +133,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image, context: InvocationContext): - # Note that this accesses a protected attribute to get to the model manager service. - # Is there a better way? - lama = LaMA(context._services.model_manager) - return lama(image) + with context.models.load_ckpt_from_url( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=LaMA.load_jit_model, + ) as model: + lama = LaMA(model) + return lama(image) @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index c7fea497ca..cd5838d1f2 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,13 +1,12 @@ -from typing import TYPE_CHECKING, Any +from pathlib import Path +from typing import Any import numpy as np import torch from PIL import Image import invokeai.backend.util.logging as logger - -if TYPE_CHECKING: - from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.backend.model_manager.config import AnyModel def norm_img(np_img): @@ -18,24 +17,11 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device) -> torch.nn.Module: - model_path = url_or_path - logger.info(f"Loading model from: {model_path}") - model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore - model.eval() - return model - - class LaMA: - def __init__(self, model_manager: "ModelManagerServiceBase"): - self._model_manager = model_manager + def __init__(self, model: AnyModel): + self._model = model def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - loaded_model = self._model_manager.load_ckpt_from_url( - source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - loader=lambda path: load_jit_model(path, "cpu"), - ) - image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -45,16 +31,23 @@ class LaMA: mask = norm_img(mask) mask = (mask > 0) * 1 - with loaded_model as model: - device = next(model.buffers()).device - image = torch.from_numpy(image).unsqueeze(0).to(device) - mask = torch.from_numpy(mask).unsqueeze(0).to(device) + device = next(self._model.buffers()).device + image = torch.from_numpy(image).unsqueeze(0).to(device) + mask = torch.from_numpy(mask).unsqueeze(0).to(device) - with torch.inference_mode(): - infilled_image = model(image, mask) + with torch.inference_mode(): + infilled_image = self._model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) return infilled_image + + @staticmethod + def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module: + model_path = url_or_path + logger.info(f"Loading model from: {model_path}") + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore + model.eval() + return model