port dw_openpose, depth_anything, and lama processors to new model download scheme

This commit is contained in:
Lincoln Stein
2024-04-12 21:05:23 -04:00
parent 3a26c7bb9e
commit 41b909cbe3
7 changed files with 72 additions and 105 deletions

View File

@ -1,4 +1,3 @@
import gc
from typing import Any
import numpy as np
@ -6,9 +5,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.app.services.shared.invocation_context import InvocationContext
def norm_img(np_img):
@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __init__(self, context: InvocationContext):
self._context = context
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
loaded_model = self._context.models.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@ -48,20 +41,18 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
with loaded_model as model:
device = next(model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image