mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): lift managed model loading out of lama class
This commit is contained in:
parent
57c831442e
commit
fcb071f30c
@ -133,10 +133,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image, context: InvocationContext):
|
||||
# Note that this accesses a protected attribute to get to the model manager service.
|
||||
# Is there a better way?
|
||||
lama = LaMA(context._services.model_manager)
|
||||
return lama(image)
|
||||
with context.models.load_ckpt_from_url(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
return lama(image)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
|
@ -1,13 +1,12 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -18,24 +17,11 @@ def norm_img(np_img):
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device) -> torch.nn.Module:
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, model_manager: "ModelManagerServiceBase"):
|
||||
self._model_manager = model_manager
|
||||
def __init__(self, model: AnyModel):
|
||||
self._model = model
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
loaded_model = self._model_manager.load_ckpt_from_url(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=lambda path: load_jit_model(path, "cpu"),
|
||||
)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
|
||||
@ -45,16 +31,23 @@ class LaMA:
|
||||
mask = norm_img(mask)
|
||||
mask = (mask > 0) * 1
|
||||
|
||||
with loaded_model as model:
|
||||
device = next(model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
device = next(self._model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
with torch.inference_mode():
|
||||
infilled_image = self._model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
return infilled_image
|
||||
|
||||
@staticmethod
|
||||
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||
model.eval()
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user