Merge branch 'main' into feat/compel_node

This commit is contained in:
StAlKeR7779 2023-05-04 00:28:33 +03:00 committed by GitHub
commit 56d3cbead0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
329 changed files with 6233 additions and 5270 deletions

View File

@ -1,14 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from argparse import Namespace
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
import invokeai.backend.util.logging as logger
from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
@ -19,6 +17,7 @@ from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@ -44,15 +43,16 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int):
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
events = FastAPIEventService(event_handler_id)
@ -70,8 +70,9 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=get_model_manager(config),
model_manager=get_model_manager(config,logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
@ -83,7 +84,7 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
restoration=RestorationServices(config,logger),
)
create_system_graphs(services.graph_library)

View File

@ -8,10 +8,6 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -112,19 +108,20 @@ async def update_model(
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists
print(f">> Checking for model {model_name}...")
logger.info(f"Checking for model {model_name}...")
if model_exists:
print(f">> Deleting Model: {model_name}")
logger.info(f"Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}")
logger.info(f"Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
print(f">> Model not found")
logger.error(f"Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
@ -248,4 +245,4 @@ async def delete_model(model_name: str) -> None:
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# except Exception as e:

View File

@ -3,6 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -16,7 +17,6 @@ from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
# Create the app
@ -56,7 +56,7 @@ async def startup_event():
config.parse_args()
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id
config=config, event_handler_id=event_handler_id, logger=logger
)

View File

@ -2,14 +2,15 @@
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
from ..services.invoker import Invoker
@ -229,7 +230,7 @@ class HistoryCommand(BaseCommand):
for i in range(min(self.count, len(history))):
entry_id = history[-1 - i]
entry = context.get_session().graph.get_node(entry_id)
print(f"{entry_id}: {get_invocation_command(entry)}")
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
class SetDefaultCommand(BaseCommand):

View File

@ -10,6 +10,7 @@ import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
@ -160,8 +161,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
logger.error(
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@ -13,21 +13,20 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -182,7 +181,7 @@ def invoke_all(context: CliContext):
# Print any errors
if context.session.has_error():
for n in context.session.errors:
print(
context.invoker.services.logger.error(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
@ -192,13 +191,13 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
model_manager = get_model_manager(config)
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
set_autocompleter(model_manager)
events = EventServiceBase()
@ -225,7 +224,8 @@ def invoke_cli():
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
restoration=RestorationServices(config,logger=logger),
logger=logger,
)
system_graphs = create_system_graphs(services.graph_library)
@ -365,12 +365,12 @@ def invoke_cli():
invoke_all(context)
except InvalidArgs:
print('Invalid command, use "help" to list commands')
invoker.services.logger.warning('Invalid command, use "help" to list commands')
continue
except SessionError:
# Start a new session
print("Session error: creating a new session")
invoker.services.logger.warning("Session error: creating a new session")
context.reset()
except ExitCli:

View File

@ -46,8 +46,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -150,6 +150,9 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)

View File

@ -113,8 +113,8 @@ class NoiseInvocation(BaseInvocation):
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
@ -149,8 +149,6 @@ class TextToLatentsInvocation(BaseInvocation):
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -365,9 +363,74 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
image_type=image_type, image_name=image_name, image=image
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@ -3,12 +3,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
return model

View File

@ -27,10 +27,6 @@ def create_text_to_image() -> LibraryGraph:
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
# TODO: remove, when updated TextToLatents merged
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='5', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='5', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='5', field='seed')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),

View File

@ -1,4 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
@ -29,6 +31,7 @@ class InvocationServices:
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
@ -40,6 +43,7 @@ class InvocationServices:
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata

View File

@ -49,7 +49,7 @@ class Invoker:
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
@ -71,18 +71,12 @@ class Invoker:
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@ -5,6 +5,7 @@ from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from ...backend import ModelManager
@ -12,16 +13,16 @@ from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args) -> ModelManager:
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found.")
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
)
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@ -62,11 +63,12 @@ def get_model_manager(config: Args) -> ModelManager:
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e)
report_model_error(config, e, logger)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
logger.error(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
@ -76,18 +78,18 @@ def get_model_manager(config: Args) -> ModelManager:
conf_path=config.conf,
weights_directory=path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
@ -96,13 +98,12 @@ def report_model_error(opt: Namespace, e: Exception):
if response.startswith(("n", "N")):
return
print("invokeai-configure is launching....\n")
logger.info("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_config = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())

View File

@ -1,5 +1,5 @@
import traceback
from threading import Event, Thread
from threading import Event, Thread, BoundedSemaphore
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
@ -10,8 +10,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__threadLimit = BoundedSemaphore(1)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
@ -20,7 +23,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = (
True # TODO: probably better to just not use threads?
True # TODO: make async and do not use threads
)
self.__invoker_thread.start()
@ -29,6 +32,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping
@ -110,7 +114,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
@ -127,4 +131,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
except KeyboardInterrupt:
... # Log something?
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()

View File

@ -1,6 +1,7 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
@ -10,7 +11,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args):
def __init__(self,args,logger:types.ModuleType):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
@ -20,20 +21,22 @@ class RestorationServices:
args.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
logger.info("Face restoration disabled")
if args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
logger.info("Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
@ -58,15 +61,15 @@ class RestorationServices:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
print(
">> GFPGAN not found. Face restoration is disabled."
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
print(
">> CodeFormer not found. Face restoration is disabled."
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@ -80,7 +83,7 @@ class RestorationServices:
fidelity=codeformer_fidelity,
)
else:
print(">> Face Restoration is disabled.")
self.logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@ -93,10 +96,10 @@ class RestorationServices:
denoise_str=upscale_denoise_str,
)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
self.logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:

View File

@ -96,6 +96,7 @@ from pathlib import Path
from typing import List
import invokeai.version
import invokeai.backend.util.logging as logger
from invokeai.backend.image_util import retrieve_metadata
from .globals import Globals
@ -189,7 +190,7 @@ class Args(object):
print(f"{APP_NAME} {APP_VERSION}")
sys.exit(0)
print("* Initializing, be patient...")
logger.info("Initializing, be patient...")
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
Globals.try_patchmatch = switches.patchmatch
@ -197,14 +198,13 @@ class Args(object):
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
legacyinit = os.path.expanduser("~/.invokeai")
if os.path.exists(initfile):
print(
f">> Initialization file {initfile} found. Loading...",
file=sys.stderr,
logger.info(
f"Initialization file {initfile} found. Loading...",
)
sysargs.insert(0, f"@{initfile}")
elif os.path.exists(legacyinit):
print(
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
logger.warning(
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
)
sysargs.insert(0, f"@{legacyinit}")
Globals.log_tokenization = self._arg_parser.parse_args(
@ -214,7 +214,7 @@ class Args(object):
self._arg_switches = self._arg_parser.parse_args(sysargs)
return self._arg_switches
except Exception as e:
print(f"An exception has occurred: {e}")
logger.error(f"An exception has occurred: {e}")
return None
def parse_cmd(self, cmd_string):
@ -1154,7 +1154,7 @@ class Args(object):
def format_metadata(**kwargs):
print("format_metadata() is deprecated. Please use metadata_dumps()")
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
return metadata_dumps(kwargs)
@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
import sys
import traceback
print(">> could not read metadata", file=sys.stderr)
logger.error("Could not read metadata")
print(traceback.format_exc(), file=sys.stderr)
return results

View File

@ -27,6 +27,7 @@ from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
import invokeai.backend.util.logging as logger
from .args import metadata_from_png
from .generator import infill_methods
from .globals import Globals, global_cache_dir
@ -195,12 +196,12 @@ class Generate:
# device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated.
self.device = torch.device(choose_torch_device())
print(f">> Using device_type {self.device.type}")
logger.info(f"Using device_type {self.device.type}")
if full_precision:
if self.precision != "auto":
raise ValueError("Remove --full_precision / -F if using --precision")
print("Please remove deprecated --full_precision / -F")
print("If auto config does not work you can use --precision=float32")
logger.warning("Please remove deprecated --full_precision / -F")
logger.warning("If auto config does not work you can use --precision=float32")
self.precision = "float32"
if self.precision == "auto":
self.precision = choose_precision(self.device)
@ -208,13 +209,13 @@ class Generate:
if is_xformers_available():
if torch.cuda.is_available() and not Globals.disable_xformers:
print(">> xformers memory-efficient attention is available and enabled")
logger.info("xformers memory-efficient attention is available and enabled")
else:
print(
">> xformers memory-efficient attention is available but disabled"
logger.info(
"xformers memory-efficient attention is available but disabled"
)
else:
print(">> xformers not installed")
logger.info("xformers not installed")
# model caching system for fast switching
self.model_manager = ModelManager(
@ -229,8 +230,8 @@ class Generate:
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
if not self.model_manager.valid_model(model):
print(
f'** "{model}" is not a known model name; falling back to {fallback}.'
logger.warning(
f'"{model}" is not a known model name; falling back to {fallback}.'
)
model = None
self.model_name = model or fallback
@ -246,10 +247,10 @@ class Generate:
# load safety checker if requested
if safety_checker:
print(">> Initializing NSFW checker")
logger.info("Initializing NSFW checker")
self.safety_checker = SafetyChecker(self.device)
else:
print(">> NSFW checker is disabled")
logger.info("NSFW checker is disabled")
def prompt2png(self, prompt, outdir, **kwargs):
"""
@ -567,7 +568,7 @@ class Generate:
self.clear_cuda_cache()
if catch_interrupts:
print("**Interrupted** Partial results will be returned.")
logger.warning("Interrupted** Partial results will be returned.")
else:
raise KeyboardInterrupt
except RuntimeError:
@ -575,11 +576,11 @@ class Generate:
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr)
print(">> Could not generate image.")
logger.info("Could not generate image.")
toc = time.time()
print("\n>> Usage stats:")
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
logger.info("Usage stats:")
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
self.print_cuda_stats()
return results
@ -609,16 +610,16 @@ class Generate:
def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
print(
">> Max VRAM used for this generation:",
"%4.2fG." % (self.max_memory_allocated / 1e9),
"Current VRAM utilization:",
"%4.2fG" % (self.memory_allocated / 1e9),
logger.info(
"Max VRAM used for this generation: "+
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
"Current VRAM utilization: "+
"%4.2fG" % (self.memory_allocated / 1e9)
)
print(
">> Max VRAM used since script start: ",
"%4.2fG" % (self.session_peakmem / 1e9),
logger.info(
"Max VRAM used since script start: " +
"%4.2fG" % (self.session_peakmem / 1e9)
)
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
@ -647,7 +648,7 @@ class Generate:
seed = random.randrange(0, np.iinfo(np.uint32).max)
prompt = opt.prompt or args.prompt or ""
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
# try to reuse the same filename prefix as the original file.
# we take everything up to the first period
@ -696,8 +697,8 @@ class Generate:
try:
extend_instructions[direction] = int(pixels)
except ValueError:
print(
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
logger.warning(
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
)
opt.seed = seed
@ -720,8 +721,8 @@ class Generate:
# fetch the metadata from the image
generator = self.select_generator(embiggen=True)
opt.strength = opt.embiggen_strength or 0.40
print(
f">> Setting img2img strength to {opt.strength} for happy embiggening"
logger.info(
f"Setting img2img strength to {opt.strength} for happy embiggening"
)
generator.generate(
prompt,
@ -748,12 +749,12 @@ class Generate:
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
elif tool is None:
print(
"* please provide at least one postprocessing option, such as -G or -U"
logger.warning(
"please provide at least one postprocessing option, such as -G or -U"
)
return None
else:
print(f"* postprocessing tool {tool} is not yet supported")
logger.warning(f"postprocessing tool {tool} is not yet supported")
return None
def select_generator(
@ -797,8 +798,8 @@ class Generate:
image = self._load_img(img)
if image.width < self.width and image.height < self.height:
print(
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
logger.warning(
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
)
# if image has a transparent area and no mask was provided, then try to generate mask
@ -809,8 +810,8 @@ class Generate:
if (image.width * image.height) > (
self.width * self.height
) and self.size_matters:
print(
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
logger.info(
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
)
self.size_matters = False
@ -891,11 +892,11 @@ class Generate:
try:
model_data = cache.get_model(model_name)
except Exception as e:
print(f"** model {model_name} could not be loaded: {str(e)}")
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
print(traceback.format_exc(), file=sys.stderr)
if previous_model_name is None:
raise e
print("** trying to reload previous model")
logger.warning("trying to reload previous model")
model_data = cache.get_model(previous_model_name) # load previous
if model_data is None:
raise e
@ -962,15 +963,15 @@ class Generate:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
print(
">> GFPGAN not found. Face restoration is disabled."
logger.info(
"GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
print(
">> CodeFormer not found. Face restoration is disabled."
logger.info(
"CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@ -984,7 +985,7 @@ class Generate:
fidelity=codeformer_fidelity,
)
else:
print(">> Face Restoration is disabled.")
logger.info("Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@ -997,10 +998,10 @@ class Generate:
denoise_str=upscale_denoise_str,
)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
logger.info("ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
@ -1066,17 +1067,17 @@ class Generate:
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
msg = (
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
)
self.sampler = sampler_class.from_config(self.model.scheduler.config)
else:
msg = (
f">> Unsupported Sampler: {self.sampler_name} "
f" Unsupported Sampler: {self.sampler_name} "+
f"Defaulting to {default}"
)
self.sampler = default
print(msg)
logger.info(msg)
if not hasattr(self.sampler, "uses_inpainting_model"):
# FIXME: terrible kludge!
@ -1085,17 +1086,17 @@ class Generate:
def _load_img(self, img) -> Image:
if isinstance(img, Image.Image):
image = img
print(f">> using provided input image of size {image.width}x{image.height}")
logger.info(f"using provided input image of size {image.width}x{image.height}")
elif isinstance(img, str):
assert os.path.exists(img), f">> {img}: File not found"
assert os.path.exists(img), f"{img}: File not found"
image = Image.open(img)
print(
f">> loaded input image of size {image.width}x{image.height} from {img}"
logger.info(
f"loaded input image of size {image.width}x{image.height} from {img}"
)
else:
image = Image.open(img)
print(f">> loaded input image of size {image.width}x{image.height}")
logger.info(f"loaded input image of size {image.width}x{image.height}")
image = ImageOps.exif_transpose(image)
return image
@ -1183,14 +1184,14 @@ class Generate:
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
if not mask:
print(
">> Initial image has transparent areas. Will inpaint in these regions."
logger.info(
"Initial image has transparent areas. Will inpaint in these regions."
)
if (not force_outpaint) and self._check_for_erasure(image):
print(
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
if (not force_outpaint) and self._check_for_erasure(image):
logger.info(
"Colors underneath the transparent region seem to have been erased.\n" +
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
)
def _squeeze_image(self, image):
@ -1201,11 +1202,11 @@ class Generate:
def _fit_image(self, image, max_dimensions):
w, h = max_dimensions
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
# note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(width=w, height=h)
print(
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
logger.info(
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
)
return image
@ -1216,8 +1217,8 @@ class Generate:
) # resize to integer multiple of 64
if h != height or w != width:
if log:
print(
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
logger.info(
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
)
height = h
width = w

View File

@ -25,6 +25,7 @@ from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
@ -372,7 +373,7 @@ class Generator:
try:
x_T = self.get_noise(width, height)
except:
print("** An error occurred while getting initial noise **")
logger.error("An error occurred while getting initial noise")
print(traceback.format_exc())
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
@ -607,7 +608,7 @@ class Generator:
image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or "."
if not os.path.exists(dirname):
print(f"** creating directory {dirname}")
logger.info(f"creating directory {dirname}")
os.makedirs(dirname, exist_ok=True)
image.save(filepath, "PNG")

View File

@ -8,10 +8,11 @@ import torch
from PIL import Image
from tqdm import trange
import invokeai.backend.util.logging as logger
from .base import Generator
from .img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@ -72,22 +73,22 @@ class Embiggen(Generator):
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
print(
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
logger.warning(
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
print(
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
logger.warning(
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
print(
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
logger.warning(
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
@ -97,8 +98,8 @@ class Embiggen(Generator):
embiggen_tiles.sort()
if strength >= 0.5:
print(
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
logger.warning(
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it
@ -121,8 +122,8 @@ class Embiggen(Generator):
from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
print(
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
logger.info(
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2:
initsuperimage = esrgan.process(
@ -312,10 +313,10 @@ class Embiggen(Generator):
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
else:
print(
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
logger.info(
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = []
@ -361,11 +362,11 @@ class Embiggen(Generator):
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
print(
logger.debug(
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else:
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
@ -547,8 +548,8 @@ class Embiggen(Generator):
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
print(
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
logger.error(
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image

View File

@ -14,6 +14,8 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
from ..stable_diffusion.diffusers_pipeline import ConditioningData
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
import invokeai.backend.util.logging as logger
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@ -77,8 +79,8 @@ class Txt2Img2Img(Generator):
# the message below is accurate.
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
logger.info(
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing

View File

@ -5,10 +5,9 @@ wraps the actual patchmatch object. It respects the global
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
@ -28,12 +27,12 @@ class PatchMatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
print(">> Patchmatch initialized")
logger.info("Patchmatch initialized")
else:
print(">> Patchmatch not loaded (nonfatal)")
logger.info("Patchmatch not loaded (nonfatal)")
self.patch_match = pm
else:
print(">> Patchmatch loading disabled")
logger.info("Patchmatch loading disabled")
self.tried_load = True
@classmethod

View File

@ -30,9 +30,9 @@ work fine.
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
@ -83,7 +83,7 @@ class Txt2Mask(object):
"""
def __init__(self, device="cpu", refined=False):
print(">> Initializing clipseg model for text to mask inference")
logger.info("Initializing clipseg model for text to mask inference")
# BUG: we are not doing anything with the device option at this time
self.device = device
@ -101,18 +101,6 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.Resize(
(CLIPSEG_SIZE, CLIPSEG_SIZE)
), # must be multiple of 64...
]
)
if type(image) is str:
image = Image.open(image).convert("RGB")

View File

@ -25,6 +25,7 @@ from typing import Union
import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir, global_config_dir
from .model_manager import ModelManager, SDLegacyType
@ -372,9 +373,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(" | Extracting EMA weights (usually better for inference)")
logger.debug("Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
@ -392,8 +393,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
logger.debug(
"Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
logger.debug("global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@ -1229,15 +1230,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
logger.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE")
logger.debug("Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size

View File

@ -18,12 +18,13 @@ import warnings
from enum import Enum, auto
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union, Callable
from typing import Any, Optional, Union, Callable, types
import safetensors
import safetensors.torch
import torch
import transformers
import invokeai.backend.util.logging as logger
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
@ -75,6 +76,8 @@ class ModelManager(object):
Model manager handles loading, caching, importing, deleting, converting, and editing models.
"""
logger: types.ModuleType = logger
def __init__(
self,
config: OmegaConf | Path,
@ -83,6 +86,7 @@ class ModelManager(object):
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path = None,
logger: types.ModuleType = logger,
):
"""
Initialize with the path to the models.yaml config file or
@ -104,6 +108,7 @@ class ModelManager(object):
self.current_model = None
self.sequential_offload = sequential_offload
self.embedding_path = embedding_path
self.logger = logger
def valid_model(self, model_name: str) -> bool:
"""
@ -132,8 +137,8 @@ class ModelManager(object):
)
if not self.valid_model(model_name):
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
self.logger.error(
f'"{model_name}" is not a known model name. Please check your models.yaml file'
)
return self.current_model
@ -144,7 +149,7 @@ class ModelManager(object):
if model_name in self.models:
requested_model = self.models[model_name]["model"]
print(f">> Retrieving model {model_name} from system RAM cache")
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
requested_model.ready()
width = self.models[model_name]["width"]
height = self.models[model_name]["height"]
@ -379,7 +384,7 @@ class ModelManager(object):
"""
omega = self.config
if model_name not in omega:
print(f"** Unknown model {model_name}")
self.logger.error(f"Unknown model {model_name}")
return
# save these for use in deletion later
conf = omega[model_name]
@ -392,13 +397,13 @@ class ModelManager(object):
self.stack.remove(model_name)
if delete_files:
if weights:
print(f"** Deleting file {weights}")
self.logger.info(f"Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
print(f"** Deleting directory {path}")
self.logger.info(f"Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
print(f"** Deleting the cached model directory for {repo_id}")
self.logger.info(f"Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
@ -439,7 +444,7 @@ class ModelManager(object):
def _load_model(self, model_name: str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
print(
self.logger.error(
f'"{model_name}" is not a known model name. Please check your models.yaml file'
)
return
@ -457,7 +462,7 @@ class ModelManager(object):
model_format = mconfig.get("format", "ckpt")
if model_format == "ckpt":
weights = mconfig.weights
print(f">> Loading {model_name} from {weights}")
self.logger.info(f"Loading {model_name} from {weights}")
model, width, height, model_hash = self._load_ckpt_model(
model_name, mconfig
)
@ -473,13 +478,15 @@ class ModelManager(object):
# usage statistics
toc = time.time()
print(">> Model loaded in", "%4.2fs" % (toc - tic))
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
if self._has_cuda():
print(
">> Max VRAM used to load the model:",
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
"\n>> Current VRAM usage:"
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
self.logger.info(
"Max VRAM used to load the model: "+
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
)
self.logger.info(
"Current VRAM usage: "+
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
)
return model, width, height, model_hash
@ -487,11 +494,11 @@ class ModelManager(object):
name_or_path = self.model_name_or_path(mconfig)
using_fp16 = self.precision == "float16"
print(f">> Loading diffusers model from {name_or_path}")
self.logger.info(f"Loading diffusers model from {name_or_path}")
if using_fp16:
print(" | Using faster float16 precision")
self.logger.debug("Using faster float16 precision")
else:
print(" | Using more accurate float32 precision")
self.logger.debug("Using more accurate float32 precision")
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
@ -523,8 +530,8 @@ class ModelManager(object):
if str(e).startswith("fp16 is not a valid"):
pass
else:
print(
f"** An unexpected error occurred while downloading the model: {e})"
self.logger.error(
f"An unexpected error occurred while downloading the model: {e})"
)
if pipeline:
break
@ -542,7 +549,7 @@ class ModelManager(object):
# square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
self.logger.debug(f"Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash
@ -559,14 +566,14 @@ class ModelManager(object):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
from . import load_pipeline_from_original_stable_diffusion_ckpt
try:
if self.list_models()[self.current_model]["status"] == "active":
self.offload_model(self.current_model)
except Exception as e:
except Exception:
pass
vae_path = None
@ -624,7 +631,7 @@ class ModelManager(object):
if model_name not in self.models:
return
print(f">> Offloading {model_name} to CPU")
self.logger.info(f"Offloading {model_name} to CPU")
model = self.models[model_name]["model"]
model.offload_all()
self.current_model = None
@ -640,30 +647,26 @@ class ModelManager(object):
and option to exit if an infected file is identified.
"""
# scan model
print(f" | Scanning Model: {model_name}")
self.logger.debug(f"Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
print(
"### WARNING: The model you are trying to load seems to be infected."
)
print("### For your safety, InvokeAI will not load this model.")
print("### Please use checkpoints from trusted sources.")
print("### Exiting InvokeAI")
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
self.logger.critical("The model you are trying to load seems to be infected.")
self.logger.critical("For your safety, InvokeAI will not load this model.")
self.logger.critical("Please use checkpoints from trusted sources.")
self.logger.critical("Exiting InvokeAI")
sys.exit()
else:
print(
"\n### WARNING: InvokeAI was unable to scan the model you are using."
)
self.logger.warning("InvokeAI was unable to scan the model you are using.")
model_safe_check_fail = ask_user(
"Do you want to to continue loading the model?", ["y", "n"]
)
if model_safe_check_fail.lower() != "y":
print("### Exiting InvokeAI")
self.logger.critical("Exiting InvokeAI")
sys.exit()
else:
print(" | Model scanned ok")
self.logger.debug("Model scanned ok")
def import_diffuser_model(
self,
@ -780,26 +783,24 @@ class ModelManager(object):
model_path: Path = None
thing = path_url_or_repo # to save typing
print(f">> Probing {thing} for import")
self.logger.info(f"Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL")
self.logger.info(f"{thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
return
else:
print(f" | {thing} appears to be a checkpoint file on disk")
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
print(f" | {thing} appears to be a diffusers file on disk")
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -810,34 +811,30 @@ class ModelManager(object):
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
print(f" | {thing} appears to be a diffusers model.")
self.logger.debug(f"{thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
print(
f" |{thing} appears to be a directory. Will scan for models to import"
)
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m), commit_to_conf=commit_to_conf
):
print(f" >> {model_name} successfully imported")
self.logger.info(f"{model_name} successfully imported")
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
return model_name
else:
print(
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
)
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
# Model_path is set in the event of a legacy checkpoint file.
# If not set, we're all done
@ -845,7 +842,7 @@ class ModelManager(object):
return
if model_path.stem in self.config: # already imported
print(" | Already imported. Skipping")
self.logger.debug("Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
@ -861,39 +858,39 @@ class ModelManager(object):
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
print(f" | Using config file {model_config_file.name}")
self.logger.debug(f"Using config file {model_config_file.name}")
else:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
self.logger.debug("SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
elif model_type == SDLegacyType.V2_v:
print(" | SD-v2-v model detected")
self.logger.debug("SD-v2-v model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
print(" | SD-v2-e model detected")
self.logger.debug("SD-v2-e model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
elif model_type == SDLegacyType.V2:
print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
print(
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
@ -909,7 +906,7 @@ class ModelManager(object):
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
print(f" | Using VAE file {vae_path.name}")
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path(
@ -955,14 +952,14 @@ class ModelManager(object):
from . import convert_ckpt_to_diffusers
if diffusers_path.exists():
print(
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
self.logger.error(
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
)
return
model_name = model_name or diffusers_path.name
model_description = model_description or f"Converted version of {model_name}"
print(f" | Converting {model_name} to diffusers (30-60s)")
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
@ -979,10 +976,10 @@ class ModelManager(object):
vae_path=vae_path,
scan_needed=scan_needed,
)
print(
f" | Success. Converted model is now located at {str(diffusers_path)}"
self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}"
)
print(f" | Writing new config file entry for {model_name}")
self.logger.debug(f"Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
@ -993,17 +990,17 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
print(" | Conversion succeeded")
self.logger.debug("Conversion succeeded")
except Exception as e:
print(f"** Conversion failed: {str(e)}")
print(
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
self.logger.warning(f"Conversion failed: {str(e)}")
self.logger.warning(
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
)
return model_name
def search_models(self, search_folder):
print(f">> Finding Models In: {search_folder}")
self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
@ -1027,8 +1024,8 @@ class ModelManager(object):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
self.logger.info(
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
)
if least_recent_model is not None:
del self.models[least_recent_model]
@ -1036,8 +1033,8 @@ class ModelManager(object):
def print_vram_usage(self) -> None:
if self._has_cuda:
print(
">> Current VRAM usage: ",
self.logger.info(
"Current VRAM usage:"+
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
@ -1126,10 +1123,10 @@ class ModelManager(object):
dest = hub / model.stem
if dest.exists() and not source.exists():
continue
print(f"** {source} => {dest}")
cls.logger.info(f"{source} => {dest}")
if source.exists():
if dest.is_symlink():
print(f"** Found symlink at {dest.name}. Not migrating.")
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
elif dest.exists():
if source.is_dir():
rmtree(source)
@ -1146,7 +1143,7 @@ class ModelManager(object):
]
for d in empty:
os.rmdir(d)
print("** Migration is done. Continuing...")
cls.logger.info("Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
@ -1189,15 +1186,15 @@ class ModelManager(object):
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
if self.embedding_path is not None:
print(f">> Loading embeddings from {self.embedding_path}")
self.logger.info(f"Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path):
for name in files:
ti_path = os.path.join(root, name)
model.textual_inversion_manager.load_textual_inversion(
ti_path, defer_injecting_tokens=True
)
print(
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
self.logger.info(
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
)
def _has_cuda(self) -> bool:
@ -1219,7 +1216,7 @@ class ModelManager(object):
with open(hashpath) as f:
hash = f.read()
return hash
print(" | Calculating sha256 hash of model files")
self.logger.debug("Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
@ -1231,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
@ -1249,13 +1246,13 @@ class ModelManager(object):
hash = f.read()
return hash
print(" | Calculating sha256 hash of weights file")
self.logger.debug("Calculating sha256 hash of weights file")
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
@ -1276,12 +1273,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available,
)
print(f" | Loading diffusers VAE from {name_or_path}")
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
print(" | Using more accurate float32 precision")
self.logger.debug("Using more accurate float32 precision")
fp_args_list = [{}]
vae = None
@ -1305,12 +1302,12 @@ class ModelManager(object):
break
if not vae and deferred_error:
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
return vae
@staticmethod
def _delete_model_from_cache(repo_id):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub"))
# I'm sure there is a way to do this with comprehensions
@ -1321,8 +1318,8 @@ class ModelManager(object):
for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
cls.logger.warning(
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()

View File

@ -18,6 +18,7 @@ from compel.prompt_parser import (
PromptParser,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from ..stable_diffusion import InvokeAIDiffuserComponent
@ -162,8 +163,8 @@ def log_tokenization(
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
@ -237,12 +238,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
usedTokens += 1
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
logger.debug(f"{tokenized}\x1b[0m")
if discarded != "":
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
@ -295,8 +296,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
logger.warning(
"Subprompt weights add up to zero. Discarding and using even weights instead."
)
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]

View File

@ -1,3 +1,5 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
@ -8,17 +10,17 @@ class Restoration:
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
print(">> GFPGAN Initialized")
logger.info("GFPGAN Initialized")
else:
print(">> GFPGAN Disabled")
logger.info("GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
print(">> CodeFormer Initialized")
logger.info("CodeFormer Initialized")
else:
print(">> CodeFormer Disabled")
logger.info("CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
@ -39,5 +41,5 @@ class Restoration:
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
print(">> ESRGAN Initialized")
logger.info("ESRGAN Initialized")
return esrgan

View File

@ -5,6 +5,7 @@ import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from ..globals import Globals
pretrained_model_url = (
@ -23,12 +24,12 @@ class CodeFormerRestoration:
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
@ -97,7 +98,7 @@ class CodeFormerRestoration:
del output
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"\tFailed inference for CodeFormer: {error}.")
logger.error(f"Failed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")

View File

@ -6,9 +6,9 @@ import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
if not os.path.isabs(gfpgan_model_path):
@ -19,7 +19,7 @@ class GFPGAN:
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
return None
def model_exists(self):
@ -27,7 +27,7 @@ class GFPGAN:
def process(self, image, strength: float, seed: str = None):
if seed is not None:
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
@ -47,14 +47,14 @@ class GFPGAN:
except Exception:
import traceback
print(">> Error loading GFPGAN:", file=sys.stderr)
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
print(f">> WARNING: GFPGAN not initialized.")
print(
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")

View File

@ -1,7 +1,7 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
@ -82,7 +82,7 @@ class Outcrop(object):
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
print(f">> extending image {direction}ward by {pixels} pixels")
logger.info(f"extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)

View File

@ -6,18 +6,13 @@ import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
@ -74,16 +69,16 @@ class ESRGAN:
import sys
import traceback
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
logger.error("Error loading Real-ESRGAN:")
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
print(
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")

View File

@ -14,6 +14,7 @@ from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from .globals import global_cache_dir
from .util import CPU_DEVICE
@ -40,8 +41,8 @@ class SafetyChecker(object):
cache_dir=safety_model_path,
)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
logger.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
@ -65,8 +66,8 @@ class SafetyChecker(object):
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:

View File

@ -17,6 +17,7 @@ from huggingface_hub import (
hf_hub_url,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
@ -66,11 +67,11 @@ class HuggingFaceConceptsLibrary(object):
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
print(
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
logger.warning(
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
print(
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
logger.warning(
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
@ -83,7 +84,7 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
"""
if not concept_name in self.list_concepts():
print(
logger.warning(
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
@ -221,7 +222,7 @@ class HuggingFaceConceptsLibrary(object):
if chunk == 0:
bytes += total
print(f">> Downloading {repo_id}...", end="")
logger.info(f"Downloading {repo_id}...", end="")
try:
for file in (
"README.md",
@ -235,22 +236,22 @@ class HuggingFaceConceptsLibrary(object):
)
except ul_error.HTTPError as e:
if e.code == 404:
print(
logger.warning(
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
logger.warning(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest)
return False
except ul_error.URLError as e:
print(
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
logger.error(
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False
print("...{:.2f}Kb".format(bytes / 1024))
logger.info("...{:.2f}Kb".format(bytes / 1024))
return succeeded
def _concept_id(self, concept_name: str) -> str:

View File

@ -13,9 +13,9 @@ from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import AttentionProcessor
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -421,7 +421,7 @@ def get_cross_attention_modules(
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "

View File

@ -8,6 +8,7 @@ import torch
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from .cross_attention_control import (
@ -466,10 +467,14 @@ class InvokeAIDiffuserComponent:
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
logger.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
@ -496,9 +501,11 @@ class InvokeAIDiffuserComponent:
)
if self.debug_thresholding:
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
logger.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
logger.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents

View File

@ -10,7 +10,7 @@ from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -191,7 +191,7 @@ def mkdirs(paths):
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
print("Path already exists. Rename it to [{:s}]".format(new_name))
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
os.replace(path, new_name)
os.makedirs(path)

View File

@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
print(f">> Loaded local embedding for trigger {concept_name}")
logger.info(f"Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f">> Loaded remote embedding for trigger {concept_name}")
logger.info(f"Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name] = True
@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print(
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
logger.warning(
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
)
continue
@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
logger.info(
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
logger.debug(f"The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
logger.warning(
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
except ValueError as e:
if str(e).startswith("Warning"):
print(f">> {str(e)}")
logger.warning(f"{str(e)}")
else:
traceback.print_exc()
print(
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
logger.error(
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
logger.info(
f"Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
logger.debug(
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
)
print(f" | The error was {str(e)}")
logger.debug(f"The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
logger.critical(
f"Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
logger.critical("For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
logger.debug(f'Loading v1 embedding file: {basename}')
embeddings = list()
token_counter = -1
@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = f'<basename>'
trigger = '<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
logger.debug(f'Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance(
@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
)
embeddings.append(embedding_info)
else:
print(f" ** {basename}: Unrecognized embedding format")
logger.warning(f"{basename}: Unrecognized embedding format")
return embeddings
@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
logger.debug(f'Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
logger.debug(f'Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
logger.warning(f"Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(

View File

@ -0,0 +1,109 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.util.logging
Logging class for InvokeAI that produces console messages that follow
the conventions established in InvokeAI 1.X through 2.X.
One way to use it:
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(__name__)
logger.critical('this is critical')
logger.error('this is an error')
logger.warning('this is a warning')
logger.info('this is info')
logger.debug('this is debugging')
Console messages:
### this is critical
*** this is an error ***
** this is a warning
>> this is info
| this is debugging
Another way:
import invokeai.backend.util.logging as ialog
ialogger.debug('this is a debugging message')
"""
import logging
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs):
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs):
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL):
InvokeAILogger.getLogger().disable(level)
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str=None)->logging.Logger:
return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter):
'''
Repurposed from:
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
'''
crit_fmt = "### %(msg)s"
err_fmt = "*** %(msg)s"
warn_fmt = "** %(msg)s"
info_fmt = ">> %(msg)s"
dbg_fmt = " | %(msg)s"
def __init__(self):
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
def format(self, record):
# Remember the format used when the logging module
# was installed (in the event that this formatter is
# used with the vanilla logging module.
format_orig = self._style._fmt
if record.levelno == logging.DEBUG:
self._style._fmt = InvokeAILogFormatter.dbg_fmt
if record.levelno == logging.INFO:
self._style._fmt = InvokeAILogFormatter.info_fmt
if record.levelno == logging.WARNING:
self._style._fmt = InvokeAILogFormatter.warn_fmt
if record.levelno == logging.ERROR:
self._style._fmt = InvokeAILogFormatter.err_fmt
if record.levelno == logging.CRITICAL:
self._style._fmt = InvokeAILogFormatter.crit_fmt
# parent class does the work
result = super().format(record)
self._style._fmt = format_orig
return result
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name:str='invokeai')->logging.Logger:
if name not in self.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
self.loggers[name] = logger
return self.loggers[name]

View File

@ -18,6 +18,7 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import invokeai.backend.util.logging as logger
from .devices import torch_dtype
@ -38,7 +39,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
logger.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -80,8 +81,8 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
logger.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
@ -132,8 +133,8 @@ def parallel_data_prefetch(
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
logger.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
@ -175,7 +176,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print("Start prefetching...")
logger.info("Start prefetching...")
import time
start = time.time()
@ -194,7 +195,7 @@ def parallel_data_prefetch(
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
logger.error("Exception: ", e)
for p in processes:
p.terminate()
@ -202,7 +203,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
@ -318,23 +319,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
resp = requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
print("* corrupt existing file found. re-downloading")
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or exist_size == content_length:
print(f"* {dest}: complete file found. Skipping.")
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
logger.warning(f"{dest}: partial file found. Resuming...")
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
else:
print(f"* {dest}: Downloading...")
logger.error(f"{dest}: Downloading...")
try:
if content_length < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
@ -349,7 +350,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest

View File

@ -19,6 +19,7 @@ from PIL import Image
from PIL.Image import Image as ImageType
from werkzeug.utils import secure_filename
import invokeai.backend.util.logging as logger
import invokeai.frontend.web.dist as frontend
from .. import Generate
@ -213,7 +214,7 @@ class InvokeAIWebServer:
self.load_socketio_listeners(self.socketio)
if args.gui:
print(">> Launching Invoke AI GUI")
logger.info("Launching Invoke AI GUI")
try:
from flaskwebgui import FlaskUI
@ -231,17 +232,17 @@ class InvokeAIWebServer:
sys.exit(0)
else:
useSSL = args.certfile or args.keyfile
print(">> Started Invoke AI Web Server")
logger.info("Started Invoke AI Web Server")
if self.host == "0.0.0.0":
print(
logger.info(
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
)
else:
print(
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
logger.info(
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
)
print(
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
logger.info(
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
)
if not useSSL:
self.socketio.run(app=self.app, host=self.host, port=self.port)
@ -273,7 +274,7 @@ class InvokeAIWebServer:
# path for thumbnail images
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
# txt log
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
# make all output paths
[
os.makedirs(path, exist_ok=True)
@ -290,7 +291,7 @@ class InvokeAIWebServer:
def load_socketio_listeners(self, socketio):
@socketio.on("requestSystemConfig")
def handle_request_capabilities():
print(">> System config requested")
logger.info("System config requested")
config = self.get_system_config()
config["model_list"] = self.generate.model_manager.list_models()
config["infill_methods"] = infill_methods()
@ -330,7 +331,7 @@ class InvokeAIWebServer:
if model_name in current_model_list:
update = True
print(f">> Adding New Model: {model_name}")
logger.info(f"Adding New Model: {model_name}")
self.generate.model_manager.add_model(
model_name=model_name,
@ -348,14 +349,14 @@ class InvokeAIWebServer:
"update": update,
},
)
print(f">> New Model Added: {model_name}")
logger.info(f"New Model Added: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("deleteModel")
def handle_delete_model(model_name: str):
try:
print(f">> Deleting Model: {model_name}")
logger.info(f"Deleting Model: {model_name}")
self.generate.model_manager.del_model(model_name)
self.generate.model_manager.commit(opt.conf)
updated_model_list = self.generate.model_manager.list_models()
@ -366,14 +367,14 @@ class InvokeAIWebServer:
"model_list": updated_model_list,
},
)
print(f">> Model Deleted: {model_name}")
logger.info(f"Model Deleted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("requestModelChange")
def handle_set_model(model_name: str):
try:
print(f">> Model change requested: {model_name}")
logger.info(f"Model change requested: {model_name}")
model = self.generate.set_model(model_name)
model_list = self.generate.model_manager.list_models()
if model is None:
@ -454,7 +455,7 @@ class InvokeAIWebServer:
"update": True,
},
)
print(f">> Model Converted: {model_name}")
logger.info(f"Model Converted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@ -490,7 +491,7 @@ class InvokeAIWebServer:
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
"vae", None
):
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
@ -507,8 +508,8 @@ class InvokeAIWebServer:
"update": True,
},
)
print(f">> Models Merged: {models_to_merge}")
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
logger.info(f"Models Merged: {models_to_merge}")
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@ -698,7 +699,7 @@ class InvokeAIWebServer:
}
)
except Exception as e:
print(f">> Unable to load {path}")
logger.info(f"Unable to load {path}")
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
@ -735,9 +736,9 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..."
)
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
print(f">> ESRGAN Parameters: {esrgan_parameters}")
print(f">> Facetool Parameters: {facetool_parameters}")
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
logger.info(f"Facetool Parameters: {facetool_parameters}")
self.generate_images(
generation_parameters,
@ -750,8 +751,8 @@ class InvokeAIWebServer:
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
try:
print(
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
logger.info(
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
)
progress = Progress()
@ -861,14 +862,14 @@ class InvokeAIWebServer:
@socketio.on("cancel")
def handle_cancel():
print(">> Cancel processing requested")
logger.info("Cancel processing requested")
self.canceled.set()
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
def handle_delete_image(url, thumbnail, uuid, category):
try:
print(f'>> Delete requested "{url}"')
logger.info(f'Delete requested "{url}"')
from send2trash import send2trash
path = self.get_image_path_from_url(url)
@ -1263,7 +1264,7 @@ class InvokeAIWebServer:
image, os.path.basename(path), self.thumbnail_image_path
)
print(f'\n\n>> Image generated: "{path}"\n')
logger.info(f'Image generated: "{path}"\n')
self.write_log_message(f'[Generated] "{path}": {command}')
if progress.total_iterations > progress.current_iteration:
@ -1329,7 +1330,7 @@ class InvokeAIWebServer:
except Exception as e:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
print(e)
logger.error(e)
self.handle_exceptions(e)
def empty_cuda_cache(self):

View File

@ -16,6 +16,7 @@ if sys.platform == "darwin":
import pyparsing # type: ignore
import invokeai.version as invokeai
import invokeai.backend.util.logging as logger
from ...backend import Generate, ModelManager
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
@ -69,7 +70,7 @@ def main():
# run any post-install patches needed
run_patches()
print(f">> Internet connectivity is {Globals.internet_available}")
logger.info(f"Internet connectivity is {Globals.internet_available}")
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
@ -78,8 +79,8 @@ def main():
opt, FileNotFoundError(f"The file {config_file} could not be found.")
)
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup
# these two lines prevent a horrible warning message from appearing
@ -121,7 +122,7 @@ def main():
else:
raise FileNotFoundError(f"{opt.infile} not found.")
except (FileNotFoundError, IOError) as e:
print(f"{e}. Aborting.")
logger.critical('Aborted',exc_info=True)
sys.exit(-1)
# creating a Generate object:
@ -142,12 +143,12 @@ def main():
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
except (IOError, KeyError):
logger.critical("Aborted",exc_info=True)
sys.exit(-1)
if opt.seamless:
print(">> changed to seamless tiling mode")
logger.info("Changed to seamless tiling mode")
# preload the model
try:
@ -180,9 +181,7 @@ def main():
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
)
except Exception:
print(">> An error occurred:")
traceback.print_exc()
logger.error("An error occurred",exc_info=True)
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt):
@ -248,7 +247,7 @@ def main_loop(gen, opt):
if not opt.prompt:
oldargs = metadata_from_png(opt.init_img)
opt.prompt = oldargs.prompt
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
except (OSError, AttributeError, KeyError):
pass
@ -265,9 +264,9 @@ def main_loop(gen, opt):
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
print(f">> Reusing previous image {opt.init_img}")
logger.info(f"Reusing previous image {opt.init_img}")
except IndexError:
print(f">> No previous initial image at position {opt.init_img} found")
logger.info(f"No previous initial image at position {opt.init_img} found")
opt.init_img = None
continue
@ -288,9 +287,9 @@ def main_loop(gen, opt):
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
try:
opt.seed = last_results[opt.seed][1]
print(f">> Reusing previous seed {opt.seed}")
logger.info(f"Reusing previous seed {opt.seed}")
except IndexError:
print(f">> No previous seed at position {opt.seed} found")
logger.info(f"No previous seed at position {opt.seed} found")
opt.seed = None
continue
@ -309,7 +308,7 @@ def main_loop(gen, opt):
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
current_outdir = os.path.join(opt.outdir, subdir)
print('Writing files to directory: "' + current_outdir + '"')
logger.info('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists
if not os.path.exists(current_outdir):
@ -438,15 +437,14 @@ def main_loop(gen, opt):
catch_interrupts=catch_ctrl_c,
**vars(opt),
)
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
print("** An error occurred while processing your prompt **")
print(f"** {str(e)} **")
except (PromptParser.ParsingException, pyparsing.ParseException):
logger.error("An error occurred while processing your prompt",exc_info=True)
elif operation == "postprocess":
print(f">> fixing {opt.prompt}")
logger.info(f"fixing {opt.prompt}")
opt.last_operation = do_postprocess(gen, opt, image_writer)
elif operation == "mask":
print(f">> generating masks from {opt.prompt}")
logger.info(f"generating masks from {opt.prompt}")
do_textmask(gen, opt, image_writer)
if opt.grid and len(grid_images) > 0:
@ -469,12 +467,12 @@ def main_loop(gen, opt):
)
results = [[path, formatted_dream_prompt]]
except AssertionError as e:
print(e)
except AssertionError:
logger.error(e)
continue
except OSError as e:
print(e)
logger.error(e)
continue
print("Outputs:")
@ -513,7 +511,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
gen.set_model(model_name)
add_embedding_terms(gen, completer)
except KeyError as e:
print(str(e))
logger.error(e)
except Exception as e:
report_model_error(opt, e)
completer.add_history(command)
@ -527,8 +525,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!import"):
path = shlex.split(command)
if len(path) < 2:
print(
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
logger.warning(
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
)
else:
try:
@ -541,7 +539,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith(("!convert", "!optimize")):
path = shlex.split(command)
if len(path) < 2:
print("** please provide the path to a .ckpt or .safetensors model")
logger.warning("please provide the path to a .ckpt or .safetensors model")
else:
try:
convert_model(path[1], gen, opt, completer)
@ -553,7 +551,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!edit"):
path = shlex.split(command)
if len(path) < 2:
print("** please provide the name of a model")
logger.warning("please provide the name of a model")
else:
edit_model(path[1], gen, opt, completer)
completer.add_history(command)
@ -562,7 +560,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!del"):
path = shlex.split(command)
if len(path) < 2:
print("** please provide the name of a model")
logger.warning("please provide the name of a model")
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
@ -642,8 +640,8 @@ def import_model(model_path: str, gen, opt, completer):
try:
default_name = url_attachment_name(model_path)
default_name = Path(default_name).stem
except Exception as e:
print(f"** URL: {str(e)}")
except Exception:
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
model_name, model_desc = _get_model_name_and_desc(
gen.model_manager,
completer,
@ -664,11 +662,11 @@ def import_model(model_path: str, gen, opt, completer):
model_config_file=config_file,
)
if not imported_name:
print("** Aborting import.")
logger.error("Aborting import.")
return
if not _verify_load(imported_name, gen):
print("** model failed to load. Discarding configuration entry")
logger.error("model failed to load. Discarding configuration entry")
gen.model_manager.del_model(imported_name)
return
if click.confirm("Make this the default model?", default=False):
@ -676,7 +674,7 @@ def import_model(model_path: str, gen, opt, completer):
gen.model_manager.commit(opt.conf)
completer.update_models(gen.model_manager.list_models())
print(f">> {imported_name} successfully installed")
logger.info(f"{imported_name} successfully installed")
def _pick_configuration_file(completer)->Path:
print(
@ -720,21 +718,21 @@ Please select the type of this model:
return choice
def _verify_load(model_name: str, gen) -> bool:
print(">> Verifying that new model loads...")
logger.info("Verifying that new model loads...")
current_model = gen.model_name
try:
if not gen.set_model(model_name):
return
except Exception as e:
print(f"** model failed to load: {str(e)}")
print(
logger.warning(f"model failed to load: {str(e)}")
logger.warning(
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
)
return False
if click.confirm("Keep model loaded?", default=True):
gen.set_model(model_name)
else:
print(">> Restoring previous model")
logger.info("Restoring previous model")
gen.set_model(current_model)
return True
@ -757,7 +755,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **")
logger.warning("Can't convert the active model. !switch to another model first. **")
return
elif model_info := manager.model_info(model_name_or_path):
if "weights" in model_info:
@ -767,7 +765,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_description = model_info["description"]
vae_path = model_info.get("vae")
else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
return
model_name = manager.convert_and_import(
ckpt_path,
@ -788,16 +786,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
manager.commit(opt.conf)
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted")
logger.warning(f"{ckpt_path} deleted")
def del_config(model_name: str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
logger.warning("Can't delete active model. !switch to another model first. **")
return
if model_name not in gen.model_manager.config:
print(f"** Unknown model {model_name}")
logger.warning(f"Unknown model {model_name}")
return
if not click.confirm(
@ -810,17 +808,17 @@ def del_config(model_name: str, gen, opt, completer):
)
gen.model_manager.del_model(model_name, delete_files=delete_completely)
gen.model_manager.commit(opt.conf)
print(f"** {model_name} deleted")
logger.warning(f"{model_name} deleted")
completer.update_models(gen.model_manager.list_models())
def edit_model(model_name: str, gen, opt, completer):
manager = gen.model_manager
if not (info := manager.model_info(model_name)):
print(f"** Unknown model {model_name}")
logger.warning(f"** Unknown model {model_name}")
return
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
print()
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
new_name = _get_model_name(manager.list_models(), completer, model_name)
for attribute in info.keys():
@ -858,7 +856,7 @@ def edit_model(model_name: str, gen, opt, completer):
manager.set_default_model(new_name)
manager.commit(opt.conf)
completer.update_models(manager.list_models())
print(">> Model successfully updated")
logger.info("Model successfully updated")
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
@ -869,11 +867,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
if len(model_name) == 0:
model_name = default_name
if not re.match("^[\w._+:/-]+$", model_name):
print(
'** model name must contain only words, digits and the characters "._+:/-" **'
logger.warning(
'model name must contain only words, digits and the characters "._+:/-" **'
)
elif model_name != default_name and model_name in existing_names:
print(f"** the name {model_name} is already in use. Pick another.")
logger.warning(f"the name {model_name} is already in use. Pick another.")
else:
done = True
return model_name
@ -940,11 +938,10 @@ def do_postprocess(gen, opt, callback):
opt=opt,
)
except OSError:
print(traceback.format_exc(), file=sys.stderr)
print(f"** {file_path}: file could not be read")
logger.error(f"{file_path}: file could not be read",exc_info=True)
return
except (KeyError, AttributeError):
print(traceback.format_exc(), file=sys.stderr)
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
return
return opt.last_operation
@ -999,13 +996,13 @@ def prepare_image_metadata(
try:
filename = opt.fnformat.format(**wildcards)
except KeyError as e:
print(
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
logger.error(
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
)
filename = f"{prefix}.{seed}.png"
except IndexError:
print(
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
logger.error(
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
)
filename = f"{prefix}.{seed}.png"
@ -1094,14 +1091,14 @@ def split_variations(variations_string) -> list:
for part in variations_string.split(","):
seed_and_weight = part.split(":")
if len(seed_and_weight) != 2:
print(f'** Could not parse with_variation part "{part}"')
logger.warning(f'Could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'** Could not parse with_variation part "{part}"')
logger.warning(f'Could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
@ -1125,23 +1122,23 @@ def load_face_restoration(opt):
opt.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
logger.info("Face restoration disabled")
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
logger.info("Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
logger.info("Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
return gfpgan, codeformer, esrgan
def make_step_callback(gen, opt, prefix):
destination = os.path.join(opt.outdir, "intermediates", prefix)
os.makedirs(destination, exist_ok=True)
print(f">> Intermediate images will be written into {destination}")
logger.info(f"Intermediate images will be written into {destination}")
def callback(state: PipelineIntermediateState):
latents = state.latents
@ -1183,21 +1180,20 @@ def retrieve_dream_command(opt, command, completer):
try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f"## {tokens[0]}: file could not be read")
logger.error(f"{tokens[0]}: file could not be read")
except (KeyError, AttributeError, IndexError):
print(f"## {tokens[0]}: file has no metadata")
logger.error(f"{tokens[0]}: file has no metadata")
except:
print(f"## {tokens[0]}: file could not be processed")
logger.error(f"{tokens[0]}: file could not be processed")
if len(cmd) > 0:
completer.set_line(cmd)
def write_commands(opt, file_path: str, outfilepath: str):
dir, basename = os.path.split(file_path)
try:
paths = sorted(list(Path(dir).glob(basename)))
except ValueError:
print(f'## "{basename}": unacceptable pattern')
logger.error(f'"{basename}": unacceptable pattern')
return
commands = []
@ -1206,9 +1202,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
try:
cmd = dream_cmd_from_png(path)
except (KeyError, AttributeError, IndexError):
print(f"## {path}: file has no metadata")
logger.error(f"{path}: file has no metadata")
except:
print(f"## {path}: file could not be processed")
logger.error(f"{path}: file could not be processed")
if cmd:
commands.append(f"# {path}")
commands.append(cmd)
@ -1218,18 +1214,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
outfilepath = os.path.join(opt.outdir, basename)
with open(outfilepath, "w", encoding="utf-8") as f:
f.write("\n".join(commands))
print(f">> File {outfilepath} with commands created")
logger.info(f"File {outfilepath} with commands created")
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.warning(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
if not click.confirm(
@ -1238,7 +1234,7 @@ def report_model_error(opt: Namespace, e: Exception):
):
return
print("invokeai-configure is launching....\n")
logger.info("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
@ -1255,7 +1251,7 @@ def report_model_error(opt: Namespace, e: Exception):
from ..install import invokeai_configure
invokeai_configure()
print("** InvokeAI will now restart")
logger.warning("InvokeAI will now restart")
sys.argv = previous_args
main() # would rather do a os.exec(), but doesn't exist?
sys.exit(0)

View File

@ -22,6 +22,7 @@ import torch
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals, global_config_dir
from ...backend.config.model_install_backend import (
@ -455,8 +456,8 @@ def main():
Globals.root = os.path.expanduser(get_root(opt.root) or "")
if not global_config_dir().exists():
print(
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
)
from invokeai.frontend.install import invokeai_configure
@ -466,18 +467,18 @@ def main():
try:
select_and_download_models(opt)
except AssertionError as e:
print(str(e))
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")
logger.info("Goodbye! Come back soon.")
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** Insufficient vertical space for the interface. Please make your window taller and try again"
logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again"
)
elif str(e).startswith("addwstr"):
print(
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
)

View File

@ -27,6 +27,8 @@ from ...backend.globals import (
global_models_dir,
global_set_root,
)
import invokeai.backend.util.logging as logger
from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider
@ -113,7 +115,7 @@ def merge_diffusion_models_and_commit(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
print(f">> Using configured VAE assigned to {models[0]}")
logger.info(f"Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
@ -391,10 +393,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
]
print(model_names)
return sorted(model_names)
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
@ -414,7 +414,7 @@ def run_gui(args: Namespace):
args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args)
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
@ -425,8 +425,8 @@ def run_cli(args: Namespace):
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
print(
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(global_config_file()))
@ -435,7 +435,7 @@ def run_cli(args: Namespace):
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
print(f'>> Models merged into new model: "{args.merged_model_name}".')
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main():
@ -455,17 +455,16 @@ def main():
run_cli(args)
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
logger.error(
"You need to have at least two diffusers models defined in models.yaml in order to merge"
)
else:
print(
"** Not enough room for the user interface. Try making this window larger."
logger.error(
"Not enough room for the user interface. Try making this window larger."
)
sys.exit(-1)
except Exception:
print(">> An error occurred:")
traceback.print_exc()
except Exception as e:
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)

View File

@ -20,6 +20,7 @@ import npyscreen
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals, global_set_root
from ...backend.training import do_textual_inversion_training, parse_args
@ -368,14 +369,14 @@ def copy_to_embeddings_folder(args: dict):
dest_dir_name = args["placeholder_token"].strip("<>")
destination = Path(Globals.root, "embeddings", dest_dir_name)
os.makedirs(destination, exist_ok=True)
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
shutil.copy(source, destination)
if (
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
).startswith(("y", "Y")):
shutil.rmtree(Path(args["output_dir"]))
else:
print(f'>> Keeping {args["output_dir"]}')
logger.info(f'Keeping {args["output_dir"]}')
def save_args(args: dict):
@ -422,10 +423,10 @@ def do_front_end(args: Namespace):
do_textual_inversion_training(**args)
copy_to_embeddings_folder(args)
except Exception as e:
print("** An exception occurred during training. The exception was:")
print(str(e))
print("** DETAILS:")
print(traceback.format_exc())
logger.error("An exception occurred during training. The exception was:")
logger.error(str(e))
logger.error("DETAILS:")
logger.error(traceback.format_exc())
def main():
@ -437,21 +438,21 @@ def main():
else:
do_textual_inversion_training(**vars(args))
except AssertionError as e:
print(str(e))
logger.error(e)
sys.exit(-1)
except KeyboardInterrupt:
pass
except (widget.NotEnoughSpaceForWidget, Exception) as e:
if str(e).startswith("Height of 1 allocated"):
print(
"** You need to have at least one diffusers models defined in models.yaml in order to train"
logger.error(
"You need to have at least one diffusers models defined in models.yaml in order to train"
)
elif str(e).startswith("addwstr"):
print(
"** Not enough window space for the interface. Please make your window larger and try again."
logger.error(
"Not enough window space for the interface. Please make your window larger and try again."
)
else:
print(f"** An error has occurred: {str(e)}")
logger.error(e)
sys.exit(-1)

View File

@ -0,0 +1,40 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
export const appConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
build: {
chunkSizeWarningLimit: 1500,
},
server: {
// Proxy HTTP requests to the flask server
proxy: {
// Proxy socket.io to the nodes socketio server
'/ws/socket.io': {
target: 'ws://127.0.0.1:9090',
ws: true,
},
// Proxy openapi schema definiton
'/openapi.json': {
target: 'http://127.0.0.1:9090/openapi.json',
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
changeOrigin: true,
},
// proxy nodes api
'/api/v1': {
target: 'http://127.0.0.1:9090/api/v1',
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
changeOrigin: true,
},
},
},
};

View File

@ -0,0 +1,47 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
export const packageConfig: UserConfig = {
base: './',
plugins: [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
dts({
insertTypesEntry: true,
}),
],
build: {
chunkSizeWarningLimit: 1500,
lib: {
entry: path.resolve(__dirname, '../src/index.ts'),
name: 'InvokeAIUI',
fileName: (format) => `invoke-ai-ui.${format}.js`,
},
rollupOptions: {
external: ['react', 'react-dom', '@emotion/react'],
output: {
globals: {
react: 'React',
'react-dom': 'ReactDOM',
},
},
},
},
resolve: {
alias: {
app: path.resolve(__dirname, '../src/app'),
assets: path.resolve(__dirname, '../src/assets'),
common: path.resolve(__dirname, '../src/common'),
features: path.resolve(__dirname, '../src/features'),
services: path.resolve(__dirname, '../src/services'),
theme: path.resolve(__dirname, '../src/theme'),
},
},
};

View File

@ -1,98 +0,0 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
declare module 'redux-socket.io-middleware';
declare global {
/* eslint-disable @typescript-eslint/no-explicit-any */
interface Array<T> {
/**
* Returns the value of the last element in the array where predicate is true, and undefined
* otherwise.
* @param predicate findLast calls predicate once for each element of the array, in descending
* order, until it finds one where predicate returns true. If such an element is found, findLast
* immediately returns that element value. Otherwise, findLast returns undefined.
* @param thisArg If provided, it will be used as the this value for each invocation of
* predicate. If it is not provided, undefined is used instead.
*/
findLast<S extends T>(
predicate: (value: T, index: number, array: T[]) => value is S,
thisArg?: any
): S | undefined;
findLast(
predicate: (value: T, index: number, array: T[]) => unknown,
thisArg?: any
): T | undefined;
/**
* Returns the index of the last element in the array where predicate is true, and -1
* otherwise.
* @param predicate findLastIndex calls predicate once for each element of the array, in descending
* order, until it finds one where predicate returns true. If such an element is found,
* findLastIndex immediately returns that element index. Otherwise, findLastIndex returns -1.
* @param thisArg If provided, it will be used as the this value for each invocation of
* predicate. If it is not provided, undefined is used instead.
*/
findLastIndex(
predicate: (value: T, index: number, array: T[]) => unknown,
thisArg?: any
): number;
}
/* eslint-enable @typescript-eslint/no-explicit-any */
}
declare module '@invoke-ai/invoke-ai-ui' {
declare class ThemeChanger extends React.Component<ThemeChangerProps> {
public constructor(props: ThemeChangerProps);
}
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
public constructor(props: InvokeAILogoComponentProps);
}
declare class IAIPopover extends React.Component<IAIPopoverProps> {
public constructor(props: IAIPopoverProps);
}
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
public constructor(props: IAIIconButtonProps);
}
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
shouldFetchImages?: boolean;
}
declare function Invoke(props: InvokeProps): JSX.Element;
export {
ThemeChanger,
InvokeAiLogoComponent,
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@ -1,7 +1,23 @@
{
"name": "invoke-ai-ui",
"name": "@invoke-ai/invoke-ai-ui",
"private": true,
"version": "0.0.1",
"publishConfig": {
"access": "restricted",
"registry": "https://npm.pkg.github.com"
},
"main": "./dist/invoke-ai-ui.umd.js",
"module": "./dist/invoke-ai-ui.es.js",
"exports": {
".": {
"import": "./dist/invoke-ai-ui.es.js",
"require": "./dist/invoke-ai-ui.umd.js"
}
},
"types": "./dist/index.d.ts",
"files": [
"dist"
],
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
@ -40,81 +56,96 @@
},
"dependencies": {
"@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/cli": "^2.3.0",
"@chakra-ui/icons": "^2.0.17",
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.6.0",
"@chakra-ui/styled-system": "^2.9.0",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@fontsource/inter": "^4.5.15",
"@reduxjs/toolkit": "^1.9.3",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
"framer-motion": "^9.0.4",
"framer-motion": "^10.12.4",
"fuse.js": "^6.6.2",
"i18next": "^22.4.10",
"i18next": "^22.4.15",
"i18next-browser-languagedetector": "^7.0.1",
"i18next-http-backend": "^2.1.1",
"konva": "^8.4.2",
"lodash": "^4.17.21",
"patch-package": "^6.5.1",
"i18next-http-backend": "^2.2.0",
"konva": "^9.0.1",
"lodash-es": "^4.17.21",
"overlayscrollbars": "^2.1.1",
"overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0",
"re-resizable": "^6.9.9",
"react": "^18.2.0",
"react-colorful": "^5.6.1",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-hotkeys-hook": "4.3.5",
"react-i18next": "^12.1.5",
"react-hotkeys-hook": "4.4.0",
"react-i18next": "^12.2.2",
"react-icons": "^4.7.1",
"react-konva": "^18.2.4",
"react-konva-utils": "^0.3.2",
"react-konva": "^18.2.7",
"react-konva-utils": "^1.0.4",
"react-redux": "^8.0.5",
"react-rnd": "^10.4.1",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"react-use": "^17.4.0",
"react-virtuoso": "^4.3.5",
"react-zoom-pan-pinch": "^3.0.7",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
"uuid": "^9.0.0"
},
"peerDependencies": {
"@chakra-ui/cli": "^2.4.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"ts-toolbelt": "^9.6.0"
},
"devDependencies": {
"@chakra-ui/cli": "^2.4.0",
"@types/dateformat": "^5.0.0",
"@types/lodash": "^4.14.194",
"@types/react": "^18.0.28",
"@types/react-dom": "^18.0.11",
"@types/lodash-es": "^4.14.194",
"@types/node": "^18.16.2",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1",
"@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"@typescript-eslint/parser": "^5.59.1",
"@vitejs/plugin-react-swc": "^3.3.0",
"axios": "^1.4.0",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
"eslint-config-prettier": "^8.6.0",
"concurrently": "^8.0.1",
"eslint": "^8.39.0",
"eslint-config-prettier": "^8.8.0",
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"lint-staged": "^13.2.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"openapi-typescript-codegen": "^0.24.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"prettier": "^2.8.8",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"terser": "^5.17.1",
"ts-toolbelt": "^9.6.0",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite": "^4.3.3",
"vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",
"vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19"
}
}

View File

@ -527,10 +527,15 @@
"useCanvasBeta": "Use Canvas Beta Layout",
"enableImageDebugging": "Enable Image Debugging",
"useSlidersForAll": "Use Sliders For All Options",
"autoShowProgress": "Auto Show Progress Images",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"resetComplete": "Web UI has been reset. Refresh the page to reload."
"resetComplete": "Web UI has been reset. Refresh the page to reload.",
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
"general": "General"
},
"toast": {
"serverError": "Server Error",
@ -641,5 +646,9 @@
"betaDarkenOutside": "Darken Outside",
"betaLimitToBox": "Limit To Box",
"betaPreserveMasked": "Preserve Masked"
},
"ui": {
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images"
}
}

View File

@ -1,9 +1,7 @@
import ImageUploader from 'common/components/ImageUploader';
import Console from 'features/system/components/Console';
import ProgressBar from 'features/system/components/ProgressBar';
import SiteHeader from 'features/system/components/SiteHeader';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import { keepGUIAlive } from './utils';
import useToastWatcher from 'features/system/hooks/useToastWatcher';
@ -13,25 +11,34 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { PropsWithChildren, useCallback, useEffect, useState } from 'react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
memo,
PropsWithChildren,
useCallback,
useEffect,
useState,
} from 'react';
import { motion, AnimatePresence } from 'framer-motion';
import Loading from 'common/components/Loading/Loading';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
import { PartialAppConfig } from './invokeai';
import { PartialAppConfig } from 'app/types/invokeai';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { configChanged } from 'features/system/store/configSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useLogger } from 'app/logging/useLogger';
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
keepGUIAlive();
const DEFAULT_CONFIG = {};
interface Props extends PropsWithChildren {
config?: PartialAppConfig;
}
const App = ({ config = {}, children }: Props) => {
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
useToastWatcher();
useGlobalHotkeys();
const log = useLogger();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
@ -45,9 +52,9 @@ const App = ({ config = {}, children }: Props) => {
const dispatch = useAppDispatch();
useEffect(() => {
console.log('Received config: ', config);
log.info({ namespace: 'App', data: config }, 'Received config');
dispatch(configChanged(config));
}, [dispatch, config]);
}, [dispatch, config, log]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
@ -58,7 +65,7 @@ const App = ({ config = {}, children }: Props) => {
}, []);
return (
<Grid w="100vw" h="100vh" position="relative">
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
{isLightboxEnabled && <Lightbox />}
<ImageUploader>
<ProgressBar />
@ -114,11 +121,9 @@ const App = ({ config = {}, children }: Props) => {
<Portal>
<FloatingGalleryButton />
</Portal>
<Portal>
<Console />
</Portal>
<ProgressImagePreview />
</Grid>
);
};
export default App;
export default memo(App);

View File

@ -1,8 +1,8 @@
import React, { lazy, PropsWithChildren, useEffect } from 'react';
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { buildMiddleware, store } from './app/store';
import { persistor } from './persistor';
import { store } from 'app/store/store';
import { persistor } from '../store/persistor';
import { OpenAPI } from 'services/api';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
@ -14,14 +14,15 @@ import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import Loading from './common/components/Loading/Loading';
import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { PartialAppConfig } from 'app/invokeai';
import { PartialAppConfig } from 'app/types/invokeai';
import './i18n';
import '../../i18n';
import { socketMiddleware } from 'services/events/middleware';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@ -29,7 +30,7 @@ interface Props extends PropsWithChildren {
config?: PartialAppConfig;
}
export default function Component({ apiUrl, token, config, children }: Props) {
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
useEffect(() => {
// configure API client token
if (token) {
@ -50,7 +51,7 @@ export default function Component({ apiUrl, token, config, children }: Props) {
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
addMiddleware(socketMiddleware());
}, [apiUrl, token]);
return (
@ -66,4 +67,6 @@ export default function Component({ apiUrl, token, config, children }: Props) {
</Provider>
</React.StrictMode>
);
}
};
export default memo(InvokeAIUI);

View File

@ -2,8 +2,8 @@ import { ChakraProvider, extendTheme } from '@chakra-ui/react';
import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
import { RootState } from './store';
import { useAppSelector } from './storeHooks';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
@ -18,6 +18,8 @@ import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
type ThemeLocaleProviderProps = {
children: ReactNode;

View File

@ -1,23 +1,6 @@
// TODO: use Enums?
import { InProgressImageType } from 'features/system/store/systemSlice';
// Valid samplers
export const SAMPLERS: Array<string> = [
'ddim',
'plms',
'k_lms',
'k_dpm_2',
'k_dpm_2_a',
'k_dpmpp_2',
'k_dpmpp_2_a',
'k_euler',
'k_euler_a',
'k_heun',
];
// Valid Diffusers Samplers
export const DIFFUSERS_SAMPLERS: Array<string> = [
export const DIFFUSERS_SCHEDULERS: Array<string> = [
'ddim',
'plms',
'k_lms',
@ -48,17 +31,8 @@ export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 4294967295;
export const NUMPY_RAND_MAX = 2147483647;
export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;
export const IN_PROGRESS_IMAGE_TYPES: Array<{
key: string;
value: InProgressImageType;
}> = [
{ key: 'None', value: 'none' },
{ key: 'Fast', value: 'latents' },
{ key: 'Accurate', value: 'full-res' },
];
export const NODE_MIN_WIDTH = 250;

View File

@ -0,0 +1,94 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import { useEffect } from 'react';
import { LogLevelName, ROARR, Roarr } from 'roarr';
import { createLogWriter } from '@roarr/browser-log-writer';
// Base logging context includes only the package name
const baseContext = { package: '@invoke-ai/invoke-ai-ui' };
// Create browser log writer
ROARR.write = createLogWriter();
// Module-scoped logger - can be imported and used anywhere
export let log = Roarr.child(baseContext);
// Translate human-readable log levels to numbers, used for log filtering
export const LOG_LEVEL_MAP: Record<LogLevelName, number> = {
trace: 10,
debug: 20,
info: 30,
warn: 40,
error: 50,
fatal: 60,
};
export const VALID_LOG_LEVELS = [
'trace',
'debug',
'info',
'warn',
'error',
'fatal',
] as const;
export type InvokeLogLevel = (typeof VALID_LOG_LEVELS)[number];
const selector = createSelector(
systemSelector,
(system) => {
const { app_version, consoleLogLevel, shouldLogToConsole } = system;
return {
version: app_version,
consoleLogLevel,
shouldLogToConsole,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const useLogger = () => {
const { version, consoleLogLevel, shouldLogToConsole } =
useAppSelector(selector);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (shouldLogToConsole) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
localStorage.setItem(
'ROARR_FILTER',
`context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`
);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [consoleLogLevel, shouldLogToConsole]);
// Update the module-scoped logger context as needed
useEffect(() => {
const newContext: Record<string, any> = {
...baseContext,
};
if (version) {
newContext.version = version;
}
log = Roarr.child(newContext);
}, [version]);
// Use the logger within components - no different than just importing it directly
return log;
};

View File

@ -4,7 +4,7 @@ import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelector
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector(
[

View File

@ -1,65 +1,67 @@
import { createAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
// import { createAction } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import { GalleryCategory } from 'features/gallery/store/gallerySlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
// /**
// * We can't use redux-toolkit's createSlice() to make these actions,
// * because they have no associated reducer. They only exist to dispatch
// * requests to the server via socketio. These actions will be handled
// * by the middleware.
// */
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);
export const requestNewImages = createAction<GalleryCategory>(
'socketio/requestNewImages'
);
export const cancelProcessing = createAction<undefined>(
'socketio/cancelProcessing'
);
// export const generateImage = createAction<InvokeTabName>(
// 'socketio/generateImage'
// );
// export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
// export const runFacetool = createAction<InvokeAI._Image>(
// 'socketio/runFacetool'
// );
// export const deleteImage = createAction<InvokeAI._Image>(
// 'socketio/deleteImage'
// );
// export const requestImages = createAction<GalleryCategory>(
// 'socketio/requestImages'
// );
// export const requestNewImages = createAction<GalleryCategory>(
// 'socketio/requestNewImages'
// );
// export const cancelProcessing = createAction<undefined>(
// 'socketio/cancelProcessing'
// );
export const requestSystemConfig = createAction<undefined>(
'socketio/requestSystemConfig'
);
// export const requestSystemConfig = createAction<undefined>(
// 'socketio/requestSystemConfig'
// );
export const searchForModels = createAction<string>('socketio/searchForModels');
// export const searchForModels = createAction<string>('socketio/searchForModels');
export const addNewModel = createAction<
InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
>('socketio/addNewModel');
// export const addNewModel = createAction<
// InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
// >('socketio/addNewModel');
export const deleteModel = createAction<string>('socketio/deleteModel');
// export const deleteModel = createAction<string>('socketio/deleteModel');
export const convertToDiffusers =
createAction<InvokeAI.InvokeModelConversionProps>(
'socketio/convertToDiffusers'
);
// export const convertToDiffusers =
// createAction<InvokeAI.InvokeModelConversionProps>(
// 'socketio/convertToDiffusers'
// );
export const mergeDiffusersModels =
createAction<InvokeAI.InvokeModelMergingProps>(
'socketio/mergeDiffusersModels'
);
// export const mergeDiffusersModels =
// createAction<InvokeAI.InvokeModelMergingProps>(
// 'socketio/mergeDiffusersModels'
// );
export const requestModelChange = createAction<string>(
'socketio/requestModelChange'
);
// export const requestModelChange = createAction<string>(
// 'socketio/requestModelChange'
// );
export const saveStagingAreaImageToGallery = createAction<string>(
'socketio/saveStagingAreaImageToGallery'
);
// export const saveStagingAreaImageToGallery = createAction<string>(
// 'socketio/saveStagingAreaImageToGallery'
// );
export const emptyTempFolder = createAction<undefined>(
'socketio/requestEmptyTempFolder'
);
// export const emptyTempFolder = createAction<undefined>(
// 'socketio/requestEmptyTempFolder'
// );
export default {};

View File

@ -1,208 +1,209 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import type { RootState } from 'app/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
addLogEntry,
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store';
// import {
// frontendToBackendParameters,
// FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat';
// import {
// GalleryCategory,
// GalleryState,
// removeImage,
// } from 'features/gallery/store/gallerySlice';
// import {
// generationRequested,
// modelChangeRequested,
// modelConvertRequested,
// modelMergingRequested,
// setIsProcessing,
// } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
// /**
// * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests.
// */
// const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket
// ) => {
// // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store;
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
// return {
// emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true));
const state: RootState = getState();
// const state: RootState = getState();
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
// const {
// generation: generationState,
// postprocessing: postprocessingState,
// system: systemState,
// canvas: canvasState,
// } = state;
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// {
// generationMode,
// generationState,
// postprocessingState,
// canvasState,
// systemState,
// };
dispatch(generationRequested());
// dispatch(generationRequested());
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
// const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// socketio.emit(
// 'generateImage',
// generationParameters,
// esrganParameters,
// facetoolParameters
// );
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
// // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64)
// .concat('...');
// }
// if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img
// .substr(0, 64)
// .concat('...');
// }
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({
// ...generationParameters,
// ...esrganParameters,
// ...facetoolParameters,
// })}`,
// })
// );
// },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
// const {
// postprocessing: {
// upscalingLevel,
// upscalingDenoising,
// upscalingStrength,
// },
// } = getState();
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// };
// socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan',
// ...esrganParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url,
// ...esrganParameters,
// })}`,
// })
// );
// },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
// const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState();
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
// const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength,
// };
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
// if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
// }
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
// socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType,
// ...facetoolParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// {
// file: imageToProcess.url,
// ...facetoolParameters,
// }
// )}`,
// })
// );
// },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
// },
// emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime);
// },
// emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime);
// },
// emitCancelProcessing: () => {
// socketio.emit('cancel');
// },
// emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig');
// },
// emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder);
// },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig);
// },
// emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName);
// },
// emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => {
// dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert);
// },
// emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => {
// dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
// },
// emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName);
// },
// emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
// },
// emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder');
// },
// };
// };
export default makeSocketIOEmitters;
// export default makeSocketIOEmitters;
export default {};

View File

@ -1,501 +1,502 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import i18n from 'i18n';
import { v4 as uuidv4 } from 'uuid';
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import dateFormat from 'dateformat';
// import i18n from 'i18n';
// import { v4 as uuidv4 } from 'uuid';
import * as InvokeAI from 'app/invokeai';
// import * as InvokeAI from 'app/types/invokeai';
import {
addLogEntry,
addToast,
errorOccurred,
processingCanceled,
setCurrentStatus,
setFoundModels,
setIsCancelable,
setIsConnected,
setIsProcessing,
setModelList,
setSearchFolder,
setSystemConfig,
setSystemStatus,
} from 'features/system/store/systemSlice';
// import {
// addToast,
// errorOccurred,
// processingCanceled,
// setCurrentStatus,
// setFoundModels,
// setIsCancelable,
// setIsConnected,
// setIsProcessing,
// setModelList,
// setSearchFolder,
// setSystemConfig,
// setSystemStatus,
// } from 'features/system/store/systemSlice';
import {
addGalleryImages,
addImage,
clearIntermediateImage,
GalleryState,
removeImage,
setIntermediateImage,
} from 'features/gallery/store/gallerySlice';
// import {
// addGalleryImages,
// addImage,
// clearIntermediateImage,
// GalleryState,
// removeImage,
// setIntermediateImage,
// } from 'features/gallery/store/gallerySlice';
import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
// setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
import {
requestImages,
requestNewImages,
requestSystemConfig,
} from './actions';
// import type { RootState } from 'app/store/store';
// import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
// import {
// clearInitialImage,
// initialImageSelected,
// setInfillMethod,
// // setInitialImage,
// setMaskPath,
// } from 'features/parameters/store/generationSlice';
// import { tabMap } from 'features/ui/store/tabMap';
// import {
// requestImages,
// requestNewImages,
// requestSystemConfig,
// } from './actions';
/**
* Returns an object containing listener callbacks for socketio events.
* TODO: This file is large, but simple. Should it be split up further?
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
) => {
const { dispatch, getState } = store;
// /**
// * Returns an object containing listener callbacks for socketio events.
// * TODO: This file is large, but simple. Should it be split up further?
// */
// const makeSocketIOListeners = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
// ) => {
// const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
dispatch(requestSystemConfig());
const gallery: GalleryState = getState().gallery;
// return {
// /**
// * Callback to run when we receive a 'connect' event.
// */
// onConnect: () => {
// try {
// dispatch(setIsConnected(true));
// dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
// dispatch(requestSystemConfig());
// const gallery: GalleryState = getState().gallery;
if (gallery.categories.result.latest_mtime) {
dispatch(requestNewImages('result'));
} else {
dispatch(requestImages('result'));
}
// if (gallery.categories.result.latest_mtime) {
// dispatch(requestNewImages('result'));
// } else {
// dispatch(requestImages('result'));
// }
if (gallery.categories.user.latest_mtime) {
dispatch(requestNewImages('user'));
} else {
dispatch(requestImages('user'));
}
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
// if (gallery.categories.user.latest_mtime) {
// dispatch(requestNewImages('user'));
// } else {
// dispatch(requestImages('user'));
// }
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'disconnect' event.
// */
// onDisconnect: () => {
// try {
// dispatch(setIsConnected(false));
// dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
try {
const state = getState();
const { activeTab } = state.ui;
const { shouldLoopback } = state.postprocessing;
const { boundingBox: _, generationMode, ...rest } = data;
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Disconnected from server`,
// level: 'warning',
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'generationResult' event.
// */
// onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// const state = getState();
// const { activeTab } = state.ui;
// const { shouldLoopback } = state.postprocessing;
// const { boundingBox: _, generationMode, ...rest } = data;
const newImage = {
uuid: uuidv4(),
...rest,
};
// const newImage = {
// uuid: uuidv4(),
// ...rest,
// };
if (['txt2img', 'img2img'].includes(generationMode)) {
dispatch(
addImage({
category: 'result',
image: { ...newImage, category: 'result' },
})
);
}
// if (['txt2img', 'img2img'].includes(generationMode)) {
// dispatch(
// addImage({
// category: 'result',
// image: { ...newImage, category: 'result' },
// })
// );
// }
if (generationMode === 'unifiedCanvas' && data.boundingBox) {
const { boundingBox } = data;
dispatch(
addImageToStagingArea({
image: { ...newImage, category: 'temp' },
boundingBox,
})
);
// if (generationMode === 'unifiedCanvas' && data.boundingBox) {
// const { boundingBox } = data;
// dispatch(
// addImageToStagingArea({
// image: { ...newImage, category: 'temp' },
// boundingBox,
// })
// );
if (state.canvas.shouldAutoSave) {
dispatch(
addImage({
image: { ...newImage, category: 'result' },
category: 'result',
})
);
}
}
// if (state.canvas.shouldAutoSave) {
// dispatch(
// addImage({
// image: { ...newImage, category: 'result' },
// category: 'result',
// })
// );
// }
// }
// TODO: fix
// if (shouldLoopback) {
// const activeTabName = tabMap[activeTab];
// switch (activeTabName) {
// case 'img2img': {
// dispatch(initialImageSelected(newImage.uuid));
// // dispatch(setInitialImage(newImage));
// break;
// }
// }
// }
// // TODO: fix
// // if (shouldLoopback) {
// // const activeTabName = tabMap[activeTab];
// // switch (activeTabName) {
// // case 'img2img': {
// // dispatch(initialImageSelected(newImage.uuid));
// // // dispatch(setInitialImage(newImage));
// // break;
// // }
// // }
// // }
dispatch(clearIntermediateImage());
// dispatch(clearIntermediateImage());
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generated: ${data.url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'intermediateResult' event.
*/
onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
try {
dispatch(
setIntermediateImage({
uuid: uuidv4(),
...data,
category: 'result',
})
);
if (!data.isBase64) {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image generated: ${data.url}`,
})
);
}
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive an 'esrganResult' event.
*/
onPostprocessingResult: (data: InvokeAI.ImageResultResponse) => {
try {
dispatch(
addImage({
category: 'result',
image: {
uuid: uuidv4(),
...data,
category: 'result',
},
})
);
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generated: ${data.url}`,
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'intermediateResult' event.
// */
// onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// dispatch(
// setIntermediateImage({
// uuid: uuidv4(),
// ...data,
// category: 'result',
// })
// );
// if (!data.isBase64) {
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Intermediate image generated: ${data.url}`,
// })
// );
// }
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive an 'esrganResult' event.
// */
// onPostprocessingResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// dispatch(
// addImage({
// category: 'result',
// image: {
// uuid: uuidv4(),
// ...data,
// category: 'result',
// },
// })
// );
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Postprocessed: ${data.url}`,
})
);
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onProgressUpdate: (data: InvokeAI.SystemStatus) => {
try {
dispatch(setIsProcessing(true));
dispatch(setSystemStatus(data));
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onError: (data: InvokeAI.ErrorResponse) => {
const { message, additionalData } = data;
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Postprocessed: ${data.url}`,
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'progressUpdate' event.
// * TODO: Add additional progress phases
// */
// onProgressUpdate: (data: InvokeAI.SystemStatus) => {
// try {
// dispatch(setIsProcessing(true));
// dispatch(setSystemStatus(data));
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'progressUpdate' event.
// */
// onError: (data: InvokeAI.ErrorResponse) => {
// const { message, additionalData } = data;
if (additionalData) {
// TODO: handle more data than short message
}
// if (additionalData) {
// // TODO: handle more data than short message
// }
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${message}`,
level: 'error',
})
);
dispatch(errorOccurred());
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
const { images, areMoreImagesAvailable, category } = data;
// try {
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Server error: ${message}`,
// level: 'error',
// })
// );
// dispatch(errorOccurred());
// dispatch(clearIntermediateImage());
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'galleryImages' event.
// */
// onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
// const { images, areMoreImagesAvailable, category } = data;
/**
* the logic here ideally would be in the reducer but we have a side effect:
* generating a uuid. so the logic needs to be here, outside redux.
*/
// /**
// * the logic here ideally would be in the reducer but we have a side effect:
// * generating a uuid. so the logic needs to be here, outside redux.
// */
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI._Image => {
return {
uuid: uuidv4(),
...image,
};
});
// // Generate a UUID for each image
// const preparedImages = images.map((image): InvokeAI._Image => {
// return {
// uuid: uuidv4(),
// ...image,
// };
// });
dispatch(
addGalleryImages({
images: preparedImages,
areMoreImagesAvailable,
category,
})
);
// dispatch(
// addGalleryImages({
// images: preparedImages,
// areMoreImagesAvailable,
// category,
// })
// );
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Loaded ${images.length} images`,
})
);
},
/**
* Callback to run when we receive a 'processingCanceled' event.
*/
onProcessingCanceled: () => {
dispatch(processingCanceled());
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Loaded ${images.length} images`,
// })
// );
// },
// /**
// * Callback to run when we receive a 'processingCanceled' event.
// */
// onProcessingCanceled: () => {
// dispatch(processingCanceled());
const { intermediateImage } = getState().gallery;
// const { intermediateImage } = getState().gallery;
if (intermediateImage) {
if (!intermediateImage.isBase64) {
dispatch(
addImage({
category: 'result',
image: intermediateImage,
})
);
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Intermediate image saved: ${intermediateImage.url}`,
})
);
}
dispatch(clearIntermediateImage());
}
// if (intermediateImage) {
// if (!intermediateImage.isBase64) {
// dispatch(
// addImage({
// category: 'result',
// image: intermediateImage,
// })
// );
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Intermediate image saved: ${intermediateImage.url}`,
// })
// );
// }
// dispatch(clearIntermediateImage());
// }
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Processing canceled`,
level: 'warning',
})
);
},
/**
* Callback to run when we receive a 'imageDeleted' event.
*/
onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
const { url } = data;
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Processing canceled`,
// level: 'warning',
// })
// );
// },
// /**
// * Callback to run when we receive a 'imageDeleted' event.
// */
// onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
// const { url } = data;
// remove image from gallery
dispatch(removeImage(data));
// // remove image from gallery
// dispatch(removeImage(data));
// remove references to image in options
const {
generation: { initialImage, maskPath },
} = getState();
// // remove references to image in options
// const {
// generation: { initialImage, maskPath },
// } = getState();
if (
initialImage === url ||
(initialImage as InvokeAI._Image)?.url === url
) {
dispatch(clearInitialImage());
}
// if (
// initialImage === url ||
// (initialImage as InvokeAI._Image)?.url === url
// ) {
// dispatch(clearInitialImage());
// }
if (maskPath === url) {
dispatch(setMaskPath(''));
}
// if (maskPath === url) {
// dispatch(setMaskPath(''));
// }
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image deleted: ${url}`,
})
);
},
onSystemConfig: (data: InvokeAI.SystemConfig) => {
dispatch(setSystemConfig(data));
if (!data.infill_methods.includes('patchmatch')) {
dispatch(setInfillMethod(data.infill_methods[0]));
}
},
onFoundModels: (data: InvokeAI.FoundModelResponse) => {
const { search_folder, found_models } = data;
dispatch(setSearchFolder(search_folder));
dispatch(setFoundModels(found_models));
},
onNewModelAdded: (data: InvokeAI.ModelAddedResponse) => {
const { new_model_name, model_list, update } = data;
dispatch(setModelList(model_list));
dispatch(setIsProcessing(false));
dispatch(setCurrentStatus(i18n.t('modelManager.modelAdded')));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Model Added: ${new_model_name}`,
level: 'info',
})
);
dispatch(
addToast({
title: !update
? `${i18n.t('modelManager.modelAdded')}: ${new_model_name}`
: `${i18n.t('modelManager.modelUpdated')}: ${new_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelDeleted: (data: InvokeAI.ModelDeletedResponse) => {
const { deleted_model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setIsProcessing(false));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `${i18n.t(
'modelManager.modelAdded'
)}: ${deleted_model_name}`,
level: 'info',
})
);
dispatch(
addToast({
title: `${i18n.t(
'modelManager.modelEntryDeleted'
)}: ${deleted_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelConverted: (data: InvokeAI.ModelConvertedResponse) => {
const { converted_model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common.statusModelConverted')));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Model converted: ${converted_model_name}`,
level: 'info',
})
);
dispatch(
addToast({
title: `${i18n.t(
'modelManager.modelConverted'
)}: ${converted_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
const { merged_models, merged_model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common.statusMergedModels')));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Models merged: ${merged_models}`,
level: 'info',
})
);
dispatch(
addToast({
title: `${i18n.t('modelManager.modelsMerged')}: ${merged_model_name}`,
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
const { model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Model changed: ${model_name}`,
level: 'info',
})
);
},
onModelChangeFailed: (data: InvokeAI.ModelChangeResponse) => {
const { model_name, model_list } = data;
dispatch(setModelList(model_list));
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(true));
dispatch(errorOccurred());
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Model change failed: ${model_name}`,
level: 'error',
})
);
},
onTempFolderEmptied: () => {
dispatch(
addToast({
title: i18n.t('toast.tempFoldersEmptied'),
status: 'success',
duration: 2500,
isClosable: true,
})
);
},
};
};
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image deleted: ${url}`,
// })
// );
// },
// onSystemConfig: (data: InvokeAI.SystemConfig) => {
// dispatch(setSystemConfig(data));
// if (!data.infill_methods.includes('patchmatch')) {
// dispatch(setInfillMethod(data.infill_methods[0]));
// }
// },
// onFoundModels: (data: InvokeAI.FoundModelResponse) => {
// const { search_folder, found_models } = data;
// dispatch(setSearchFolder(search_folder));
// dispatch(setFoundModels(found_models));
// },
// onNewModelAdded: (data: InvokeAI.ModelAddedResponse) => {
// const { new_model_name, model_list, update } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(setCurrentStatus(i18n.t('modelManager.modelAdded')));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model Added: ${new_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: !update
// ? `${i18n.t('modelManager.modelAdded')}: ${new_model_name}`
// : `${i18n.t('modelManager.modelUpdated')}: ${new_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelDeleted: (data: InvokeAI.ModelDeletedResponse) => {
// const { deleted_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `${i18n.t(
// 'modelManager.modelAdded'
// )}: ${deleted_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t(
// 'modelManager.modelEntryDeleted'
// )}: ${deleted_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelConverted: (data: InvokeAI.ModelConvertedResponse) => {
// const { converted_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusModelConverted')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model converted: ${converted_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t(
// 'modelManager.modelConverted'
// )}: ${converted_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
// const { merged_models, merged_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusMergedModels')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Models merged: ${merged_models}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t('modelManager.modelsMerged')}: ${merged_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
// const { model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model changed: ${model_name}`,
// level: 'info',
// })
// );
// },
// onModelChangeFailed: (data: InvokeAI.ModelChangeResponse) => {
// const { model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(errorOccurred());
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model change failed: ${model_name}`,
// level: 'error',
// })
// );
// },
// onTempFolderEmptied: () => {
// dispatch(
// addToast({
// title: i18n.t('toast.tempFoldersEmptied'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// };
// };
export default makeSocketIOListeners;
// export default makeSocketIOListeners;
export default {};

View File

@ -1,246 +1,248 @@
import { Middleware } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
// import { Middleware } from '@reduxjs/toolkit';
// import { io } from 'socket.io-client';
import makeSocketIOEmitters from './emitters';
import makeSocketIOListeners from './listeners';
// import makeSocketIOEmitters from './emitters';
// import makeSocketIOListeners from './listeners';
import * as InvokeAI from 'app/invokeai';
// import * as InvokeAI from 'app/types/invokeai';
/**
* Creates a socketio middleware to handle communication with server.
*
* Special `socketio/actionName` actions are created in actions.ts and
* exported for use by the application, which treats them like any old
* action, using `dispatch` to dispatch them.
*
* These actions are intercepted here, where `socketio.emit()` calls are
* made on their behalf - see `emitters.ts`. The emitter functions
* are the outbound communication to the server.
*
* Listeners are also established here - see `listeners.ts`. The listener
* functions receive communication from the server and usually dispatch
* some new action to handle whatever data was sent from the server.
*/
export const socketioMiddleware = () => {
const { origin } = new URL(window.location.href);
// /**
// * Creates a socketio middleware to handle communication with server.
// *
// * Special `socketio/actionName` actions are created in actions.ts and
// * exported for use by the application, which treats them like any old
// * action, using `dispatch` to dispatch them.
// *
// * These actions are intercepted here, where `socketio.emit()` calls are
// * made on their behalf - see `emitters.ts`. The emitter functions
// * are the outbound communication to the server.
// *
// * Listeners are also established here - see `listeners.ts`. The listener
// * functions receive communication from the server and usually dispatch
// * some new action to handle whatever data was sent from the server.
// */
// export const socketioMiddleware = () => {
// const { origin } = new URL(window.location.href);
const socketio = io(origin, {
timeout: 60000,
path: `${window.location.pathname}socket.io`,
});
// const socketio = io(origin, {
// timeout: 60000,
// path: `${window.location.pathname}socket.io`,
// });
socketio.disconnect();
// socketio.disconnect();
let areListenersSet = false;
// let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const {
onConnect,
onDisconnect,
onError,
onPostprocessingResult,
onGenerationResult,
onIntermediateResult,
onProgressUpdate,
onGalleryImages,
onProcessingCanceled,
onImageDeleted,
onSystemConfig,
onModelChanged,
onFoundModels,
onNewModelAdded,
onModelDeleted,
onModelConverted,
onModelsMerged,
onModelChangeFailed,
onTempFolderEmptied,
} = makeSocketIOListeners(store);
// const middleware: Middleware = (store) => (next) => (action) => {
// const {
// onConnect,
// onDisconnect,
// onError,
// onPostprocessingResult,
// onGenerationResult,
// onIntermediateResult,
// onProgressUpdate,
// onGalleryImages,
// onProcessingCanceled,
// onImageDeleted,
// onSystemConfig,
// onModelChanged,
// onFoundModels,
// onNewModelAdded,
// onModelDeleted,
// onModelConverted,
// onModelsMerged,
// onModelChangeFailed,
// onTempFolderEmptied,
// } = makeSocketIOListeners(store);
const {
emitGenerateImage,
emitRunESRGAN,
emitRunFacetool,
emitDeleteImage,
emitRequestImages,
emitRequestNewImages,
emitCancelProcessing,
emitRequestSystemConfig,
emitSearchForModels,
emitAddNewModel,
emitDeleteModel,
emitConvertToDiffusers,
emitMergeDiffusersModels,
emitRequestModelChange,
emitSaveStagingAreaImageToGallery,
emitRequestEmptyTempFolder,
} = makeSocketIOEmitters(store, socketio);
// const {
// emitGenerateImage,
// emitRunESRGAN,
// emitRunFacetool,
// emitDeleteImage,
// emitRequestImages,
// emitRequestNewImages,
// emitCancelProcessing,
// emitRequestSystemConfig,
// emitSearchForModels,
// emitAddNewModel,
// emitDeleteModel,
// emitConvertToDiffusers,
// emitMergeDiffusersModels,
// emitRequestModelChange,
// emitSaveStagingAreaImageToGallery,
// emitRequestEmptyTempFolder,
// } = makeSocketIOEmitters(store, socketio);
/**
* If this is the first time the middleware has been called (e.g. during store setup),
* initialize all our socket.io listeners.
*/
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
// /**
// * If this is the first time the middleware has been called (e.g. during store setup),
// * initialize all our socket.io listeners.
// */
// if (!areListenersSet) {
// socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
// socketio.on('disconnect', () => onDisconnect());
socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
// socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
onGenerationResult(data)
);
// socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
// onGenerationResult(data)
// );
socketio.on(
'postprocessingResult',
(data: InvokeAI.ImageResultResponse) => onPostprocessingResult(data)
);
// socketio.on(
// 'postprocessingResult',
// (data: InvokeAI.ImageResultResponse) => onPostprocessingResult(data)
// );
socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
onIntermediateResult(data)
);
// socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
// onIntermediateResult(data)
// );
socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
onProgressUpdate(data)
);
// socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
// onProgressUpdate(data)
// );
socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
onGalleryImages(data)
);
// socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
// onGalleryImages(data)
// );
socketio.on('processingCanceled', () => {
onProcessingCanceled();
});
// socketio.on('processingCanceled', () => {
// onProcessingCanceled();
// });
socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
onImageDeleted(data);
});
// socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
// onImageDeleted(data);
// });
socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
onSystemConfig(data);
});
// socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
// onSystemConfig(data);
// });
socketio.on('foundModels', (data: InvokeAI.FoundModelResponse) => {
onFoundModels(data);
});
// socketio.on('foundModels', (data: InvokeAI.FoundModelResponse) => {
// onFoundModels(data);
// });
socketio.on('newModelAdded', (data: InvokeAI.ModelAddedResponse) => {
onNewModelAdded(data);
});
// socketio.on('newModelAdded', (data: InvokeAI.ModelAddedResponse) => {
// onNewModelAdded(data);
// });
socketio.on('modelDeleted', (data: InvokeAI.ModelDeletedResponse) => {
onModelDeleted(data);
});
// socketio.on('modelDeleted', (data: InvokeAI.ModelDeletedResponse) => {
// onModelDeleted(data);
// });
socketio.on('modelConverted', (data: InvokeAI.ModelConvertedResponse) => {
onModelConverted(data);
});
// socketio.on('modelConverted', (data: InvokeAI.ModelConvertedResponse) => {
// onModelConverted(data);
// });
socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
onModelsMerged(data);
});
// socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
// onModelsMerged(data);
// });
socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
onModelChanged(data);
});
// socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
// onModelChanged(data);
// });
socketio.on('modelChangeFailed', (data: InvokeAI.ModelChangeResponse) => {
onModelChangeFailed(data);
});
// socketio.on('modelChangeFailed', (data: InvokeAI.ModelChangeResponse) => {
// onModelChangeFailed(data);
// });
socketio.on('tempFolderEmptied', () => {
onTempFolderEmptied();
});
// socketio.on('tempFolderEmptied', () => {
// onTempFolderEmptied();
// });
areListenersSet = true;
}
// areListenersSet = true;
// }
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/generateImage': {
emitGenerateImage(action.payload);
break;
}
// /**
// * Handle redux actions caught by middleware.
// */
// switch (action.type) {
// case 'socketio/generateImage': {
// emitGenerateImage(action.payload);
// break;
// }
case 'socketio/runESRGAN': {
emitRunESRGAN(action.payload);
break;
}
// case 'socketio/runESRGAN': {
// emitRunESRGAN(action.payload);
// break;
// }
case 'socketio/runFacetool': {
emitRunFacetool(action.payload);
break;
}
// case 'socketio/runFacetool': {
// emitRunFacetool(action.payload);
// break;
// }
case 'socketio/deleteImage': {
emitDeleteImage(action.payload);
break;
}
// case 'socketio/deleteImage': {
// emitDeleteImage(action.payload);
// break;
// }
case 'socketio/requestImages': {
emitRequestImages(action.payload);
break;
}
// case 'socketio/requestImages': {
// emitRequestImages(action.payload);
// break;
// }
case 'socketio/requestNewImages': {
emitRequestNewImages(action.payload);
break;
}
// case 'socketio/requestNewImages': {
// emitRequestNewImages(action.payload);
// break;
// }
case 'socketio/cancelProcessing': {
emitCancelProcessing();
break;
}
// case 'socketio/cancelProcessing': {
// emitCancelProcessing();
// break;
// }
case 'socketio/requestSystemConfig': {
emitRequestSystemConfig();
break;
}
// case 'socketio/requestSystemConfig': {
// emitRequestSystemConfig();
// break;
// }
case 'socketio/searchForModels': {
emitSearchForModels(action.payload);
break;
}
// case 'socketio/searchForModels': {
// emitSearchForModels(action.payload);
// break;
// }
case 'socketio/addNewModel': {
emitAddNewModel(action.payload);
break;
}
// case 'socketio/addNewModel': {
// emitAddNewModel(action.payload);
// break;
// }
case 'socketio/deleteModel': {
emitDeleteModel(action.payload);
break;
}
// case 'socketio/deleteModel': {
// emitDeleteModel(action.payload);
// break;
// }
case 'socketio/convertToDiffusers': {
emitConvertToDiffusers(action.payload);
break;
}
// case 'socketio/convertToDiffusers': {
// emitConvertToDiffusers(action.payload);
// break;
// }
case 'socketio/mergeDiffusersModels': {
emitMergeDiffusersModels(action.payload);
break;
}
// case 'socketio/mergeDiffusersModels': {
// emitMergeDiffusersModels(action.payload);
// break;
// }
case 'socketio/requestModelChange': {
emitRequestModelChange(action.payload);
break;
}
// case 'socketio/requestModelChange': {
// emitRequestModelChange(action.payload);
// break;
// }
case 'socketio/saveStagingAreaImageToGallery': {
emitSaveStagingAreaImageToGallery(action.payload);
break;
}
// case 'socketio/saveStagingAreaImageToGallery': {
// emitSaveStagingAreaImageToGallery(action.payload);
// break;
// }
case 'socketio/requestEmptyTempFolder': {
emitRequestEmptyTempFolder();
break;
}
}
// case 'socketio/requestEmptyTempFolder': {
// emitRequestEmptyTempFolder();
// break;
// }
// }
next(action);
};
// next(action);
// };
return middleware;
};
// return middleware;
// };
export default {};

View File

@ -1,4 +1,4 @@
import { store } from 'app/store';
import { store } from 'app/store/store';
import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);

View File

@ -19,8 +19,6 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
@ -28,8 +26,10 @@ import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylis
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemDenylist } from 'features/system/store/systemPersistsDenylist';
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -82,19 +82,18 @@ const rootPersistConfig = getPersistConfig({
'hotkeys',
'config',
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
// export function buildMiddleware() {
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
// return socketMiddleware();
// } else {
// return socketioMiddleware();
// }
// }
export const store = configureStore({
reducer: persistedReducer,
@ -114,6 +113,7 @@ export const store = configureStore({
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
],
},
});

View File

@ -1,5 +1,5 @@
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
import { AppDispatch, RootState } from './store';
import { AppDispatch, RootState } from 'app/store/store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;

View File

@ -1,5 +1,5 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
import { AppDispatch, RootState } from 'app/store/store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{

View File

@ -12,10 +12,11 @@
* 'gfpgan'.
*/
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { ImageResponseMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
import { O } from 'ts-toolbelt';
@ -28,24 +29,24 @@ import { O } from 'ts-toolbelt';
* TODO: Better documentation of types.
*/
export declare type PromptItem = {
export type PromptItem = {
prompt: string;
weight: number;
};
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
export declare type Prompt = Array<PromptItem> | string;
export type Prompt = Array<PromptItem> | string;
export declare type SeedWeightPair = {
export type SeedWeightPair = {
seed: number;
weight: number;
};
export declare type SeedWeights = Array<SeedWeightPair>;
export type SeedWeights = Array<SeedWeightPair>;
// All generated images contain these metadata.
export declare type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | GFPGANMetadata>;
export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler:
| 'ddim'
| 'k_dpm_2_a'
@ -70,11 +71,11 @@ export declare type CommonGeneratedImageMetadata = {
};
// txt2img and img2img images have some unique attributes.
export declare type Txt2ImgMetadata = GeneratedImageMetadata & {
export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'txt2img';
};
export declare type Img2ImgMetadata = GeneratedImageMetadata & {
export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
type: 'img2img';
orig_hash: string;
strength: number;
@ -84,102 +85,80 @@ export declare type Img2ImgMetadata = GeneratedImageMetadata & {
};
// Superset of generated image metadata types.
export declare type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
// All post processed images contain these metadata.
export declare type CommonPostProcessedImageMetadata = {
export type CommonPostProcessedImageMetadata = {
orig_path: string;
orig_hash: string;
};
// esrgan and gfpgan images have some unique attributes.
export declare type ESRGANMetadata = CommonPostProcessedImageMetadata & {
export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
type: 'esrgan';
scale: 2 | 4;
strength: number;
denoise_str: number;
};
export declare type FacetoolMetadata = CommonPostProcessedImageMetadata & {
export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
type: 'gfpgan' | 'codeformer';
strength: number;
fidelity?: number;
};
// Superset of all postprocessed image metadata types..
export declare type PostProcessedImageMetadata =
| ESRGANMetadata
| FacetoolMetadata;
export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
// Metadata includes the system config and image metadata.
export declare type Metadata = SystemGenerationMetadata & {
image: GeneratedImageMetadata | PostProcessedImageMetadata;
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type _Image = {
uuid: string;
url: string;
thumbnail: string;
mtime: number;
metadata?: Metadata;
width: number;
height: number;
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
// export type Metadata = SystemGenerationMetadata & {
// image: GeneratedImageMetadata | PostProcessedImageMetadata;
// };
/**
* ResultImage
*/
export declare type Image = {
export type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<_Image>;
metadata: ImageResponseMetadata;
};
/**
* Types related to the system status.
*/
// This represents the processing status of the backend.
export declare type SystemStatus = {
isProcessing: boolean;
currentStep: number;
totalSteps: number;
currentIteration: number;
totalIterations: number;
currentStatus: string;
currentStatusHasSteps: boolean;
hasError: boolean;
};
// // This represents the processing status of the backend.
// export type SystemStatus = {
// isProcessing: boolean;
// currentStep: number;
// totalSteps: number;
// currentIteration: number;
// totalIterations: number;
// currentStatus: string;
// currentStatusHasSteps: boolean;
// hasError: boolean;
// };
export declare type SystemGenerationMetadata = {
model: string;
model_weights?: string;
model_id?: string;
model_hash: string;
app_id: string;
app_version: string;
};
// export type SystemGenerationMetadata = {
// model: string;
// model_weights?: string;
// model_id?: string;
// model_hash: string;
// app_id: string;
// app_version: string;
// };
export declare type SystemConfig = SystemGenerationMetadata & {
model_list: ModelList;
infill_methods: string[];
};
// export type SystemConfig = SystemGenerationMetadata & {
// model_list: ModelList;
// infill_methods: string[];
// };
export declare type ModelStatus = 'active' | 'cached' | 'not loaded';
export type ModelStatus = 'active' | 'cached' | 'not loaded';
export declare type Model = {
export type Model = {
status: ModelStatus;
description: string;
weights: string;
@ -191,7 +170,7 @@ export declare type Model = {
format?: string;
};
export declare type DiffusersModel = {
export type DiffusersModel = {
status: ModelStatus;
description: string;
repo_id?: string;
@ -204,14 +183,14 @@ export declare type DiffusersModel = {
default?: boolean;
};
export declare type ModelList = Record<string, Model & DiffusersModel>;
export type ModelList = Record<string, Model & DiffusersModel>;
export declare type FoundModel = {
export type FoundModel = {
name: string;
location: string;
};
export declare type InvokeModelConfigProps = {
export type InvokeModelConfigProps = {
name: string | undefined;
description: string | undefined;
config: string | undefined;
@ -223,7 +202,7 @@ export declare type InvokeModelConfigProps = {
format: string | undefined;
};
export declare type InvokeDiffusersModelConfigProps = {
export type InvokeDiffusersModelConfigProps = {
name: string | undefined;
description: string | undefined;
repo_id: string | undefined;
@ -236,13 +215,13 @@ export declare type InvokeDiffusersModelConfigProps = {
};
};
export declare type InvokeModelConversionProps = {
export type InvokeModelConversionProps = {
model_name: string;
save_location: string;
custom_location: string | null;
};
export declare type InvokeModelMergingProps = {
export type InvokeModelMergingProps = {
models_to_merge: string[];
alpha: number;
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
@ -255,48 +234,48 @@ export declare type InvokeModelMergingProps = {
* These types type data received from the server via socketio.
*/
export declare type ModelChangeResponse = {
export type ModelChangeResponse = {
model_name: string;
model_list: ModelList;
};
export declare type ModelConvertedResponse = {
export type ModelConvertedResponse = {
converted_model_name: string;
model_list: ModelList;
};
export declare type ModelsMergedResponse = {
export type ModelsMergedResponse = {
merged_models: string[];
merged_model_name: string;
model_list: ModelList;
};
export declare type ModelAddedResponse = {
export type ModelAddedResponse = {
new_model_name: string;
model_list: ModelList;
update: boolean;
};
export declare type ModelDeletedResponse = {
export type ModelDeletedResponse = {
deleted_model_name: string;
model_list: ModelList;
};
export declare type FoundModelResponse = {
export type FoundModelResponse = {
search_folder: string;
found_models: FoundModel[];
};
export declare type SystemStatusResponse = SystemStatus;
// export type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
// export type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
export declare type ImageUploadResponse = {
export type ImageUploadResponse = {
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
url: string;
mtime: number;
@ -306,33 +285,16 @@ export declare type ImageUploadResponse = {
// bbox: [number, number, number, number];
};
export declare type ErrorResponse = {
export type ErrorResponse = {
message: string;
additionalData?: string;
};
export declare type GalleryImagesResponse = {
images: Array<Omit<_Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
export declare type ImageDeletedResponse = {
uuid: string;
url: string;
category: GalleryCategory;
};
export declare type ImageUrlResponse = {
export type ImageUrlResponse = {
url: string;
};
export declare type UploadImagePayload = {
file: File;
destination?: ImageUploadDestination;
};
export declare type UploadOutpaintingMergeImagePayload = {
export type UploadOutpaintingMergeImagePayload = {
dataURL: string;
name: string;
};
@ -340,7 +302,7 @@ export declare type UploadOutpaintingMergeImagePayload = {
/**
* A disable-able application feature
*/
export declare type AppFeature =
export type AppFeature =
| 'faceRestore'
| 'upscaling'
| 'lightbox'
@ -353,7 +315,7 @@ export declare type AppFeature =
/**
* A disable-able Stable Diffusion feature
*/
export declare type StableDiffusionFeature =
export type StableDiffusionFeature =
| 'noiseConfig'
| 'variations'
| 'symmetry'
@ -364,7 +326,7 @@ export declare type StableDiffusionFeature =
* Configuration options for the InvokeAI UI.
* Distinct from system settings which may be changed inside the app.
*/
export declare type AppConfig = {
export type AppConfig = {
/**
* Whether or not URLs should be transformed to use a different host
*/
@ -428,4 +390,4 @@ export declare type AppConfig = {
};
};
export declare type PartialAppConfig = O.Partial<AppConfig, 'deep'>;
export type PartialAppConfig = O.Partial<AppConfig, 'deep'>;

View File

@ -1,25 +0,0 @@
export function keepGUIAlive() {
async function getRequest(url = '') {
const response = await fetch(url, {
method: 'GET',
cache: 'no-cache',
});
return response;
}
const keepAliveServer = () => {
const url = document.location;
const route = '/flaskwebgui-keep-server-alive';
getRequest(url + route).then((data) => {
return data;
});
};
if (!import.meta.env.NODE_ENV || import.meta.env.NODE_ENV === 'production') {
document.addEventListener('DOMContentLoaded', () => {
const intervalRequest = 3 * 1000;
keepAliveServer();
setInterval(keepAliveServer, intervalRequest);
});
}
}

View File

@ -8,7 +8,7 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { Feature, useFeatureHelpInfo } from 'app/features';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { SystemState } from 'features/system/store/systemSlice';
import { memo, ReactElement } from 'react';

View File

@ -14,7 +14,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { clamp } from 'lodash';
import { clamp } from 'lodash-es';
import { FocusEvent, memo, useEffect, useState } from 'react';

View File

@ -16,13 +16,23 @@ type IAISelectProps = SelectProps & {
validValues:
| Array<number | string>
| Array<{ key: string; value: string | number }>;
horizontal?: boolean;
spaceEvenly?: boolean;
};
/**
* Customized Chakra FormControl + Select multi-part component.
*/
const IAISelect = (props: IAISelectProps) => {
const { label, isDisabled, validValues, tooltip, tooltipProps, ...rest } =
props;
const {
label,
isDisabled,
validValues,
tooltip,
tooltipProps,
horizontal,
spaceEvenly,
...rest
} = props;
return (
<FormControl
isDisabled={isDisabled}
@ -32,10 +42,28 @@ const IAISelect = (props: IAISelectProps) => {
e.nativeEvent.stopPropagation();
e.nativeEvent.cancelBubble = true;
}}
sx={
horizontal
? {
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
justifyContent: 'space-between',
gap: 4,
}
: {}
}
>
{label && <FormLabel>{label}</FormLabel>}
{label && (
<FormLabel sx={spaceEvenly ? { flexBasis: 0, flexGrow: 1 } : {}}>
{label}
</FormLabel>
)}
<Tooltip label={tooltip} {...tooltipProps}>
<Select {...rest}>
<Select
{...rest}
rootProps={{ sx: spaceEvenly ? { flexBasis: 0, flexGrow: 1 } : {} }}
>
{validValues.map((opt) => {
return typeof opt === 'string' || typeof opt === 'number' ? (
<IAIOption key={opt} value={opt}>

View File

@ -23,7 +23,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { clamp } from 'lodash';
import { clamp } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import {
@ -233,7 +233,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
hidden={hideTooltip}
{...sliderTooltipProps}
>
<SliderThumb {...sliderThumbProps} />
<SliderThumb {...sliderThumbProps} zIndex={0} />
</Tooltip>
</Slider>

View File

@ -1,32 +1,11 @@
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useCallback } from 'react';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { Image } from 'app/invokeai';
import { Badge, Box, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
type ImageToImageOverlayProps = {
setIsLoaded: (isLoaded: boolean) => void;
image: Image;
};
const ImageToImageOverlay = ({
setIsLoaded,
image,
}: ImageToImageOverlayProps) => {
const isImageToImageEnabled = useAppSelector(
(state: RootState) => state.generation.isImageToImageEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleResetInitialImage = useCallback(() => {
dispatch(clearInitialImage());
setIsLoaded(false);
}, [dispatch, setIsLoaded]);
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
return (
<Box
sx={{

View File

@ -1,34 +1,13 @@
import {
Box,
ButtonGroup,
Collapse,
Flex,
Heading,
HStack,
Image,
Spacer,
Text,
useDisclosure,
VStack,
} from '@chakra-ui/react';
import { motion } from 'framer-motion';
import IAIButton from 'common/components/IAIButton';
import ImageFit from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageFit';
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
import IAIIconButton from 'common/components/IAIIconButton';
import { useTranslation } from 'react-i18next';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
const ImageToImageSettingsHeader = () => {
const isImageToImageEnabled = useAppSelector(
(state: RootState) => state.generation.isImageToImageEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();

View File

@ -1,6 +1,6 @@
import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';

View File

@ -1,5 +1,6 @@
import { Flex, Image, Spinner } from '@chakra-ui/react';
import InvokeAILogoImage from 'assets/images/logo.png';
import { memo } from 'react';
// This component loads before the theme so we cannot use theme tokens here
@ -29,4 +30,4 @@ const Loading = () => {
);
};
export default Loading;
export default memo(Loading);

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(

View File

@ -1,4 +1,4 @@
import * as InvokeAI from 'app/invokeai';
import * as InvokeAI from 'app/types/invokeai';
import promptToString from './promptToString';
export function getPromptAndNegative(inputPrompt: InvokeAI.Prompt) {

View File

@ -1,5 +1,6 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
@ -15,14 +16,19 @@ export const useGetUrl = () => {
(state: RootState) => state.config.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url?: string) => {
const getUrl = useCallback(
(url?: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
[shouldTransformUrls]
);
return {
shouldTransformUrls,
getUrl,
};
};

View File

@ -1,4 +1,4 @@
import { forEach, size } from 'lodash';
import { forEach, size } from 'lodash-es';
import { ImageField, LatentsField, ConditioningField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';

View File

@ -1,4 +1,4 @@
import * as InvokeAI from 'app/invokeai';
import * as InvokeAI from 'app/types/invokeai';
const promptToString = (prompt: InvokeAI.Prompt): string => {
if (typeof prompt === 'string') {

View File

@ -1,4 +1,4 @@
import * as InvokeAI from 'app/invokeai';
import * as InvokeAI from 'app/types/invokeai';
export const stringToSeedWeights = (
string: string

View File

@ -1,20 +0,0 @@
import Component from './component';
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
InvokeAiLogoComponent,
ThemeChanger,
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@ -1,4 +1,4 @@
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import { clearCanvasHistory } from 'features/canvas/store/canvasSlice';

View File

@ -1,6 +1,6 @@
import { Box, chakra, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import {
canvasSelector,
isStagingSelector,
@ -8,7 +8,7 @@ import {
import Konva from 'konva';
import { KonvaEventObject } from 'konva/lib/Node';
import { Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useCallback, useRef } from 'react';
import { Layer, Stage } from 'react-konva';

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { useAppSelector } from 'app/store/storeHooks';
import { isEqual } from 'lodash-es';
import { Group, Rect } from 'react-konva';
import { canvasSelector } from '../store/canvasSelectors';

View File

@ -2,10 +2,10 @@
import { useToken } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { isEqual, range } from 'lodash';
import { isEqual, range } from 'lodash-es';
import { ReactNode, useCallback, useLayoutEffect, useState } from 'react';
import { Group, Line as KonvaLine } from 'react-konva';

View File

@ -1,10 +1,10 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva';

View File

@ -1,12 +1,12 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { RectConfig } from 'konva/lib/shapes/Rect';
import { Rect } from 'react-konva';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import Konva from 'konva';
import { isNumber } from 'lodash';
import { isNumber } from 'lodash-es';
import { useCallback, useEffect, useRef, useState } from 'react';
export const canvasMaskCompositerSelector = createSelector(

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { Group, Line } from 'react-konva';
import { isCanvasMaskLine } from '../store/canvasTypes';

View File

@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { Group, Line, Rect } from 'react-konva';
import {

View File

@ -1,6 +1,6 @@
import { Flex, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
canvasSelector,
initialCanvasImageSelector,

View File

@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { Group, Rect } from 'react-konva';
import IAICanvasImage from './IAICanvasImage';

View File

@ -1,7 +1,7 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import {
@ -12,7 +12,7 @@ import {
setShouldShowStagingImage,
setShouldShowStagingOutline,
} from 'features/canvas/store/canvasSlice';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';

View File

@ -1,8 +1,8 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import roundToHundreth from '../util/roundToHundreth';

View File

@ -1,9 +1,9 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import roundToHundreth from 'features/canvas/util/roundToHundreth';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';

View File

@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { Circle, Group } from 'react-konva';
import {

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
roundDownToMultiple,
roundToMultiple,
@ -16,7 +16,7 @@ import Konva from 'konva';
import { GroupConfig } from 'konva/lib/Group';
import { KonvaEventObject } from 'konva/lib/Node';
import { Vector2d } from 'konva/lib/types';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useCallback, useEffect, useRef, useState } from 'react';
import { Group, Rect, Transformer } from 'react-konva';

View File

@ -1,6 +1,6 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAICheckbox from 'common/components/IAICheckbox';
import IAIColorPicker from 'common/components/IAIColorPicker';
@ -18,7 +18,7 @@ import {
setShouldPreserveMaskedArea,
} from 'features/canvas/store/canvasSlice';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
@ -9,7 +9,7 @@ import { FaRedo } from 'react-icons/fa';
import { redo } from 'features/canvas/store/canvasSlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
const canvasRedoSelector = createSelector(

View File

@ -1,6 +1,6 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICheckbox from 'common/components/IAICheckbox';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
@ -16,7 +16,7 @@ import {
setShouldSnapToGrid,
} from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';

View File

@ -1,6 +1,6 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
@ -17,7 +17,7 @@ import {
setTool,
} from 'features/canvas/store/canvasSlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { clamp, isEqual } from 'lodash';
import { clamp, isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';

View File

@ -1,6 +1,6 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISelect from 'common/components/IAISelect';
import useImageUploader from 'common/hooks/useImageUploader';
@ -24,7 +24,7 @@ import {
import { mergeAndUploadCanvas } from 'features/canvas/store/thunks/mergeAndUploadCanvas';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';

Some files were not shown because too many files have changed in this diff Show More