mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): download upscaling & lama models as they are requested
This commit is contained in:
parent
97f16b2b7e
commit
fabef8b45b
@ -9,6 +9,7 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
@ -217,6 +218,13 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
@ -11,6 +11,7 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -27,6 +28,13 @@ ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"ESRGAN_SRx4_DF2KOST_official.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
}
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
@ -45,7 +53,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
@ -92,11 +99,16 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=models_path / esrgan_model_path,
|
||||
model_path=esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
|
51
invokeai/app/util/download_with_progress.py
Normal file
51
invokeai/app/util/download_with_progress.py
Normal file
@ -0,0 +1,51 @@
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
||||
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
||||
"""Download a file from a URL to a destination path, with a progress bar.
|
||||
If the file already exists, it will not be downloaded again.
|
||||
|
||||
Exceptions are not caught.
|
||||
|
||||
Args:
|
||||
name (str): Name of the file being downloaded.
|
||||
url (str): URL to download the file from.
|
||||
dest_path (Path): Destination path to save the file to.
|
||||
|
||||
Returns:
|
||||
bool: True if the file was downloaded, False if it already existed.
|
||||
"""
|
||||
if dest_path.exists():
|
||||
return False # already downloaded
|
||||
|
||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
||||
|
||||
return True
|
Loading…
Reference in New Issue
Block a user