diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index e9ab2639de..8d14c0a8fe 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -9,6 +9,7 @@ from PIL import Image, ImageOps from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA @@ -217,6 +218,13 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) + # Downloads the LaMa model if it doesn't already exist + download_with_progress_bar( + name="LaMa Inpainting Model", + url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + dest_path=context.config.get().models_path / "core/misc/lama/lama.pt", + ) + infilled = infill_lama(image.copy()) image_dto = context.images.save(image=infilled) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index b6d127118d..77c0be858f 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -11,6 +11,7 @@ from pydantic import ConfigDict from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -27,6 +28,13 @@ ESRGAN_MODELS = Literal[ "RealESRGAN_x2plus.pth", ] +ESRGAN_MODEL_URLS: dict[str, str] = { + "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "ESRGAN_SRx4_DF2KOST_official.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", +} + if choose_torch_device() == torch.device("mps"): from torch import mps @@ -45,7 +53,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) - models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -92,11 +99,16 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") + esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") + + # Downloads the ESRGAN model if it doesn't already exist + download_with_progress_bar( + name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + ) upscaler = RealESRGAN( scale=netscale, - model_path=models_path / esrgan_model_path, + model_path=esrgan_model_path, model=rrdbnet_model, half=False, tile=self.tile_size, diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py new file mode 100644 index 0000000000..97a2abb2f6 --- /dev/null +++ b/invokeai/app/util/download_with_progress.py @@ -0,0 +1,51 @@ +from pathlib import Path +from urllib import request + +from tqdm import tqdm + +from invokeai.backend.util.logging import InvokeAILogger + + +class ProgressBar: + """Simple progress bar for urllib.request.urlretrieve using tqdm.""" + + def __init__(self, model_name: str = "file"): + self.pbar = None + self.name = model_name + + def __call__(self, block_num: int, block_size: int, total_size: int): + if not self.pbar: + self.pbar = tqdm( + desc=self.name, + initial=0, + unit="iB", + unit_scale=True, + unit_divisor=1000, + total=total_size, + ) + self.pbar.update(block_size) + + +def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool: + """Download a file from a URL to a destination path, with a progress bar. + If the file already exists, it will not be downloaded again. + + Exceptions are not caught. + + Args: + name (str): Name of the file being downloaded. + url (str): URL to download the file from. + dest_path (Path): Destination path to save the file to. + + Returns: + bool: True if the file was downloaded, False if it already existed. + """ + if dest_path.exists(): + return False # already downloaded + + InvokeAILogger.get_logger().info(f"Downloading {name}...") + + dest_path.parent.mkdir(parents=True, exist_ok=True) + request.urlretrieve(url, dest_path, ProgressBar(name)) + + return True