mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): remove references to restoration services
- remove restoration services - remove the restore faces nodes - update tests
This commit is contained in:
parent
8a1b9d1001
commit
c7b547ea3e
@ -20,7 +20,6 @@ from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@ -58,7 +57,7 @@ class ApiDependencies:
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||
logger.debug(f'InvokeAI version {__version__}')
|
||||
logger.debug(f"InvokeAI version {__version__}")
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
@ -117,7 +116,7 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config,logger),
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
@ -129,7 +128,6 @@ class ApiDependencies:
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config, logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
@ -54,7 +54,6 @@ from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
||||
import torch
|
||||
@ -295,7 +294,6 @@ def invoke_cli():
|
||||
),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
@ -1,55 +0,0 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["restore_face"] = "restore_face"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["restoration", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
strength=self.strength, # GFPGAN strength
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
)
|
||||
|
||||
# Results are image and seed, unwrap for now
|
||||
# TODO: can this return multiple results?
|
||||
image_dto = context.services.images.create(
|
||||
image=results[0][0],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
@ -10,7 +10,6 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.restoration_services import RestorationServices
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.config import InvokeAISettings
|
||||
@ -34,7 +33,6 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
queue: "InvocationQueueABC"
|
||||
restoration: "RestorationServices"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,7 +48,6 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
queue: "InvocationQueueABC",
|
||||
restoration: "RestorationServices",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
@ -65,4 +62,3 @@ class InvocationServices:
|
||||
self.model_manager = model_manager
|
||||
self.processor = processor
|
||||
self.queue = queue
|
||||
self.restoration = restoration
|
||||
|
@ -1,113 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
# This should be a real base class for postprocessing functions,
|
||||
# but right now we just instantiate the existing gfpgan, esrgan
|
||||
# and codeformer functions.
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
restoration = Restoration()
|
||||
# TODO: redo for new model structure
|
||||
if False and args.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
logger.info("Face restoration disabled")
|
||||
if False and args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
# TO DO: refactor into separate methods
|
||||
def upscale_and_reconstruct(
|
||||
self,
|
||||
image_list,
|
||||
facetool="gfpgan",
|
||||
upscale=None,
|
||||
upscale_denoise_str=0.75,
|
||||
strength=0.0,
|
||||
codeformer_fidelity=0.75,
|
||||
save_original=False,
|
||||
image_callback=None,
|
||||
prefix=None,
|
||||
):
|
||||
results = []
|
||||
for r in image_list:
|
||||
image, seed = r
|
||||
try:
|
||||
if strength > 0:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||||
)
|
||||
image = self.codeformer.process(
|
||||
image=image,
|
||||
strength=strength,
|
||||
device=cf_device,
|
||||
seed=seed,
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = self.esrgan.process(
|
||||
image,
|
||||
upscale[1],
|
||||
seed,
|
||||
int(upscale[0]),
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||||
else:
|
||||
r[0] = image
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
return results
|
@ -55,7 +55,6 @@ def mock_services() -> InvocationServices:
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None, # type: ignore
|
||||
configuration = None, # type: ignore
|
||||
)
|
||||
|
||||
|
@ -48,7 +48,6 @@ def mock_services() -> InvocationServices:
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None, # type: ignore
|
||||
configuration = None, # type: ignore
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
from invokeai.app.invocations.upscale import RealESRGANInvocation
|
||||
from invokeai.app.invocations.image import *
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.params import ParamIntInvocation
|
||||
@ -19,7 +19,7 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
||||
def test_connections_are_compatible():
|
||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
from_field = "image"
|
||||
to_node = UpscaleInvocation(id = "2")
|
||||
to_node = RealESRGANInvocation(id = "2")
|
||||
to_field = "image"
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
@ -29,7 +29,7 @@ def test_connections_are_compatible():
|
||||
def test_connections_are_incompatible():
|
||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
from_field = "image"
|
||||
to_node = UpscaleInvocation(id = "2")
|
||||
to_node = RealESRGANInvocation(id = "2")
|
||||
to_field = "strength"
|
||||
|
||||
result = are_connections_compatible(from_node, from_field, to_node, to_field)
|
||||
@ -39,7 +39,7 @@ def test_connections_are_incompatible():
|
||||
def test_connections_incompatible_with_invalid_fields():
|
||||
from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
from_field = "invalid_field"
|
||||
to_node = UpscaleInvocation(id = "2")
|
||||
to_node = RealESRGANInvocation(id = "2")
|
||||
to_field = "image"
|
||||
|
||||
# From field is invalid
|
||||
@ -86,10 +86,10 @@ def test_graph_fails_to_update_node_if_type_changes():
|
||||
g = Graph()
|
||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
g.add_node(n)
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n2)
|
||||
|
||||
nu = UpscaleInvocation(id = "1")
|
||||
nu = RealESRGANInvocation(id = "1")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
g.update_node("1", nu)
|
||||
@ -98,7 +98,7 @@ def test_graph_allows_non_conflicting_id_change():
|
||||
g = Graph()
|
||||
n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
g.add_node(n)
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n2)
|
||||
e1 = create_edge(n.id,"image",n2.id,"image")
|
||||
g.add_edge(e1)
|
||||
@ -128,7 +128,7 @@ def test_graph_fails_to_update_node_id_if_conflict():
|
||||
def test_graph_adds_edge():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id,"image",n2.id,"image")
|
||||
@ -139,7 +139,7 @@ def test_graph_adds_edge():
|
||||
|
||||
def test_graph_fails_to_add_edge_with_cycle():
|
||||
g = Graph()
|
||||
n1 = UpscaleInvocation(id = "1")
|
||||
n1 = RealESRGANInvocation(id = "1")
|
||||
g.add_node(n1)
|
||||
e = create_edge(n1.id,"image",n1.id,"image")
|
||||
with pytest.raises(InvalidEdgeError):
|
||||
@ -148,8 +148,8 @@ def test_graph_fails_to_add_edge_with_cycle():
|
||||
def test_graph_fails_to_add_edge_with_long_cycle():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n3 = UpscaleInvocation(id = "3")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
n3 = RealESRGANInvocation(id = "3")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
@ -164,7 +164,7 @@ def test_graph_fails_to_add_edge_with_long_cycle():
|
||||
def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge("1","image","3","image")
|
||||
@ -177,8 +177,8 @@ def test_graph_fails_to_add_edge_with_missing_node_id():
|
||||
def test_graph_fails_to_add_edge_when_destination_exists():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n3 = UpscaleInvocation(id = "3")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
n3 = RealESRGANInvocation(id = "3")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
@ -194,7 +194,7 @@ def test_graph_fails_to_add_edge_when_destination_exists():
|
||||
def test_graph_fails_to_add_edge_with_mismatched_types():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge("1","image","2","strength")
|
||||
@ -344,7 +344,7 @@ def test_graph_iterator_invalid_if_output_and_input_types_different():
|
||||
def test_graph_validates():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e1 = create_edge("1","image","2","image")
|
||||
@ -377,8 +377,8 @@ def test_graph_invalid_if_subgraph_invalid():
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
g = Graph()
|
||||
n1 = UpscaleInvocation(id = "1")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n1 = RealESRGANInvocation(id = "1")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.nodes[n1.id] = n1
|
||||
g.nodes[n2.id] = n2
|
||||
e1 = create_edge("1","image","2","image")
|
||||
@ -391,7 +391,7 @@ def test_graph_invalid_if_has_cycle():
|
||||
def test_graph_invalid_with_invalid_connection():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.nodes[n1.id] = n1
|
||||
g.nodes[n2.id] = n2
|
||||
e1 = create_edge("1","image","2","strength")
|
||||
@ -503,7 +503,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
@ -512,7 +512,7 @@ def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
def test_graph_gets_networkx_graph():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id,"image",n2.id,"image")
|
||||
@ -529,7 +529,7 @@ def test_graph_gets_networkx_graph():
|
||||
def test_graph_can_serialize():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id,"image",n2.id,"image")
|
||||
@ -541,7 +541,7 @@ def test_graph_can_serialize():
|
||||
def test_graph_can_deserialize():
|
||||
g = Graph()
|
||||
n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi")
|
||||
n2 = UpscaleInvocation(id = "2")
|
||||
n2 = RealESRGANInvocation(id = "2")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
e = create_edge(n1.id,"image",n2.id,"image")
|
||||
|
Loading…
Reference in New Issue
Block a user