mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): use correctly-typed configuration service in upscale node
This commit is contained in:
parent
48a031dbaf
commit
5d59dd4b97
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path
|
from pathlib import Path, PosixPath
|
||||||
from typing import Literal, Union, cast
|
from typing import Literal, Union, cast
|
||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
@ -33,12 +33,12 @@ class RealESRGANInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name) # type: ignore
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore
|
models_path = context.services.configuration.models_path
|
||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
model_path = None
|
esrgan_model_path = None
|
||||||
|
|
||||||
if self.model_name in [
|
if self.model_name in [
|
||||||
"RealESRGAN_x4plus.pth",
|
"RealESRGAN_x4plus.pth",
|
||||||
@ -54,13 +54,13 @@ class RealESRGANInvocation(BaseInvocation):
|
|||||||
scale=4,
|
scale=4,
|
||||||
)
|
)
|
||||||
netscale = 4
|
netscale = 4
|
||||||
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth":
|
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
|
||||||
# x4 RRDBNet model, 6 blocks
|
# x4 RRDBNet model, 6 blocks
|
||||||
rrdbnet_model = RRDBNet(
|
rrdbnet_model = RRDBNet(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
num_block=6, # 6 blocks
|
num_block=6, # 6 blocks
|
||||||
num_grow_ch=32,
|
num_grow_ch=32,
|
||||||
scale=4,
|
scale=4,
|
||||||
)
|
)
|
||||||
@ -83,11 +83,11 @@ class RealESRGANInvocation(BaseInvocation):
|
|||||||
context.services.logger.error(msg)
|
context.services.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=netscale,
|
scale=netscale,
|
||||||
model_path=str(models_dir / model_path),
|
model_path=str(models_path / esrgan_model_path),
|
||||||
model=rrdbnet_model,
|
model=rrdbnet_model,
|
||||||
half=False,
|
half=False,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user