wip upscale node

This commit is contained in:
psychedelicious 2023-07-15 21:13:44 +10:00
parent ee7d700ae4
commit 3aca35c932

View File

@ -65,9 +65,7 @@ class UpscaleInvocation(BaseInvocation):
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x2plus",
"ESRGAN_SRx4_DF2KOST_official-ff704c30",
]
@ -79,10 +77,7 @@ class RealESRGANInvocation(BaseInvocation):
type: Literal["realesrgan"] = "realesrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image" )
model_name: REALESRGAN_MODELS = Field(default="RealESRGAN_x4plus", description="The Real-ESRGAN model to use")
scale: float = Field(default=4, description="The final upsampling scale")
tile: int = Field(default=400, description="The tile size (px)")
tile_pad: int = Field(default=10, description="The tile padding size (px)")
pre_pad: int = Field(default=0, description="The pre padding size at each border (px)")
scale: Literal[2, 4] = Field(default=4, description="The final upsampling scale")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -91,22 +86,21 @@ class RealESRGANInvocation(BaseInvocation):
netscale = None
model_path = None
if self.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
if self.model_name == 'RealESRGAN x4 Plus': # x4 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'/home/bat/invokeai/models/upscale/{self.model_name}.pth'
elif self.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
model_path = f'core/upscaling/realesrgan/RealESRGAN_x4plus.pth'
elif self.model_name == 'RealESRGAN x4 Plus (Anime 6B)': # x4 RRDBNet model with 6 blocks
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'/home/bat/invokeai/models/upscale/{self.model_name}.pth'
elif self.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
elif self.model_name in ['ESRGAN_SRx4_DF2KOST_official-ff704c30']: # x2 RRDBNet model
model_path = f'core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth'
# elif self.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
# netscale = 2
elif self.model_name in ['ESRGAN x4']: # x2 RRDBNet model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
model_path = f'/home/bat/invokeai/models/upscale/{self.model_name}.pth'
model_path = f'core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth'
if not model or not netscale or not model_path:
raise Exception(f"Invalid model {self.model_name}")
@ -115,9 +109,6 @@ class RealESRGANInvocation(BaseInvocation):
scale=netscale,
model_path=model_path,
model=model,
tile=self.tile,
tile_pad=self.tile_pad,
pre_pad=self.pre_pad,
half=False,
)