mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
experimental: LaMa Infill
This commit is contained in:
committed by
Kent Keirsey
parent
7bb876a79b
commit
49892faee4
56
invokeai/backend/image_util/lama.py
Normal file
56
invokeai/backend/image_util/lama.py
Normal file
@ -0,0 +1,56 @@
|
||||
import gc
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
if len(np_img.shape) == 2:
|
||||
np_img = np_img[:, :, np.newaxis]
|
||||
np_img = np.transpose(np_img, (2, 0, 1))
|
||||
np_img = np_img.astype("float32") / 255
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
model_path = url_or_path
|
||||
print(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_invokeai_config().models_path / 'core/misc/lama/lama.pt'
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
|
||||
mask = input_image.split()[-1]
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
infilled_image = model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
return infilled_image
|
Reference in New Issue
Block a user