mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add invocation_context.load_ckpt_from_url() method
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
@ -11,7 +10,6 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -56,7 +54,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
esrgan_model_path = None
|
||||
|
||||
if self.model_name in [
|
||||
"RealESRGAN_x4plus.pth",
|
||||
@ -99,16 +96,13 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
loadnet = context.models.load_ckpt_from_url(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=esrgan_model_path,
|
||||
loadnet=loadnet.model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
@ -118,6 +112,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
Reference in New Issue
Block a user