feat(nodes): remove references to restoration services

- remove restoration services
- remove the restore faces nodes
- update tests
This commit is contained in:
psychedelicious 2023-07-16 01:12:39 +10:00
parent 8a1b9d1001
commit c7b547ea3e
8 changed files with 25 additions and 203 deletions

View File

@ -20,7 +20,6 @@ from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -58,7 +57,7 @@ class ApiDependencies:
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}')
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -117,7 +116,7 @@ class ApiDependencies:
)
services = InvocationServices(
model_manager=ModelManagerService(config,logger),
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
@ -129,7 +128,6 @@ class ApiDependencies:
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@ -54,7 +54,6 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
import torch
@ -295,7 +294,6 @@ def invoke_cli():
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)

View File

@ -1,55 +0,0 @@
from typing import Literal, Optional
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
strength=self.strength, # GFPGAN strength
save_original=False,
image_callback=None,
)
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -10,7 +10,6 @@ if TYPE_CHECKING:
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
@ -34,7 +33,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"
def __init__(
self,
@ -50,7 +48,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
):
self.board_images = board_images
self.boards = boards
@ -65,4 +62,3 @@ class InvocationServices:
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration

View File

@ -1,113 +0,0 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
# TODO: redo for new model structure
if False and args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
if False and args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results

View File

@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,6 +1,6 @@
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
from invokeai.app.invocations.upscale import UpscaleInvocation
from invokeai.app.invocations.upscale import RealESRGANInvocation
from invokeai.app.invocations.image import *
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
from invokeai.app.invocations.params import ParamIntInvocation
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
def test_connections_are_compatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "image"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -29,7 +29,7 @@ def test_connections_are_compatible():
def test_connections_are_incompatible():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "image"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "strength"
result = are_connections_compatible(from_node, from_field, to_node, to_field)
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
def test_connections_incompatible_with_invalid_fields():
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
from_field = "invalid_field"
to_node = UpscaleInvocation(id = "2")
to_node = RealESRGANInvocation(id = "2")
to_field = "image"
# From field is invalid
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
nu = UpscaleInvocation(id = "1")
nu = RealESRGANInvocation(id = "1")
with pytest.raises(TypeError):
g.update_node("1", nu)
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
g = Graph()
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
g.add_node(n)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
e1 = create_edge(n.id,"image",n2.id,"image")
g.add_edge(e1)
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
def test_graph_adds_edge():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -139,7 +139,7 @@ def test_graph_adds_edge():
def test_graph_fails_to_add_edge_with_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n1 = RealESRGANInvocation(id = "1")
g.add_node(n1)
e = create_edge(n1.id,"image",n1.id,"image")
with pytest.raises(InvalidEdgeError):
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
def test_graph_fails_to_add_edge_with_long_cycle():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
n2 = RealESRGANInvocation(id = "2")
n3 = RealESRGANInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
def test_graph_fails_to_add_edge_with_missing_node_id():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","3","image")
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
def test_graph_fails_to_add_edge_when_destination_exists():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n3 = UpscaleInvocation(id = "3")
n2 = RealESRGANInvocation(id = "2")
n3 = RealESRGANInvocation(id = "3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
def test_graph_fails_to_add_edge_with_mismatched_types():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","strength")
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
def test_graph_validates():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e1 = create_edge("1","image","2","image")
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
def test_graph_invalid_if_has_cycle():
g = Graph()
n1 = UpscaleInvocation(id = "1")
n2 = UpscaleInvocation(id = "2")
n1 = RealESRGANInvocation(id = "1")
n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","image")
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
def test_graph_invalid_with_invalid_connection():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.nodes[n1.id] = n1
g.nodes[n2.id] = n2
e1 = create_edge("1","image","2","strength")
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
g.add_node(n1)
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n2)
with pytest.raises(NodeNotFoundError):
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
def test_graph_gets_networkx_graph():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
def test_graph_can_serialize():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")
@ -541,7 +541,7 @@ def test_graph_can_serialize():
def test_graph_can_deserialize():
g = Graph()
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
n2 = UpscaleInvocation(id = "2")
n2 = RealESRGANInvocation(id = "2")
g.add_node(n1)
g.add_node(n2)
e = create_edge(n1.id,"image",n2.id,"image")