InvokeAI/invokeai/backend/image_util/infill_methods/lama.py
Lincoln Stein e93f4d632d
[util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-04-15 13:12:49 +00:00

68 lines
2.1 KiB
Python

import gc
from typing import Any
import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
def norm_img(np_img):
if len(np_img.shape) == 2:
np_img = np_img[:, :, np.newaxis]
np_img = np.transpose(np_img, (2, 0, 1))
np_img = np_img.astype("float32") / 255
return np_img
def load_jit_model(url_or_path, device):
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model.eval()
return model
class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
mask = input_image.split()[-1]
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image