mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into patch-1
This commit is contained in:
@ -1,14 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||
import invokeai.backend.util.logging as logger
|
||||
from typing import types
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
@ -19,6 +17,7 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.metadata import PngMetadataService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -44,15 +43,16 @@ class ApiDependencies:
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int):
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
@ -70,8 +70,9 @@ class ApiDependencies:
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
model_manager=get_model_manager(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
@ -83,7 +84,7 @@ class ApiDependencies:
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -112,19 +108,20 @@ async def update_model(
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
logger.info(f"Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
logger.info(f"Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
logger.info(f"Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
@ -248,4 +245,4 @@ async def delete_model(model_name: str) -> None:
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
# except Exception as e:
|
||||
|
@ -3,6 +3,7 @@ import asyncio
|
||||
from inspect import signature
|
||||
|
||||
import uvicorn
|
||||
import invokeai.backend.util.logging as logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
@ -16,7 +17,6 @@ from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
# Create the app
|
||||
@ -56,7 +56,7 @@ async def startup_event():
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
|
@ -2,14 +2,15 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
@ -229,7 +230,7 @@ class HistoryCommand(BaseCommand):
|
||||
for i in range(min(self.count, len(history))):
|
||||
entry_id = history[-1 - i]
|
||||
entry = context.get_session().graph.get_node(entry_id)
|
||||
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
||||
|
||||
|
||||
class SetDefaultCommand(BaseCommand):
|
||||
|
@ -10,6 +10,7 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
@ -160,8 +161,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
logger.error(
|
||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
histfile.replace(Path(newname))
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
@ -13,21 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.metadata import PngMetadataService
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
@ -182,7 +181,7 @@ def invoke_all(context: CliContext):
|
||||
# Print any errors
|
||||
if context.session.has_error():
|
||||
for n in context.session.errors:
|
||||
print(
|
||||
context.invoker.services.logger.error(
|
||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||
)
|
||||
|
||||
@ -192,13 +191,13 @@ def invoke_all(context: CliContext):
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
model_manager = get_model_manager(config)
|
||||
model_manager = get_model_manager(config,logger=logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
@ -225,7 +224,8 @@ def invoke_cli():
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@ -365,12 +365,12 @@ def invoke_cli():
|
||||
invoke_all(context)
|
||||
|
||||
except InvalidArgs:
|
||||
print('Invalid command, use "help" to list commands')
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
|
@ -46,8 +46,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
@ -150,6 +150,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
|
@ -113,8 +113,8 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
@ -146,11 +146,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
@ -363,9 +360,74 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
context.services.images.save(image_type, image_name, image, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image
|
||||
image_type=image_type, image_name=image_name, image=image
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal[
|
||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
||||
]
|
||||
|
||||
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
type: Literal["lresize"] = "lresize"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
type: Literal["lscale"] = "lscale"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
|
@ -3,12 +3,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
model = model_manager.get_model()
|
||||
print(
|
||||
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||
)
|
||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||
|
||||
return model
|
||||
|
@ -23,8 +23,6 @@ def create_text_to_image() -> LibraryGraph:
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
]
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from typing import types
|
||||
from invokeai.app.services.metadata import MetadataServiceBase
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
@ -29,6 +31,7 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
logger: types.ModuleType,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
@ -40,6 +43,7 @@ class InvocationServices:
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.metadata = metadata
|
||||
|
@ -49,7 +49,7 @@ class Invoker:
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
@ -71,18 +71,12 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from argparse import Namespace
|
||||
from invokeai.backend import Args
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import types
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager
|
||||
@ -12,16 +13,16 @@ from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config: Args) -> ModelManager:
|
||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
if not config.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
report_model_error(
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
||||
)
|
||||
|
||||
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
@ -62,11 +63,12 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(config, e)
|
||||
report_model_error(config, e, logger)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
logger.error(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# try to autoconvert new models
|
||||
@ -76,18 +78,18 @@ def get_model_manager(config: Args) -> ModelManager:
|
||||
conf_path=config.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
logger.info('Model manager initialized')
|
||||
return model_manager
|
||||
|
||||
def report_model_error(opt: Namespace, e: Exception):
|
||||
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
print(
|
||||
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||
logger.error(
|
||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||
)
|
||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||
if yes_to_all:
|
||||
print(
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
logger.warning(
|
||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
response = input(
|
||||
@ -96,13 +98,12 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
if response.startswith(("n", "N")):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
logger.info("invokeai-configure is launching....\n")
|
||||
|
||||
# Match arguments that were set on the CLI
|
||||
# only the arguments accepted by the configuration script are parsed
|
||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||
previous_config = sys.argv
|
||||
sys.argv = ["invokeai-configure"]
|
||||
sys.argv.extend(root_dir)
|
||||
sys.argv.extend(config.to_dict())
|
||||
|
@ -1,5 +1,5 @@
|
||||
import traceback
|
||||
from threading import Event, Thread
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
@ -10,8 +10,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
@ -20,7 +23,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
kwargs=dict(stop_event=self.__stop_event),
|
||||
)
|
||||
self.__invoker_thread.daemon = (
|
||||
True # TODO: probably better to just not use threads?
|
||||
True # TODO: make async and do not use threads
|
||||
)
|
||||
self.__invoker_thread.start()
|
||||
|
||||
@ -29,6 +32,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
if not queue_item: # Probably stopping
|
||||
@ -110,7 +114,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
@ -127,4 +131,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
... # Log something?
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from typing import types
|
||||
from ...backend.restoration import Restoration
|
||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
|
||||
@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||
class RestorationServices:
|
||||
'''Face restoration and upscaling'''
|
||||
|
||||
def __init__(self,args):
|
||||
def __init__(self,args,logger:types.ModuleType):
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
if args.restore or args.esrgan:
|
||||
@ -20,20 +21,22 @@ class RestorationServices:
|
||||
args.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
logger.info("Face restoration disabled")
|
||||
if args.esrgan:
|
||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
logger.info("Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
logger.info("Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
self.device = torch.device(choose_torch_device())
|
||||
self.gfpgan = gfpgan
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.logger = logger
|
||||
self.logger.info('Face restoration initialized')
|
||||
|
||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||
# esrgan upscaling
|
||||
@ -58,15 +61,15 @@ class RestorationServices:
|
||||
if self.gfpgan is not None or self.codeformer is not None:
|
||||
if facetool == "gfpgan":
|
||||
if self.gfpgan is None:
|
||||
print(
|
||||
">> GFPGAN not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"GFPGAN not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
image = self.gfpgan.process(image, strength, seed)
|
||||
if facetool == "codeformer":
|
||||
if self.codeformer is None:
|
||||
print(
|
||||
">> CodeFormer not found. Face restoration is disabled."
|
||||
self.logger.info(
|
||||
"CodeFormer not found. Face restoration is disabled."
|
||||
)
|
||||
else:
|
||||
cf_device = (
|
||||
@ -80,7 +83,7 @@ class RestorationServices:
|
||||
fidelity=codeformer_fidelity,
|
||||
)
|
||||
else:
|
||||
print(">> Face Restoration is disabled.")
|
||||
self.logger.info("Face Restoration is disabled.")
|
||||
if upscale is not None:
|
||||
if self.esrgan is not None:
|
||||
if len(upscale) < 2:
|
||||
@ -93,10 +96,10 @@ class RestorationServices:
|
||||
denoise_str=upscale_denoise_str,
|
||||
)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
print(
|
||||
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
self.logger.info(
|
||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||
)
|
||||
|
||||
if image_callback is not None:
|
||||
|
Reference in New Issue
Block a user