feat(nodes): use correctly-typed configuration service in upscale node

This commit is contained in:
psychedelicious 2023-07-16 10:54:52 +10:00
parent 48a031dbaf
commit 5d59dd4b97

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Literal, Union, cast
import cv2 as cv
@ -33,12 +33,12 @@ class RealESRGANInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) # type: ignore
models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore
image = context.services.images.get_pil_image(self.image.image_name)
models_path = context.services.configuration.models_path
rrdbnet_model = None
netscale = None
model_path = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@ -54,13 +54,13 @@ class RealESRGANInvocation(BaseInvocation):
scale=4,
)
netscale = 4
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth":
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
@ -83,11 +83,11 @@ class RealESRGANInvocation(BaseInvocation):
context.services.logger.error(msg)
raise ValueError(msg)
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_dir / model_path),
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
)