mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
port dw_openpose, depth_anything, and lama processors to new model download scheme
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
import gc
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@ -6,9 +5,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -28,18 +25,14 @@ def load_jit_model(url_or_path, device):
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, context: InvocationContext):
|
||||
self._context = context
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
loaded_model = self._context.models.load_ckpt_from_url(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=lambda path: load_jit_model(path, "cpu"),
|
||||
)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
@ -48,20 +41,18 @@ class LaMA:
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
with loaded_model as model:
|
||||
device = next(model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return infilled_image
|
||||
|
Reference in New Issue
Block a user