mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Address change requests in first round of PR reviews.
Pending: - Move model install calls into model manager and create passthrus in invocation_context. - Consider splitting load_model_from_url() into a call to get the path and a call to load the path.
This commit is contained in:
@ -11,7 +11,6 @@ from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
@ -96,22 +95,21 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
loadnet=loadnet.model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
with loadnet as loadnet_model:
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
loadnet=loadnet_model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
|
Reference in New Issue
Block a user