mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add restoration services to nodes
This commit is contained in:
parent
3aa1ee1218
commit
8ca91b1774
@ -5,6 +5,7 @@ from argparse import Namespace
|
|||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -67,6 +68,7 @@ class ApiDependencies:
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
@ -18,6 +18,7 @@ from .invocations import *
|
|||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -148,6 +149,7 @@ def invoke_cli():
|
|||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
|
@ -8,7 +8,6 @@ from ..services.invocation_services import InvocationServices
|
|||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
@ -23,7 +22,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.generate.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=None,
|
upscale=None,
|
||||||
strength=self.strength, # GFPGAN strength
|
strength=self.strength, # GFPGAN strength
|
||||||
|
@ -26,7 +26,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.generate.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=(self.level, self.strength),
|
upscale=(self.level, self.strength),
|
||||||
strength=0.0, # GFPGAN strength
|
strength=0.0, # GFPGAN strength
|
||||||
|
@ -3,17 +3,18 @@ from invokeai.backend import ModelManager
|
|||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
model_manager: ModelManager
|
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
|
model_manager: ModelManager
|
||||||
|
restoration: RestorationServices
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
@ -27,6 +28,7 @@ class InvocationServices:
|
|||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
|
restoration: RestorationServices,
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
@ -34,3 +36,4 @@ class InvocationServices:
|
|||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.restoration = restoration
|
||||||
|
109
invokeai/app/services/restoration_services.py
Normal file
109
invokeai/app/services/restoration_services.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
from ...backend.restoration import Restoration
|
||||||
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||||
|
|
||||||
|
# This should be a real base class for postprocessing functions,
|
||||||
|
# but right now we just instantiate the existing gfpgan, esrgan
|
||||||
|
# and codeformer functions.
|
||||||
|
class RestorationServices:
|
||||||
|
'''Face restoration and upscaling'''
|
||||||
|
|
||||||
|
def __init__(self,args):
|
||||||
|
try:
|
||||||
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
|
if args.restore or args.esrgan:
|
||||||
|
restoration = Restoration()
|
||||||
|
if args.restore:
|
||||||
|
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||||
|
args.gfpgan_model_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> Face restoration disabled")
|
||||||
|
if args.esrgan:
|
||||||
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||||
|
else:
|
||||||
|
print(">> Upscaling disabled")
|
||||||
|
else:
|
||||||
|
print(">> Face restoration and upscaling disabled")
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
|
self.device = torch.device(choose_torch_device())
|
||||||
|
self.gfpgan = gfpgan
|
||||||
|
self.codeformer = codeformer
|
||||||
|
self.esrgan = esrgan
|
||||||
|
|
||||||
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||||
|
# esrgan upscaling
|
||||||
|
# TO DO: refactor into separate methods
|
||||||
|
def upscale_and_reconstruct(
|
||||||
|
self,
|
||||||
|
image_list,
|
||||||
|
facetool="gfpgan",
|
||||||
|
upscale=None,
|
||||||
|
upscale_denoise_str=0.75,
|
||||||
|
strength=0.0,
|
||||||
|
codeformer_fidelity=0.75,
|
||||||
|
save_original=False,
|
||||||
|
image_callback=None,
|
||||||
|
prefix=None,
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
for r in image_list:
|
||||||
|
image, seed = r
|
||||||
|
try:
|
||||||
|
if strength > 0:
|
||||||
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
|
if facetool == "gfpgan":
|
||||||
|
if self.gfpgan is None:
|
||||||
|
print(
|
||||||
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
|
if facetool == "codeformer":
|
||||||
|
if self.codeformer is None:
|
||||||
|
print(
|
||||||
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cf_device = (
|
||||||
|
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||||||
|
)
|
||||||
|
image = self.codeformer.process(
|
||||||
|
image=image,
|
||||||
|
strength=strength,
|
||||||
|
device=cf_device,
|
||||||
|
seed=seed,
|
||||||
|
fidelity=codeformer_fidelity,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> Face Restoration is disabled.")
|
||||||
|
if upscale is not None:
|
||||||
|
if self.esrgan is not None:
|
||||||
|
if len(upscale) < 2:
|
||||||
|
upscale.append(0.75)
|
||||||
|
image = self.esrgan.process(
|
||||||
|
image,
|
||||||
|
upscale[1],
|
||||||
|
seed,
|
||||||
|
int(upscale[0]),
|
||||||
|
denoise_str=upscale_denoise_str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_callback is not None:
|
||||||
|
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||||||
|
else:
|
||||||
|
r[0] = image
|
||||||
|
|
||||||
|
results.append([image, seed])
|
||||||
|
|
||||||
|
return results
|
@ -956,7 +956,7 @@ class Generate:
|
|||||||
):
|
):
|
||||||
results = []
|
results = []
|
||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed, _ = r
|
||||||
try:
|
try:
|
||||||
if strength > 0:
|
if strength > 0:
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
|
@ -26,7 +26,8 @@ def mock_services():
|
|||||||
images = None,
|
images = None,
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
|
@ -26,7 +26,8 @@ def mock_services() -> InvocationServices:
|
|||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
processor = DefaultInvocationProcessor()
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
Loading…
Reference in New Issue
Block a user