2023-08-23 19:25:24 +00:00
|
|
|
import gc
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from PIL import Image
|
|
|
|
|
2023-08-31 21:17:41 +00:00
|
|
|
import invokeai.backend.util.logging as logger
|
2024-03-11 12:01:48 +00:00
|
|
|
from invokeai.app.services.config.config_default import get_config
|
2024-03-20 03:17:16 +00:00
|
|
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
2024-04-15 13:12:49 +00:00
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
2023-08-23 19:25:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
def norm_img(np_img):
|
|
|
|
if len(np_img.shape) == 2:
|
|
|
|
np_img = np_img[:, :, np.newaxis]
|
|
|
|
np_img = np.transpose(np_img, (2, 0, 1))
|
|
|
|
np_img = np_img.astype("float32") / 255
|
|
|
|
return np_img
|
|
|
|
|
|
|
|
|
|
|
|
def load_jit_model(url_or_path, device):
|
|
|
|
model_path = url_or_path
|
2023-08-31 21:17:41 +00:00
|
|
|
logger.info(f"Loading model from: {model_path}")
|
2023-08-23 19:25:24 +00:00
|
|
|
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
class LaMA:
|
|
|
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
2024-04-15 13:12:49 +00:00
|
|
|
device = TorchDevice.choose_torch_device()
|
2024-03-11 12:01:48 +00:00
|
|
|
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
2024-03-20 03:17:16 +00:00
|
|
|
|
|
|
|
if not model_location.exists():
|
|
|
|
download_with_progress_bar(
|
|
|
|
name="LaMa Inpainting Model",
|
|
|
|
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
|
|
dest_path=model_location,
|
|
|
|
)
|
|
|
|
|
2023-08-23 19:25:24 +00:00
|
|
|
model = load_jit_model(model_location, device)
|
|
|
|
|
|
|
|
image = np.asarray(input_image.convert("RGB"))
|
|
|
|
image = norm_img(image)
|
|
|
|
|
|
|
|
mask = input_image.split()[-1]
|
|
|
|
mask = np.asarray(mask)
|
|
|
|
mask = np.invert(mask)
|
|
|
|
mask = norm_img(mask)
|
|
|
|
|
|
|
|
mask = (mask > 0) * 1
|
|
|
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
|
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
|
|
|
2023-08-23 20:51:48 +00:00
|
|
|
with torch.inference_mode():
|
|
|
|
infilled_image = model(image, mask)
|
2023-08-23 19:25:24 +00:00
|
|
|
|
|
|
|
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
|
|
|
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
|
|
|
infilled_image = Image.fromarray(infilled_image)
|
|
|
|
|
|
|
|
del model
|
|
|
|
gc.collect()
|
2023-09-01 17:50:39 +00:00
|
|
|
torch.cuda.empty_cache()
|
2023-08-23 19:25:24 +00:00
|
|
|
|
|
|
|
return infilled_image
|