feat(backend): lift managed model loading out of lama class

This commit is contained in:
psychedelicious 2024-04-29 08:12:51 +10:00
parent 57c831442e
commit fcb071f30c
2 changed files with 24 additions and 29 deletions

View File

@ -133,9 +133,11 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model""" """Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image, context: InvocationContext): def infill(self, image: Image.Image, context: InvocationContext):
# Note that this accesses a protected attribute to get to the model manager service. with context.models.load_ckpt_from_url(
# Is there a better way? source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
lama = LaMA(context._services.model_manager) loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image) return lama(image)

View File

@ -1,13 +1,12 @@
from typing import TYPE_CHECKING, Any from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
if TYPE_CHECKING:
from invokeai.app.services.model_manager import ModelManagerServiceBase
def norm_img(np_img): def norm_img(np_img):
@ -18,24 +17,11 @@ def norm_img(np_img):
return np_img return np_img
def load_jit_model(url_or_path, device) -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model
class LaMA: class LaMA:
def __init__(self, model_manager: "ModelManagerServiceBase"): def __init__(self, model: AnyModel):
self._model_manager = model_manager self._model = model
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
loaded_model = self._model_manager.load_ckpt_from_url(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=lambda path: load_jit_model(path, "cpu"),
)
image = np.asarray(input_image.convert("RGB")) image = np.asarray(input_image.convert("RGB"))
image = norm_img(image) image = norm_img(image)
@ -45,16 +31,23 @@ class LaMA:
mask = norm_img(mask) mask = norm_img(mask)
mask = (mask > 0) * 1 mask = (mask > 0) * 1
with loaded_model as model: device = next(self._model.buffers()).device
device = next(model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device) image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode(): with torch.inference_mode():
infilled_image = model(image, mask) infilled_image = self._model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image) infilled_image = Image.fromarray(infilled_image)
return infilled_image return infilled_image
@staticmethod
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model