Merge branch 'main' into patch-1

This commit is contained in:
psychedelicious
2023-05-05 10:38:45 +10:00
committed by GitHub
329 changed files with 6233 additions and 5269 deletions

View File

@ -1,14 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from argparse import Namespace
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
import invokeai.backend.util.logging as logger
from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
@ -19,6 +17,7 @@ from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@ -44,15 +43,16 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int):
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -70,8 +70,9 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=get_model_manager(config),
model_manager=get_model_manager(config,logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
@ -83,7 +84,7 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
restoration=RestorationServices(config,logger),
)
create_system_graphs(services.graph_library)

View File

@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -112,19 +108,20 @@ async def update_model(
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists
print(f">> Checking for model {model_name}...")
logger.info(f"Checking for model {model_name}...")
if model_exists:
print(f">> Deleting Model: {model_name}")
logger.info(f"Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}")
logger.info(f"Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
print(f">> Model not found")
logger.error(f"Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
@ -248,4 +245,4 @@ async def delete_model(model_name: str) -> None:
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# except Exception as e:

View File

@ -3,6 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -16,7 +17,6 @@ from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
# Create the app
@ -56,7 +56,7 @@ async def startup_event():
config.parse_args()
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id
config=config, event_handler_id=event_handler_id, logger=logger
)

View File

@ -2,14 +2,15 @@
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
from ..services.invoker import Invoker
@ -229,7 +230,7 @@ class HistoryCommand(BaseCommand):
for i in range(min(self.count, len(history))):
entry_id = history[-1 - i]
entry = context.get_session().graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry)}")
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
class SetDefaultCommand(BaseCommand):

View File

@ -10,6 +10,7 @@ import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
@ -160,8 +161,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
logger.error(
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@ -13,21 +13,20 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -182,7 +181,7 @@ def invoke_all(context: CliContext):
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
context.invoker.services.logger.error(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
@ -192,13 +191,13 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
model_manager = get_model_manager(config)
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
set_autocompleter(model_manager)
events = EventServiceBase()
@ -225,7 +224,8 @@ def invoke_cli():
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
restoration=RestorationServices(config,logger=logger),
logger=logger,
)
system_graphs = create_system_graphs(services.graph_library)
@ -365,12 +365,12 @@ def invoke_cli():
invoke_all(context)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
invoker.services.logger.warning('Invalid command, use "help" to list commands')
continue
except SessionError:
# Start a new session
print("Session error: creating a new session")
invoker.services.logger.warning("Session error: creating a new session")
context.reset()
except ExitCli:

View File

@ -46,8 +46,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -150,6 +150,9 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)

View File

@ -113,8 +113,8 @@ class NoiseInvocation(BaseInvocation):
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
@ -146,11 +146,8 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -363,9 +360,74 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
image_type=image_type, image_name=image_name, image=image
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@ -3,12 +3,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
return model

View File

@ -23,8 +23,6 @@ def create_text_to_image() -> LibraryGraph:
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
]

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
@ -29,6 +31,7 @@ class InvocationServices:
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
@ -40,6 +43,7 @@ class InvocationServices:
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata

View File

@ -49,7 +49,7 @@ class Invoker:
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
@ -71,18 +71,12 @@ class Invoker:
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@ -5,6 +5,7 @@ from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from ...backend import ModelManager
@ -12,16 +13,16 @@ from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args) -> ModelManager:
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found.")
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -62,11 +63,12 @@ def get_model_manager(config: Args) -> ModelManager:
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e)
report_model_error(config, e, logger)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
logger.error(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
@ -76,18 +78,18 @@ def get_model_manager(config: Args) -> ModelManager:
conf_path=config.conf,
weights_directory=path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
@ -96,13 +98,12 @@ def report_model_error(opt: Namespace, e: Exception):
if response.startswith(("n", "N")):
return
print("invokeai-configure is launching....\n")
logger.info("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_config = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())

View File

@ -1,5 +1,5 @@
import traceback
from threading import Event, Thread
from threading import Event, Thread, BoundedSemaphore
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
@ -10,8 +10,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__threadLimit = BoundedSemaphore(1)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
@ -20,7 +23,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = (
True # TODO: probably better to just not use threads?
True # TODO: make async and do not use threads
)
self.__invoker_thread.start()
@ -29,6 +32,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping
@ -110,7 +114,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
@ -127,4 +131,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
except KeyboardInterrupt:
... # Log something?
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()

View File

@ -1,6 +1,7 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args):
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
@ -20,20 +21,22 @@ class RestorationServices:
args.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
logger.info("Face restoration disabled")
if args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
logger.info("Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
@ -58,15 +61,15 @@ class RestorationServices:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
print(
">> GFPGAN not found. Face restoration is disabled."
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
print(
">> CodeFormer not found. Face restoration is disabled."
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@ -80,7 +83,7 @@ class RestorationServices:
fidelity=codeformer_fidelity,
)
else:
print(">> Face Restoration is disabled.")
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@ -93,10 +96,10 @@ class RestorationServices:
denoise_str=upscale_denoise_str,
)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None: