feat(nodes): use correctly-typed configuration service in upscale node

This commit is contained in:
psychedelicious 2023-07-16 10:54:52 +10:00
parent 48a031dbaf
commit 5d59dd4b97

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path from pathlib import Path, PosixPath
from typing import Literal, Union, cast from typing import Literal, Union, cast
import cv2 as cv import cv2 as cv
@ -33,12 +33,12 @@ class RealESRGANInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) # type: ignore image = context.services.images.get_pil_image(self.image.image_name)
models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore models_path = context.services.configuration.models_path
rrdbnet_model = None rrdbnet_model = None
netscale = None netscale = None
model_path = None esrgan_model_path = None
if self.model_name in [ if self.model_name in [
"RealESRGAN_x4plus.pth", "RealESRGAN_x4plus.pth",
@ -54,13 +54,13 @@ class RealESRGANInvocation(BaseInvocation):
scale=4, scale=4,
) )
netscale = 4 netscale = 4
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth": elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks # x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet( rrdbnet_model = RRDBNet(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
num_feat=64, num_feat=64,
num_block=6, # 6 blocks num_block=6, # 6 blocks
num_grow_ch=32, num_grow_ch=32,
scale=4, scale=4,
) )
@ -83,11 +83,11 @@ class RealESRGANInvocation(BaseInvocation):
context.services.logger.error(msg) context.services.logger.error(msg)
raise ValueError(msg) raise ValueError(msg)
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=netscale, scale=netscale,
model_path=str(models_dir / model_path), model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model, model=rrdbnet_model,
half=False, half=False,
) )