diff --git a/.github/workflows/test-invoke-pip.yml b/.github/workflows/test-invoke-pip.yml index c5e4d10bfd..21fda2d191 100644 --- a/.github/workflows/test-invoke-pip.yml +++ b/.github/workflows/test-invoke-pip.yml @@ -80,12 +80,7 @@ jobs: uses: actions/checkout@v3 - name: set test prompt to main branch validation - if: ${{ github.ref == 'refs/heads/main' }} - run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }} - - - name: set test prompt to Pull Request validation - if: ${{ github.ref != 'refs/heads/main' }} - run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }} + run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }} - name: setup python uses: actions/setup-python@v4 @@ -105,12 +100,6 @@ jobs: id: run-pytest run: pytest - - name: set INVOKEAI_OUTDIR - run: > - python -c - "import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')" - >> ${{ matrix.github-env }} - - name: run invokeai-configure id: run-preload-models env: @@ -129,15 +118,20 @@ jobs: HF_HUB_OFFLINE: 1 HF_DATASETS_OFFLINE: 1 TRANSFORMERS_OFFLINE: 1 + INVOKEAI_OUTDIR: ${{ github.workspace }}/results run: > invokeai --no-patchmatch --no-nsfw_checker - --from_file ${{ env.TEST_PROMPTS }} + --precision=float32 + --always_use_cpu --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }} + --from_file ${{ env.TEST_PROMPTS }} - name: Archive results id: archive-results + env: + INVOKEAI_OUTDIR: ${{ github.workspace }}/results uses: actions/upload-artifact@v3 with: name: results diff --git a/.gitignore b/.gitignore index e9918d4fb5..7f3b1278df 100644 --- a/.gitignore +++ b/.gitignore @@ -201,6 +201,8 @@ checkpoints # If it's a Mac .DS_Store +invokeai/frontend/web/dist/* + # Let the frontend manage its own gitignore !invokeai/frontend/web/* diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 0e0d7481a7..d1c0f11b09 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -7,7 +7,6 @@ from typing import types from ..services.default_graphs import create_system_graphs from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage -from ...backend import Globals from ..services.restoration_services import RestorationServices from ..services.graph import GraphExecutionState, LibraryGraph from ..services.image_storage import DiskImageStorage @@ -42,17 +41,8 @@ class ApiDependencies: invoker: Invoker = None - @staticmethod def initialize(config, event_handler_id: int, logger: types.ModuleType=logger): - Globals.try_patchmatch = config.patchmatch - Globals.always_use_cpu = config.always_use_cpu - Globals.internet_available = config.internet_available and check_internet() - Globals.disable_xformers = not config.xformers - Globals.ckpt_convert = config.ckpt_convert - - # TO DO: Use the config to select the logger rather than use the default - # invokeai logging module - logger.info(f"Internet connectivity is {Globals.internet_available}") + logger.info(f"Internet connectivity is {config.internet_available}") events = FastAPIEventService(event_handler_id) @@ -72,7 +62,6 @@ class ApiDependencies: services = InvocationServices( model_manager=ModelManagerService(config,logger), events=events, - logger=logger, latents=latents, images=images, metadata=metadata, @@ -85,6 +74,8 @@ class ApiDependencies: ), processor=DefaultInvocationProcessor(), restoration=RestorationServices(config,logger), + configuration=config, + logger=logger, ) create_system_graphs(services.graph_library) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 2dc97df273..33714f1057 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -13,11 +13,11 @@ from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware from pydantic.schema import schema -from ..backend import Args from .api.dependencies import ApiDependencies from .api.routers import images, sessions, models from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation +from .services.config import InvokeAIAppConfig # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? @@ -33,30 +33,25 @@ app.add_middleware( middleware_id=event_handler_id, ) -# Add CORS -# TODO: use configuration for this -origins = [] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - socket_io = SocketIO(app) -config = {} - +# initialize config +# this is a module global +app_config = InvokeAIAppConfig() # Add startup event to load dependencies @app.on_event("startup") async def startup_event(): - config = Args() - config.parse_args() + app.add_middleware( + CORSMiddleware, + allow_origins=app_config.allow_origins, + allow_credentials=app_config.allow_credentials, + allow_methods=app_config.allow_methods, + allow_headers=app_config.allow_headers, + ) ApiDependencies.initialize( - config=config, event_handler_id=event_handler_id, logger=logger + config=app_config, event_handler_id=event_handler_id, logger=logger ) @@ -148,14 +143,11 @@ app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), n def invoke_api(): # Start our own event loop for eventing usage - # TODO: determine if there's a better way to do this loop = asyncio.new_event_loop() - config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop) + config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop) # Use access_log to turn off logging - server = uvicorn.Server(config) loop.run_until_complete(server.serve()) - if __name__ == "__main__": invoke_api() diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index 01cd99bc35..10d1ead677 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand): nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif") plt.axis("off") plt.show() + +class SortedHelpFormatter(argparse.HelpFormatter): + def _iter_indented_subactions(self, action): + try: + get_subactions = action._get_subactions + except AttributeError: + pass + else: + self._indent() + if isinstance(action, argparse._SubParsersAction): + for subaction in sorted(get_subactions(), key=lambda x: x.dest): + yield subaction + else: + for subaction in get_subactions(): + yield subaction + self._dedent() diff --git a/invokeai/app/cli/completer.py b/invokeai/app/cli/completer.py index c84c430bd7..79274dab8c 100644 --- a/invokeai/app/cli/completer.py +++ b/invokeai/app/cli/completer.py @@ -11,9 +11,10 @@ from pathlib import Path from typing import List, Dict, Literal, get_args, get_type_hints, get_origin import invokeai.backend.util.logging as logger -from ...backend import ModelManager, Globals +from ...backend import ModelManager from ..invocations.baseinvocation import BaseInvocation from .commands import BaseCommand +from ..services.invocation_services import InvocationServices # singleton object, class variable completer = None @@ -131,13 +132,13 @@ class Completer(object): readline.redisplay() self.linebuffer = None -def set_autocompleter(model_manager: ModelManager) -> Completer: +def set_autocompleter(services: InvocationServices) -> Completer: global completer if completer: return completer - completer = Completer(model_manager) + completer = Completer(services.model_manager) readline.set_completer(completer.complete) # pyreadline3 does not have a set_auto_history() method @@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer: readline.parse_and_bind("set skip-completed-text on") readline.parse_and_bind("set show-all-if-ambiguous on") - histfile = Path(Globals.root, ".invoke_history") + histfile = Path(services.configuration.root_dir / ".invoke_history") try: readline.read_history_file(histfile) readline.set_history_length(1000) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 89bef69e11..b63936b525 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -4,13 +4,14 @@ import argparse import os import re import shlex +import sys import time from typing import ( Union, get_type_hints, ) -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from pydantic.fields import Field import invokeai.backend.util.logging as logger @@ -20,10 +21,7 @@ from invokeai.app.services.metadata import PngMetadataService from .services.default_graphs import create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage -from ..backend import Args -from ..backend import Globals # this should go when pr 3340 merged - -from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers +from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter from .cli.completer import set_autocompleter from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase @@ -37,6 +35,7 @@ from .services.invoker import Invoker from .services.processor import DefaultInvocationProcessor from .services.sqlite import SqliteItemStorage from .services.model_manager_service import ModelManagerService +from .services.config import get_invokeai_config class CliCommand(BaseModel): command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore @@ -66,7 +65,7 @@ def add_invocation_args(command_parser): def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser: # Create invocation parser - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter) def exit(*args, **kwargs): raise InvalidArgs @@ -191,28 +190,26 @@ def invoke_all(context: CliContext): def invoke_cli(): - config = Args() - config.parse_args() + # this gets the basic configuration + config = get_invokeai_config() - logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}") - logger.info(f'InvokeAI runtime directory is "{Globals.root}"') + # get the optional list of invocations to execute on the command line + parser = config.get_parser() + parser.add_argument('commands',nargs='*') + invocation_commands = parser.parse_args().commands + # get the optional file to read commands from. + # Simplest is to use it for STDIN + if infile := config.from_file: + sys.stdin = open(infile,"r") + model_manager = ModelManagerService(config,logger) - - # This initializes the autocompleter and returns it. - # Currently nothing is done with the returned Completer - # object, but the object can be used to change autocompletion - # behavior on the fly, if desired. set_autocompleter(model_manager) events = EventServiceBase() - + output_folder = config.output_path metadata = PngMetadataService() - output_folder = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../../outputs") - ) - # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") @@ -232,6 +229,7 @@ def invoke_cli(): processor=DefaultInvocationProcessor(), restoration=RestorationServices(config,logger=logger), logger=logger, + configuration=config, ) system_graphs = create_system_graphs(services.graph_library) @@ -247,10 +245,18 @@ def invoke_cli(): # print(services.session_manager.list()) context = CliContext(invoker, session, parser) + set_autocompleter(services) - while True: + command_line_args_exist = len(invocation_commands) > 0 + done = False + + while not done: try: - cmd_input = input("invoke> ") + if command_line_args_exist: + cmd_input = invocation_commands.pop(0) + done = len(invocation_commands) == 0 + else: + cmd_input = input("invoke> ") except (KeyboardInterrupt, EOFError): # Ctrl-c exits break @@ -374,6 +380,9 @@ def invoke_cli(): invoker.services.logger.warning('Invalid command, use "help" to list commands') continue + except ValidationError: + invoker.services.logger.warning('Invalid command arguments, run " --help" for summary') + except SessionError: # Start a new session invoker.services.logger.warning("Session error: creating a new session") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 633b53accd..dff6c58e88 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,10 +5,8 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont from .model import ClipField -from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.util.devices import torch_dtype from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent -from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager -from ...backend.model_management import SDModelType from compel import Compel from compel.prompt_parser import ( @@ -18,8 +16,6 @@ from compel.prompt_parser import ( Fragment, ) -from invokeai.backend.globals import Globals - class ConditioningField(BaseModel): conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") @@ -91,7 +87,7 @@ class CompelInvocation(BaseInvocation): prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(self.prompt) - if getattr(Globals, "log_tokenization", False): + if context.services.configuration.log_tokenization: log_tokenization_for_prompt_object(prompt, tokenizer) c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 98f87d2dd4..2ce58c016b 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,7 +5,12 @@ from typing import Literal from pydantic import BaseModel, Field import numpy as np -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, + InvocationConfig, +) class MathInvocationConfig(BaseModel): @@ -22,19 +27,21 @@ class MathInvocationConfig(BaseModel): class IntOutput(BaseInvocationOutput): """An integer output""" - #fmt: off + + # fmt: off type: Literal["int_output"] = "int_output" a: int = Field(default=None, description="The output integer") - #fmt: on + # fmt: on class AddInvocation(BaseInvocation, MathInvocationConfig): """Adds two numbers""" - #fmt: off + + # fmt: off type: Literal["add"] = "add" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a + self.b) @@ -42,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig): class SubtractInvocation(BaseInvocation, MathInvocationConfig): """Subtracts two numbers""" - #fmt: off + + # fmt: off type: Literal["sub"] = "sub" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a - self.b) @@ -54,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig): class MultiplyInvocation(BaseInvocation, MathInvocationConfig): """Multiplies two numbers""" - #fmt: off + + # fmt: off type: Literal["mul"] = "mul" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a * self.b) @@ -66,11 +75,12 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig): class DivideInvocation(BaseInvocation, MathInvocationConfig): """Divides two numbers""" - #fmt: off + + # fmt: off type: Literal["div"] = "div" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=int(self.a / self.b)) @@ -78,8 +88,13 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig): class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" - #fmt: off + + # fmt: off type: Literal["rand_int"] = "rand_int" - #fmt: on + low: int = Field(default=0, description="The inclusive low value") + high: int = Field( + default=np.iinfo(np.int32).max, description="The exclusive high value" + ) + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max)) + return IntOutput(a=np.random.randint(self.low, self.high)) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py new file mode 100644 index 0000000000..824690c525 --- /dev/null +++ b/invokeai/app/services/config.py @@ -0,0 +1,528 @@ +# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team + +'''Invokeai configuration system. + +Arguments and fields are taken from the pydantic definition of the +model. Defaults can be set by creating a yaml configuration file that +has a top-level key of "InvokeAI" and subheadings for each of the +categories returned by `invokeai --help`. The file looks like this: + +[file: invokeai.yaml] + +InvokeAI: + Paths: + root: /home/lstein/invokeai-main + conf_path: configs/models.yaml + legacy_conf_dir: configs/stable-diffusion + outdir: outputs + embedding_dir: embeddings + lora_dir: loras + autoconvert_dir: null + gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth + Models: + model: stable-diffusion-1.5 + embeddings: true + Memory/Performance: + xformers_enabled: false + sequential_guidance: false + precision: float16 + max_loaded_models: 4 + always_use_cpu: false + free_gpu_mem: false + Features: + nsfw_checker: true + restore: true + esrgan: true + patchmatch: true + internet_available: true + log_tokenization: false + Web Server: + host: 127.0.0.1 + port: 8081 + allow_origins: [] + allow_credentials: true + allow_methods: + - '*' + allow_headers: + - '*' + +The default name of the configuration file is `invokeai.yaml`, located +in INVOKEAI_ROOT. You can replace supersede this by providing any +OmegaConf dictionary object initialization time: + + omegaconf = OmegaConf.load('/tmp/init.yaml') + conf = InvokeAIAppConfig(conf=omegaconf) + +By default, InvokeAIAppConfig will parse the contents of `sys.argv` at +initialization time. You may pass a list of strings in the optional +`argv` argument to use instead of the system argv: + + conf = InvokeAIAppConfig(arg=['--xformers_enabled']) + +It is also possible to set a value at initialization time. This value +has highest priority. + + conf = InvokeAIAppConfig(xformers_enabled=True) + +Any setting can be overwritten by setting an environment variable of +form: "INVOKEAI_", as in: + + export INVOKEAI_port=8080 + +Order of precedence (from highest): + 1) initialization options + 2) command line options + 3) environment variable options + 4) config file options + 5) pydantic defaults + +Typical usage: + + from invokeai.app.services.config import InvokeAIAppConfig + from invokeai.invocations.generate import TextToImageInvocation + + # get global configuration and print its nsfw_checker value + conf = InvokeAIAppConfig() + print(conf.nsfw_checker) + + # get the text2image invocation and print its step value + text2image = TextToImageInvocation() + print(text2image.steps) + +Computed properties: + +The InvokeAIAppConfig object has a series of properties that +resolve paths relative to the runtime root directory. They each return +a Path object: + + root_path - path to InvokeAI root + output_path - path to default outputs directory + model_conf_path - path to models.yaml + conf - alias for the above + embedding_path - path to the embeddings directory + lora_path - path to the LoRA directory + +In most cases, you will want to create a single InvokeAIAppConfig +object for the entire application. The get_invokeai_config() function +does this: + + config = get_invokeai_config() + print(config.root) + +# Subclassing + +If you wish to create a similar class, please subclass the +`InvokeAISettings` class and define a Literal field named "type", +which is set to the desired top-level name. For example, to create a +"InvokeBatch" configuration, define like this: + + class InvokeBatch(InvokeAISettings): + type: Literal["InvokeBatch"] = "InvokeBatch" + node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources') + cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources') + +This will now read and write from the "InvokeBatch" section of the +config file, look for environment variables named INVOKEBATCH_*, and +accept the command-line arguments `--node_count` and `--cpu_count`. The +two configs are kept in separate sections of the config file: + + # invokeai.yaml + + InvokeBatch: + Resources: + node_count: 1 + cpu_count: 8 + + InvokeAI: + Paths: + root: /home/lstein/invokeai-main + conf_path: configs/models.yaml + legacy_conf_dir: configs/stable-diffusion + outdir: outputs + ... +''' +import argparse +import pydoc +import typing +import os +import sys +from argparse import ArgumentParser +from omegaconf import OmegaConf, DictConfig +from pathlib import Path +from pydantic import BaseSettings, Field, parse_obj_as +from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args + +INIT_FILE = Path('invokeai.yaml') +LEGACY_INIT_FILE = Path('invokeai.init') + +# This global stores a singleton InvokeAIAppConfig configuration object +global_config = None + +class InvokeAISettings(BaseSettings): + ''' + Runtime configuration settings in which default values are + read from an omegaconf .yaml file. + ''' + initconf : ClassVar[DictConfig] = None + argparse_groups : ClassVar[Dict] = {} + + def parse_args(self, argv: list=sys.argv[1:]): + parser = self.get_parser() + opt, _ = parser.parse_known_args(argv) + for name in self.__fields__: + if name not in self._excluded(): + setattr(self, name, getattr(opt,name)) + + def to_yaml(self)->str: + """ + Return a YAML string representing our settings. This can be used + as the contents of `invokeai.yaml` to restore settings later. + """ + cls = self.__class__ + type = get_args(get_type_hints(cls)['type'])[0] + field_dict = dict({type:dict()}) + for name,field in self.__fields__.items(): + if name in cls._excluded(): + continue + category = field.field_info.extra.get("category") or "Uncategorized" + value = getattr(self,name) + if category not in field_dict[type]: + field_dict[type][category] = dict() + # keep paths as strings to make it easier to read + field_dict[type][category][name] = str(value) if isinstance(value,Path) else value + conf = OmegaConf.create(field_dict) + return OmegaConf.to_yaml(conf) + + @classmethod + def add_parser_arguments(cls, parser): + if 'type' in get_type_hints(cls): + settings_stanza = get_args(get_type_hints(cls)['type'])[0] + else: + settings_stanza = "Uncategorized" + + env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper() + + initconf = cls.initconf.get(settings_stanza) \ + if cls.initconf and settings_stanza in cls.initconf \ + else OmegaConf.create() + + # create an upcase version of the environment in + # order to achieve case-insensitive environment + # variables (the way Windows does) + upcase_environ = dict() + for key,value in os.environ.items(): + upcase_environ[key.upper()] = value + + fields = cls.__fields__ + cls.argparse_groups = {} + + for name, field in fields.items(): + if name not in cls._excluded(): + current_default = field.default + + category = field.field_info.extra.get("category","Uncategorized") + env_name = env_prefix + '_' + name + if category in initconf and name in initconf.get(category): + field.default = initconf.get(category).get(name) + if env_name.upper() in upcase_environ: + field.default = upcase_environ[env_name.upper()] + cls.add_field_argument(parser, name, field) + + field.default = current_default + + @classmethod + def cmd_name(self, command_field: str='type')->str: + hints = get_type_hints(self) + if command_field in hints: + return get_args(hints[command_field])[0] + else: + return 'Uncategorized' + + @classmethod + def get_parser(cls)->ArgumentParser: + parser = PagingArgumentParser( + prog=cls.cmd_name(), + description=cls.__doc__, + ) + cls.add_parser_arguments(parser) + return parser + + @classmethod + def add_subparser(cls, parser: argparse.ArgumentParser): + parser.add_parser(cls.cmd_name(), help=cls.__doc__) + + @classmethod + def _excluded(self)->List[str]: + return ['type','initconf'] + + class Config: + env_file_encoding = 'utf-8' + arbitrary_types_allowed = True + case_sensitive = True + + @classmethod + def add_field_argument(cls, command_parser, name: str, field, default_override = None): + field_type = get_type_hints(cls).get(name) + default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() + if category := field.field_info.extra.get("category"): + if category not in cls.argparse_groups: + cls.argparse_groups[category] = command_parser.add_argument_group(category) + argparse_group = cls.argparse_groups[category] + else: + argparse_group = command_parser + + if get_origin(field_type) == Literal: + allowed_values = get_args(field.type_) + allowed_types = set() + for val in allowed_values: + allowed_types.add(type(val)) + allowed_types_list = list(allowed_types) + field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore + + argparse_group.add_argument( + f"--{name}", + dest=name, + type=field_type, + default=default, + choices=allowed_values, + help=field.field_info.description, + ) + + elif get_origin(field_type) == list: + argparse_group.add_argument( + f"--{name}", + dest=name, + nargs='*', + type=field.type_, + default=default, + action=argparse.BooleanOptionalAction if field.type_==bool else 'store', + help=field.field_info.description, + ) + else: + argparse_group.add_argument( + f"--{name}", + dest=name, + type=field.type_, + default=default, + action=argparse.BooleanOptionalAction if field.type_==bool else 'store', + help=field.field_info.description, + ) +def _find_root()->Path: + if os.environ.get("INVOKEAI_ROOT"): + root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() + elif ( + os.environ.get("VIRTUAL_ENV") + and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists() + or + Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists() + ) + ): + root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve() + else: + root = Path("~/invokeai").expanduser().resolve() + return root + +class InvokeAIAppConfig(InvokeAISettings): + ''' +Generate images using Stable Diffusion. Use "invokeai" to launch +the command-line client (recommended for experts only), or +"invokeai-web" to launch the web server. Global options +can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by +setting environment variables INVOKEAI_. + ''' + #fmt: off + type: Literal["InvokeAI"] = "InvokeAI" + host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') + port : int = Field(default=9090, description="Port to bind to", category='Web Server') + allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server') + allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server') + allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server') + allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server') + + esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features') + internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') + log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') + nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features') + patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') + restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features') + + always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') + free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') + max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') + sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') + xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') + + root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') + autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths') + conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') + embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths') + gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths') + legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') + lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths') + outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') + from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') + + model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') + embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models') + #fmt: on + + def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs): + ''' + Initialize InvokeAIAppconfig. + :param conf: alternate Omegaconf dictionary object + :param argv: aternate sys.argv list + :param **kwargs: attributes to initialize with + ''' + super().__init__(**kwargs) + + # Set the runtime root directory. We parse command-line switches here + # in order to pick up the --root_dir option. + self.parse_args(argv) + if conf is None: + try: + conf = OmegaConf.load(self.root_dir / INIT_FILE) + except: + pass + InvokeAISettings.initconf = conf + + # parse args again in order to pick up settings in configuration file + self.parse_args(argv) + + # restore initialization values + hints = get_type_hints(self) + for k in kwargs: + setattr(self,k,parse_obj_as(hints[k],kwargs[k])) + + @property + def root_path(self)->Path: + ''' + Path to the runtime root directory + ''' + if self.root: + return Path(self.root).expanduser() + else: + return self.find_root() + + @property + def root_dir(self)->Path: + ''' + Alias for above. + ''' + return self.root_path + + def _resolve(self,partial_path:Path)->Path: + return (self.root_path / partial_path).resolve() + + @property + def output_path(self)->Path: + ''' + Path to defaults outputs directory. + ''' + return self._resolve(self.outdir) + + @property + def model_conf_path(self)->Path: + ''' + Path to models configuration file. + ''' + return self._resolve(self.conf_path) + + @property + def legacy_conf_path(self)->Path: + ''' + Path to directory of legacy configuration files (e.g. v1-inference.yaml) + ''' + return self._resolve(self.legacy_conf_dir) + + @property + def cache_dir(self)->Path: + ''' + Path to the global cache directory for HuggingFace hub-managed models + ''' + return self.models_dir / "hub" + + @property + def models_dir(self)->Path: + ''' + Path to the models directory + ''' + return self._resolve("models") + + @property + def converted_ckpts_dir(self)->Path: + ''' + Path to the converted models + ''' + return self._resolve("models/converted_ckpts") + + @property + def embedding_path(self)->Path: + ''' + Path to the textual inversion embeddings directory. + ''' + return self._resolve(self.embedding_dir) if self.embedding_dir else None + + @property + def lora_path(self)->Path: + ''' + Path to the LoRA models directory. + ''' + return self._resolve(self.lora_dir) if self.lora_dir else None + + @property + def autoconvert_path(self)->Path: + ''' + Path to the directory containing models to be imported automatically at startup. + ''' + return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None + + @property + def gfpgan_model_path(self)->Path: + ''' + Path to the GFPGAN model. + ''' + return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None + + # the following methods support legacy calls leftover from the Globals era + @property + def full_precision(self)->bool: + """Return true if precision set to float32""" + return self.precision=='float32' + + @property + def disable_xformers(self)->bool: + """Return true if xformers_enabled is false""" + return not self.xformers_enabled + + @property + def try_patchmatch(self)->bool: + """Return true if patchmatch true""" + return self.patchmatch + + @staticmethod + def find_root()->Path: + ''' + Choose the runtime root directory when not specified on command line or + init file. + ''' + return _find_root() + + +class PagingArgumentParser(argparse.ArgumentParser): + ''' + A custom ArgumentParser that uses pydoc to page its output. + It also supports reading defaults from an init file. + ''' + def print_help(self, file=None): + text = self.format_help() + pydoc.pager(text) + +def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings: + ''' + This returns a singleton InvokeAIAppConfig configuration object. + ''' + global global_config + if global_config is None or type(global_config)!=cls: + global_config = cls(**kwargs) + return global_config diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 7ed65015d0..ab6e4ed49d 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -135,6 +135,7 @@ class GraphInvocationOutput(BaseInvocationOutput): # TODO: Fill this out and move to invocations class GraphInvocation(BaseInvocation): + """Execute a graph""" type: Literal["graph"] = "graph" # TODO: figure out how to create a default here @@ -162,6 +163,7 @@ class IterateInvocationOutput(BaseInvocationOutput): # TODO: Fill this out and move to invocations class IterateInvocation(BaseInvocation): + """Iterates over a list of items""" type: Literal["iterate"] = "iterate" collection: list[Any] = Field( diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 47b3b6cf07..d4c0c06b65 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase from .restoration_services import RestorationServices from .invocation_queue import InvocationQueueABC from .item_storage import ItemStorageABC +from .config import InvokeAISettings class InvocationServices: """Services that can be used by invocations""" @@ -21,7 +22,8 @@ class InvocationServices: queue: InvocationQueueABC model_manager: ModelManager restoration: RestorationServices - + configuration: InvokeAISettings + # NOTE: we must forward-declare any types that include invocations, since invocations can use services graph_library: ItemStorageABC["LibraryGraph"] graph_execution_manager: ItemStorageABC["GraphExecutionState"] @@ -40,6 +42,7 @@ class InvocationServices: graph_execution_manager: ItemStorageABC["GraphExecutionState"], processor: "InvocationProcessorABC", restoration: RestorationServices, + configuration: InvokeAISettings=None, ): self.model_manager = model_manager self.events = events @@ -52,3 +55,4 @@ class InvocationServices: self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration + self.configuration = configuration diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 448bd59f00..2575a92a5c 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,7 +1,6 @@ """ Initialization file for invokeai.backend """ -from .generate import Generate from .generator import ( InvokeAIGeneratorBasicParams, InvokeAIGenerator, @@ -12,5 +11,3 @@ from .generator import ( ) from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo from .safety_checker import SafetyChecker -from .args import Args -from .globals import Globals diff --git a/invokeai/backend/args.py b/invokeai/backend/args.py deleted file mode 100644 index 7404dbeede..0000000000 --- a/invokeai/backend/args.py +++ /dev/null @@ -1,1397 +0,0 @@ -"""Helper class for dealing with image generation arguments. - -The Args class parses both the command line (shell) arguments, as well as the -command string passed at the invoke> prompt. It serves as the definitive repository -of all the arguments used by Generate and their default values, and implements the -preliminary metadata standards discussed here: - -https://github.com/lstein/stable-diffusion/issues/266 - -To use: - opt = Args() - - # Read in the command line options: - # this returns a namespace object like the underlying argparse library) - # You do not have to use the return value, but you can check it against None - # to detect illegal arguments on the command line. - args = opt.parse_args() - if not args: - print('oops') - sys.exit(-1) - - # read in a command passed to the invoke> prompt: - opts = opt.parse_cmd('do androids dream of electric sheep? -H256 -W1024 -n4') - - # The Args object acts like a namespace object - print(opt.model) - -You can set attributes in the usual way, use vars(), etc.: - - opt.model = 'something-else' - do_something(**vars(a)) - -It is helpful in saving metadata: - - # To get a json representation of all the values, allowing - # you to override any values dynamically - j = opt.json(seed=42) - - # To get the prompt string with the switches, allowing you - # to override any values dynamically - j = opt.dream_prompt_str(seed=42) - -If you want to access the namespace objects from the shell args or the -parsed command directly, you may use the values returned from the -original calls to parse_args() and parse_cmd(), or get them later -using the _arg_switches and _cmd_switches attributes. This can be -useful if both the args and the command contain the same attribute and -you wish to apply logic as to which one to use. For example: - - a = Args() - args = a.parse_args() - opts = a.parse_cmd(string) - do_grid = args.grid or opts.grid - -To add new attributes, edit the _create_arg_parser() and -_create_dream_cmd_parser() methods. - -**Generating and retrieving sd-metadata** - -To generate a dict representing RFC266 metadata: - - metadata = metadata_dumps(opt,) - -This will generate an RFC266 dictionary that can then be turned into a JSON -and written to the PNG file. The optional seeds, weights, model_hash and -postprocesser arguments are not available to the opt object and so must be -provided externally. See how invoke.py does it. - -Note that this function was originally called format_metadata() and a wrapper -is provided that issues a deprecation notice. - -To retrieve a (series of) opt objects corresponding to the metadata, do this: - - opt_list = metadata_loads(metadata) - -The metadata should be pulled out of the PNG image. pngwriter has a method -retrieve_metadata that will do this, or you can do it in one swell foop -with metadata_from_png(): - - opt_list = metadata_from_png('/path/to/image_file.png') -""" - -import argparse -import base64 -import copy -import functools -import hashlib -import json -import os -import pydoc -import re -import shlex -import sys -from argparse import Namespace -from pathlib import Path -from typing import List - -import invokeai.version -import invokeai.backend.util.logging as logger -from invokeai.backend.image_util import retrieve_metadata - -from .globals import Globals -from .prompting import split_weighted_subprompts - -APP_ID = invokeai.version.__app_id__ -APP_NAME = invokeai.version.__app_name__ -APP_VERSION = invokeai.version.__version__ - -SAMPLER_CHOICES = [ - "ddim", - "ddpm", - "deis", - "lms", - "pndm", - "heun", - "euler", - "euler_k", - "euler_a", - "kdpm_2", - "kdpm_2_a", - "dpmpp_2s", - "dpmpp_2m", - "dpmpp_2m_k", - "unipc", -] - -PRECISION_CHOICES = [ - "auto", - "float32", - "autocast", - "float16", -] - - -class ArgFormatter(argparse.RawTextHelpFormatter): - # use defined argument order to display usage - def _format_usage(self, usage, actions, groups, prefix): - if prefix is None: - prefix = "usage: " - - # if usage is specified, use that - if usage is not None: - usage = usage % dict(prog=self._prog) - - # if no optionals or positionals are available, usage is just prog - elif usage is None and not actions: - usage = "invoke>" - elif usage is None: - prog = "invoke>" - # build full usage string - action_usage = self._format_actions_usage(actions, groups) # NEW - usage = " ".join([s for s in [prog, action_usage] if s]) - # omit the long line wrapping code - # prefix with 'usage:' - return "%s%s\n\n" % (prefix, usage) - - -class PagingArgumentParser(argparse.ArgumentParser): - """ - A custom ArgumentParser that uses pydoc to page its output. - It also supports reading defaults from an init file. - """ - - def print_help(self, file=None): - text = self.format_help() - pydoc.pager(text) - - def convert_arg_line_to_args(self, arg_line): - return shlex.split(arg_line, comments=True) - - -class Args(object): - def __init__(self, arg_parser=None, cmd_parser=None): - """ - Initialize new Args class. It takes two optional arguments, an argparse - parser for switches given on the shell command line, and an argparse - parser for switches given on the invoke> CLI line. If one or both are - missing, it creates appropriate parsers internally. - """ - self._arg_parser = arg_parser or self._create_arg_parser() - self._cmd_parser = cmd_parser or self._create_dream_cmd_parser() - self._arg_switches = self.parse_cmd("") # fill in defaults - self._cmd_switches = self.parse_cmd("") # fill in defaults - - def parse_args(self, args: List[str] = None): - """Parse the shell switches and store.""" - sysargs = args if args is not None else sys.argv[1:] - try: - # pre-parse before we do any initialization to get root directory - # and intercept --version request - switches = self._arg_parser.parse_args(sysargs) - if switches.version: - print(f"{APP_NAME} {APP_VERSION}") - sys.exit(0) - - logger.info("Initializing, be patient...") - Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root)) - Globals.try_patchmatch = switches.patchmatch - - # now use root directory to find the init file - initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile)) - legacyinit = os.path.expanduser("~/.invokeai") - if os.path.exists(initfile): - logger.info( - f"Initialization file {initfile} found. Loading...", - ) - sysargs.insert(0, f"@{initfile}") - elif os.path.exists(legacyinit): - logger.warning( - f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init." - ) - sysargs.insert(0, f"@{legacyinit}") - Globals.log_tokenization = self._arg_parser.parse_args( - sysargs - ).log_tokenization - - self._arg_switches = self._arg_parser.parse_args(sysargs) - return self._arg_switches - except Exception as e: - logger.error(f"An exception has occurred: {e}") - return None - - def parse_cmd(self, cmd_string): - """Parse a invoke>-style command string""" - # handle the case in which the first token is a switch - if cmd_string.startswith("-"): - prompt = "" - switches = cmd_string - # handle the case in which the prompt is enclosed by quotes - elif cmd_string.startswith('"'): - a = shlex.split(cmd_string, comments=True) - prompt = a[0] - switches = shlex.join(a[1:]) - else: - # no initial quote, so get everything up to the first thing - # that looks like a switch - if cmd_string.startswith("-"): - prompt = "" - switches = cmd_string - else: - match = re.match("^(.+?)\s(--?[a-zA-Z].+)", cmd_string) - if match: - prompt, switches = match.groups() - else: - prompt = cmd_string - switches = "" - try: - self._cmd_switches = self._cmd_parser.parse_args( - shlex.split(switches, comments=True) - ) - if not getattr(self._cmd_switches, "prompt"): - setattr(self._cmd_switches, "prompt", prompt) - return self._cmd_switches - except: - return None - - def json(self, **kwargs): - return json.dumps(self.to_dict(**kwargs)) - - def to_dict(self, **kwargs): - a = vars(self) - a.update(kwargs) - return a - - # Isn't there a more automated way of doing this? - # Ideally we get the switch strings out of the argparse objects, - # but I don't see a documented API for this. - def dream_prompt_str(self, **kwargs): - """Normalized dream_prompt.""" - a = vars(self) - a.update(kwargs) - switches = list() - prompt = a["prompt"] - prompt.replace('"', '\\"') - switches.append(prompt) - switches.append(f'-s {a["steps"]}') - switches.append(f'-S {a["seed"]}') - switches.append(f'-W {a["width"]}') - switches.append(f'-H {a["height"]}') - switches.append(f'-C {a["cfg_scale"]}') - if a["karras_max"] is not None: - switches.append(f'--karras_max {a["karras_max"]}') - if a["perlin"] > 0: - switches.append(f'--perlin {a["perlin"]}') - if a["threshold"] > 0: - switches.append(f'--threshold {a["threshold"]}') - if a["grid"]: - switches.append("--grid") - if a["seamless"]: - switches.append("--seamless") - if a["hires_fix"]: - switches.append("--hires_fix") - if a["h_symmetry_time_pct"]: - switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}') - if a["v_symmetry_time_pct"]: - switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}') - - # img2img generations have parameters relevant only to them and have special handling - if a["init_img"] and len(a["init_img"]) > 0: - switches.append(f'-I {a["init_img"]}') - switches.append(f'-A {a["sampler_name"]}') - if a["fit"]: - switches.append("--fit") - if a["init_mask"] and len(a["init_mask"]) > 0: - switches.append(f'-M {a["init_mask"]}') - if a["init_color"] and len(a["init_color"]) > 0: - switches.append(f'--init_color {a["init_color"]}') - if a["strength"] and a["strength"] > 0: - switches.append(f'-f {a["strength"]}') - if a["inpaint_replace"]: - switches.append("--inpaint_replace") - if a["text_mask"]: - switches.append(f'-tm {" ".join([str(u) for u in a["text_mask"]])}') - else: - switches.append(f'-A {a["sampler_name"]}') - - # facetool-specific parameters, only print if running facetool - if a["facetool_strength"]: - switches.append(f'-G {a["facetool_strength"]}') - switches.append(f'-ft {a["facetool"]}') - if a["facetool"] == "codeformer": - switches.append(f'-cf {a["codeformer_fidelity"]}') - - if a["outcrop"]: - switches.append(f'-c {" ".join([str(u) for u in a["outcrop"]])}') - - # esrgan-specific parameters - if a["upscale"]: - switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}') - - # embiggen parameters - if a["embiggen"]: - switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}') - if a["embiggen_tiles"]: - switches.append( - f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}' - ) - if a["embiggen_strength"]: - switches.append(f'--embiggen_strength {a["embiggen_strength"]}') - - # outpainting parameters - if a["out_direction"]: - switches.append(f'-D {" ".join([str(u) for u in a["out_direction"]])}') - - # LS: slight semantic drift which needs addressing in the future: - # 1. Variations come out of the stored metadata as a packed string with the keyword "variations" - # 2. However, they come out of the CLI (and probably web) with the keyword "with_variations" and - # in broken-out form. Variation (1) should be changed to comply with (2) - if a["with_variations"] and len(a["with_variations"]) > 0: - formatted_variations = ",".join( - f"{seed}:{weight}" for seed, weight in (a["with_variations"]) - ) - switches.append(f"-V {formatted_variations}") - if "variations" in a and len(a["variations"]) > 0: - switches.append(f'-V {a["variations"]}') - return " ".join(switches) - - def __getattribute__(self, name): - """ - Returns union of command-line arguments and dream_prompt arguments, - with the latter superseding the former. - """ - cmd_switches = None - arg_switches = None - try: - cmd_switches = object.__getattribute__(self, "_cmd_switches") - arg_switches = object.__getattribute__(self, "_arg_switches") - except AttributeError: - pass - - if cmd_switches and arg_switches and name == "__dict__": - return self._merge_dict( - arg_switches.__dict__, - cmd_switches.__dict__, - ) - try: - return object.__getattribute__(self, name) - except AttributeError: - pass - - if not hasattr(cmd_switches, name) and not hasattr(arg_switches, name): - raise AttributeError - - value_arg, value_cmd = (None, None) - try: - value_cmd = getattr(cmd_switches, name) - except AttributeError: - pass - try: - value_arg = getattr(arg_switches, name) - except AttributeError: - pass - - # here is where we can pick and choose which to use - # default behavior is to choose the dream_command value over - # the arg value. For example, the --grid and --individual options are a little - # funny because of their push/pull relationship. This is how to handle it. - if name == "grid": - if cmd_switches.individual: - return False - else: - return value_cmd or value_arg - return value_cmd if value_cmd is not None else value_arg - - def __setattr__(self, name, value): - if name.startswith("_"): - object.__setattr__(self, name, value) - else: - self._cmd_switches.__dict__[name] = value - - def _merge_dict(self, dict1, dict2): - new_dict = {} - for k in set(list(dict1.keys()) + list(dict2.keys())): - value1 = dict1.get(k, None) - value2 = dict2.get(k, None) - new_dict[k] = value2 if value2 is not None else value1 - return new_dict - - def _create_init_file(self, initfile: str): - with open(initfile, mode="w", encoding="utf-8") as f: - f.write( - """# InvokeAI initialization file -# Put frequently-used startup commands here, one or more per line -# Examples: -# --web --host=0.0.0.0 -# --steps 20 -# -Ak_euler_a -C10.0 -""" - ) - - def _create_arg_parser(self): - """ - This defines all the arguments used on the command line when you launch - the CLI or web backend. - """ - parser = PagingArgumentParser( - description=""" - Generate images using Stable Diffusion. - Use --web to launch the web interface. - Use --from_file to load prompts from a file path or standard input ("-"). - Otherwise you will be dropped into an interactive command prompt (type -h for help.) - Other command-line arguments are defaults that can usually be overridden - prompt the command prompt. - """, - fromfile_prefix_chars="@", - ) - general_group = parser.add_argument_group("General") - model_group = parser.add_argument_group("Model selection") - file_group = parser.add_argument_group("Input/output") - web_server_group = parser.add_argument_group("Web server") - render_group = parser.add_argument_group("Rendering") - postprocessing_group = parser.add_argument_group("Postprocessing") - deprecated_group = parser.add_argument_group("Deprecated options") - - deprecated_group.add_argument("--laion400m") - deprecated_group.add_argument("--weights") # deprecated - deprecated_group.add_argument( - "--ckpt_convert", - action=argparse.BooleanOptionalAction, - dest="ckpt_convert", - default=True, - help="Load legacy ckpt files as diffusers (deprecated; always true now).", - ) - - general_group.add_argument( - "--version", "-V", action="store_true", help="Print InvokeAI version number" - ) - model_group.add_argument( - "--root_dir", - default=None, - help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.', - ) - model_group.add_argument( - "--config", - "-c", - "-config", - dest="conf", - default="./configs/models.yaml", - help="Path to configuration file for alternate models.", - ) - model_group.add_argument( - "--model", - help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', - ) - model_group.add_argument( - "--weight_dirs", - nargs="+", - type=str, - help="List of one or more directories that will be auto-scanned for new model weights to import", - ) - model_group.add_argument( - "--png_compression", - "-z", - type=int, - default=6, - choices=range(0, 10), - dest="png_compression", - help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.", - ) - model_group.add_argument( - "-F", - "--full_precision", - dest="full_precision", - action="store_true", - help="Deprecated way to set --precision=float32", - ) - model_group.add_argument( - "--max_cache_size", - dest="max_cache_size", - type=float, - default=6.0, - help="Maximum size of the model RAM cache (in GB). 6 GB is sufficient to keep 2-3 diffusers models in RAM simultaneously.", - ) - model_group.add_argument( - "--free_gpu_mem", - dest="free_gpu_mem", - action="store_true", - help="Force free gpu memory before final decoding", - ) - model_group.add_argument( - "--sequential_guidance", - dest="sequential_guidance", - action="store_true", - help="Calculate guidance in serial instead of in parallel, lowering memory requirement " - "at the expense of speed", - ) - model_group.add_argument( - "--xformers", - action=argparse.BooleanOptionalAction, - default=True, - help="Enable/disable xformers support (default enabled if installed)", - ) - model_group.add_argument( - "--always_use_cpu", - dest="always_use_cpu", - action="store_true", - help="Force use of CPU even if GPU is available", - ) - model_group.add_argument( - "--precision", - dest="precision", - type=str, - choices=PRECISION_CHOICES, - metavar="PRECISION", - help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', - default="auto", - ) - model_group.add_argument( - "--internet", - action=argparse.BooleanOptionalAction, - dest="internet_available", - default=True, - help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).", - ) - model_group.add_argument( - "--nsfw_checker", - "--safety_checker", - action=argparse.BooleanOptionalAction, - dest="safety_checker", - default=False, - help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.", - ) - model_group.add_argument( - "--autoconvert", - default=None, - type=str, - help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models", - ) - model_group.add_argument( - "--patchmatch", - action=argparse.BooleanOptionalAction, - default=True, - help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.", - ) - file_group.add_argument( - "--from_file", - dest="infile", - type=str, - help="If specified, load prompts from this file", - ) - file_group.add_argument( - "--outdir", - "-o", - type=str, - help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs", - default="outputs", - ) - file_group.add_argument( - "--prompt_as_dir", - "-p", - action="store_true", - help="Place images in subdirectories named after the prompt.", - ) - render_group.add_argument( - "--fnformat", - default="{prefix}.{seed}.png", - type=str, - help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png", - ) - render_group.add_argument( - "-s", "--steps", type=int, default=50, help="Number of steps" - ) - render_group.add_argument( - "-W", - "--width", - type=int, - help="Image width, multiple of 64", - ) - render_group.add_argument( - "-H", - "--height", - type=int, - help="Image height, multiple of 64", - ) - render_group.add_argument( - "-C", - "--cfg_scale", - default=7.5, - type=float, - help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', - ) - render_group.add_argument( - "--sampler", - "-A", - "-m", - dest="sampler_name", - type=str, - choices=SAMPLER_CHOICES, - metavar="SAMPLER_NAME", - help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', - default="lms", - ) - render_group.add_argument( - "--log_tokenization", - "-t", - action="store_true", - help="shows how the prompt is split into tokens", - ) - render_group.add_argument( - "-f", - "--strength", - type=float, - help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely", - ) - render_group.add_argument( - "-T", - "-fit", - "--fit", - action=argparse.BooleanOptionalAction, - help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)", - ) - - render_group.add_argument( - "--grid", - "-g", - action=argparse.BooleanOptionalAction, - help="generate a grid", - ) - render_group.add_argument( - "--embedding_directory", - "--embedding_path", - dest="embedding_path", - default="embeddings", - type=str, - help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)", - ) - render_group.add_argument( - "--embeddings", - action=argparse.BooleanOptionalAction, - default=True, - help="Enable embedding directory (default). Use --no-embeddings to disable.", - ) - render_group.add_argument( - "--enable_image_debugging", - action="store_true", - help="Generates debugging image to display", - ) - render_group.add_argument( - "--karras_max", - type=int, - default=None, - help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].", - ) - # Restoration related args - postprocessing_group.add_argument( - "--no_restore", - dest="restore", - action="store_false", - help="Disable face restoration with GFPGAN or codeformer", - ) - postprocessing_group.add_argument( - "--no_upscale", - dest="esrgan", - action="store_false", - help="Disable upscaling with ESRGAN", - ) - postprocessing_group.add_argument( - "--esrgan_bg_tile", - type=int, - default=400, - help="Tile size for background sampler, 0 for no tile during testing. Default: 400.", - ) - postprocessing_group.add_argument( - "--esrgan_denoise_str", - type=float, - default=0.75, - help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75", - ) - postprocessing_group.add_argument( - "--gfpgan_model_path", - type=str, - default="./models/gfpgan/GFPGANv1.4.pth", - help="Indicates the path to the GFPGAN model", - ) - web_server_group.add_argument( - "--web", - dest="web", - action="store_true", - help="Start in web server mode.", - ) - web_server_group.add_argument( - "--web_develop", - dest="web_develop", - action="store_true", - help="Start in web server development mode.", - ) - web_server_group.add_argument( - "--web_verbose", - action="store_true", - help="Enables verbose logging", - ) - web_server_group.add_argument( - "--cors", - nargs="*", - type=str, - help="Additional allowed origins, comma-separated", - ) - web_server_group.add_argument( - "--host", - type=str, - default="127.0.0.1", - help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.", - ) - web_server_group.add_argument( - "--port", type=int, default="9090", help="Web server: Port to listen on" - ) - web_server_group.add_argument( - "--certfile", - type=str, - default=None, - help="Web server: Path to certificate file to use for SSL. Use together with --keyfile", - ) - web_server_group.add_argument( - "--keyfile", - type=str, - default=None, - help="Web server: Path to private key file to use for SSL. Use together with --certfile", - ) - web_server_group.add_argument( - "--gui", - dest="gui", - action="store_true", - help="Start InvokeAI GUI", - ) - deprecated_group.add_argument( - "--autoimport", - default=None, - type=str, - help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly", - ) - deprecated_group.add_argument( - "--max_loaded_models", - dest="max_loaded_models", - type=int, - default=3, - help="Maximum number of models to keep in RAM cache (deprecated - use max_cache_size)", - ) - return parser - - # This creates the parser that processes commands on the invoke> command line - def _create_dream_cmd_parser(self): - parser = PagingArgumentParser( - formatter_class=ArgFormatter, - description=""" - *Image generation* - invoke> a fantastic alien landscape -W576 -H512 -s60 -n4 - - *postprocessing* - !fix applies upscaling/facefixing to a previously-generated image. - invoke> !fix 0000045.4829112.png -G1 -U4 -ft codeformer - - *embeddings* - invoke> !triggers -- return all trigger phrases contained in loaded embedding files - - *History manipulation* - !fetch retrieves the command used to generate an earlier image. Provide - a directory wildcard and the name of a file to write and all the commands - used to generate the images in the directory will be written to that file. - invoke> !fetch 0000015.8929913.png - invoke> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5 - invoke> !fetch /path/to/images/*.png prompts.txt - - !replay /path/to/prompts.txt - Replays all the prompts contained in the file prompts.txt. - - !history lists all the commands issued during the current session. - - !NN retrieves the NNth command from the history - - *Model manipulation* - !models -- list models in configs/models.yaml - !switch -- switch to model named - !import_model /path/to/weights/file.ckpt -- adds a .ckpt model to your config - !import_model /path/to/weights/ -- interactively import models from a directory - !import_model http://path_to_model.ckpt -- downloads and adds a .ckpt model to your config - !import_model hakurei/waifu-diffusion -- downloads and adds a diffusers model to your config - !optimize_model -- converts a .ckpt model to a diffusers model - !convert_model /path/to/weights/file.ckpt -- converts a .ckpt file path to a diffusers model - !edit_model -- edit a model's description - !del_model -- delete a model - """, - ) - render_group = parser.add_argument_group("General rendering") - img2img_group = parser.add_argument_group("Image-to-image and inpainting") - inpainting_group = parser.add_argument_group("Inpainting") - outpainting_group = parser.add_argument_group("Outpainting and outcropping") - variation_group = parser.add_argument_group("Creating and combining variations") - postprocessing_group = parser.add_argument_group("Post-processing") - special_effects_group = parser.add_argument_group("Special effects") - deprecated_group = parser.add_argument_group("Deprecated options") - render_group.add_argument( - "--prompt", - default="", - help="prompt string", - ) - render_group.add_argument("-s", "--steps", type=int, help="Number of steps") - render_group.add_argument( - "-S", - "--seed", - type=int, - default=None, - help="Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc", - ) - render_group.add_argument( - "-n", - "--iterations", - type=int, - default=1, - help="Number of samplings to perform (slower, but will provide seeds for individual images)", - ) - render_group.add_argument( - "-W", - "--width", - type=int, - help="Image width, multiple of 64", - ) - render_group.add_argument( - "-H", - "--height", - type=int, - help="Image height, multiple of 64", - ) - render_group.add_argument( - "-C", - "--cfg_scale", - type=float, - help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', - ) - render_group.add_argument( - "--threshold", - default=0.0, - type=float, - help='Latent threshold for classifier free guidance (CFG) - prevent generator from "trying" too hard. Use positive values, 0 disables.', - ) - render_group.add_argument( - "--perlin", - default=0.0, - type=float, - help="Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.", - ) - render_group.add_argument( - "--h_symmetry_time_pct", - default=None, - type=float, - help="Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.", - ) - render_group.add_argument( - "--v_symmetry_time_pct", - default=None, - type=float, - help="Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.", - ) - render_group.add_argument( - "--fnformat", - default="{prefix}.{seed}.png", - type=str, - help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png", - ) - render_group.add_argument( - "--grid", - "-g", - action=argparse.BooleanOptionalAction, - help="generate a grid", - ) - render_group.add_argument( - "-i", - "--individual", - action="store_true", - help="override command-line --grid setting and generate individual images", - ) - render_group.add_argument( - "-x", - "--skip_normalize", - action="store_true", - help="Skip subprompt weight normalization", - ) - render_group.add_argument( - "-A", - "-m", - "--sampler", - dest="sampler_name", - type=str, - choices=SAMPLER_CHOICES, - metavar="SAMPLER_NAME", - help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', - ) - render_group.add_argument( - "-t", - "--log_tokenization", - action="store_true", - help="shows how the prompt is split into tokens", - ) - render_group.add_argument( - "--outdir", - "-o", - type=str, - help="Directory to save generated images and a log of prompts and seeds", - ) - render_group.add_argument( - "--hires_fix", - action="store_true", - dest="hires_fix", - help="Create hires image using img2img to prevent duplicated objects", - ) - render_group.add_argument( - "--save_intermediates", - type=int, - default=0, - dest="save_intermediates", - help='Save every nth intermediate image into an "intermediates" directory within the output directory', - ) - render_group.add_argument( - "--png_compression", - "-z", - type=int, - choices=range(0, 10), - dest="png_compression", - help="level of PNG compression, from 0 (none) to 9 (maximum). [6]", - ) - render_group.add_argument( - "--karras_max", - type=int, - default=None, - help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].", - ) - img2img_group.add_argument( - "-I", - "--init_img", - type=str, - help="Path to input image for img2img mode (supersedes width and height)", - ) - img2img_group.add_argument( - "-tm", - "--text_mask", - nargs="+", - type=str, - help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).', - default=None, - ) - img2img_group.add_argument( - "--init_color", - type=str, - help="Path to reference image for color correction (used for repeated img2img and inpainting)", - ) - img2img_group.add_argument( - "-T", - "-fit", - "--fit", - action="store_true", - help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)", - ) - img2img_group.add_argument( - "-f", - "--strength", - type=float, - help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely", - ) - inpainting_group.add_argument( - "-M", - "--init_mask", - type=str, - help="Path to input mask for inpainting mode (supersedes width and height)", - ) - inpainting_group.add_argument( - "--invert_mask", - action="store_true", - help="Invert the mask", - ) - inpainting_group.add_argument( - "-r", - "--inpaint_replace", - type=float, - default=0.0, - help="when inpainting, adjust how aggressively to replace the part of the picture under the mask, from 0.0 (a gentle merge) to 1.0 (replace entirely)", - ) - outpainting_group.add_argument( - "-c", - "--outcrop", - nargs="+", - type=str, - metavar=("direction", "pixels"), - help="Outcrop the image with one or more direction/pixel pairs: e.g. -c top 64 bottom 128 left 64 right 64", - ) - outpainting_group.add_argument( - "--force_outpaint", - action="store_true", - default=False, - help="Force outpainting if you have no inpainting mask to pass", - ) - outpainting_group.add_argument( - "--seam_size", - type=int, - default=0, - help="When outpainting, size of the mask around the seam between original and outpainted image", - ) - outpainting_group.add_argument( - "--seam_blur", - type=int, - default=0, - help="When outpainting, the amount to blur the seam inwards", - ) - outpainting_group.add_argument( - "--seam_strength", - type=float, - default=0.7, - help="When outpainting, the img2img strength to use when filling the seam. Values around 0.7 work well", - ) - outpainting_group.add_argument( - "--seam_steps", - type=int, - default=10, - help="When outpainting, the number of steps to use to fill the seam. Low values (~10) work well", - ) - outpainting_group.add_argument( - "--tile_size", - type=int, - default=32, - help="When outpainting, the tile size to use for filling outpaint areas", - ) - postprocessing_group.add_argument( - "--new_prompt", - type=str, - help="Change the text prompt applied during postprocessing (default, use original generation prompt)", - ) - postprocessing_group.add_argument( - "-ft", - "--facetool", - type=str, - default="gfpgan", - help="Select the face restoration AI to use: gfpgan, codeformer", - ) - postprocessing_group.add_argument( - "-G", - "--facetool_strength", - "--gfpgan_strength", - type=float, - help="The strength at which to apply the face restoration to the result.", - default=0.0, - ) - postprocessing_group.add_argument( - "-cf", - "--codeformer_fidelity", - type=float, - help="Used along with CodeFormer. Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.", - default=0.75, - ) - postprocessing_group.add_argument( - "-U", - "--upscale", - nargs="+", - type=float, - help="Scale factor (1, 2, 3, 4, etc..) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75", - default=None, - ) - postprocessing_group.add_argument( - "--save_original", - "-save_orig", - action="store_true", - help="Save original. Use it when upscaling to save both versions.", - ) - postprocessing_group.add_argument( - "--embiggen", - "-embiggen", - nargs="+", - type=float, - help="Arbitrary upscaling using img2img. Provide scale factor (0.75), optionally followed by strength (0.75) and tile overlap proportion (0.25).", - default=None, - ) - postprocessing_group.add_argument( - "--embiggen_tiles", - "-embiggen_tiles", - nargs="+", - type=int, - help="For embiggen, provide list of tiles to process and replace onto the image e.g. `1 3 5`.", - default=None, - ) - postprocessing_group.add_argument( - "--embiggen_strength", - "-embiggen_strength", - type=float, - help="The strength of the embiggen img2img step, defaults to 0.4", - default=None, - ) - special_effects_group.add_argument( - "--seamless", - action="store_true", - help="Change the model to seamless tiling (circular) mode", - ) - special_effects_group.add_argument( - "--seamless_axes", - default=["x", "y"], - type=list[str], - help="Specify which axes to use circular convolution on.", - ) - variation_group.add_argument( - "-v", - "--variation_amount", - default=0.0, - type=float, - help="If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.", - ) - variation_group.add_argument( - "-V", - "--with_variations", - default=None, - type=str, - help="list of variations to apply, in the format `seed:weight,seed:weight,...", - ) - render_group.add_argument( - "--use_mps_noise", - action="store_true", - dest="use_mps_noise", - help="Simulate noise on M1 systems to get the same results", - ) - deprecated_group.add_argument( - "-D", - "--out_direction", - nargs="+", - type=str, - metavar=("direction", "pixels"), - help="Older outcropping system. Direction to extend the given image (left|right|top|bottom). If a distance pixel value is not specified it defaults to half the image size", - ) - return parser - - -def format_metadata(**kwargs): - logger.warning("format_metadata() is deprecated. Please use metadata_dumps()") - return metadata_dumps(kwargs) - - -def metadata_dumps(opt, seeds=[], model_hash=None, postprocessing=None): - """ - Given an Args object, returns a dict containing the keys and - structure of the proposed stable diffusion metadata standard - https://github.com/lstein/stable-diffusion/discussions/392 - This is intended to be turned into JSON and stored in the - "sd - """ - - # top-level metadata minus `image` or `images` - metadata = { - "model": "stable diffusion", - "model_id": opt.model, - "model_hash": model_hash, - "app_id": APP_ID, - "app_version": APP_VERSION, - } - - # # add some RFC266 fields that are generated internally, and not as - # # user args - image_dict = opt.to_dict(postprocessing=postprocessing) - - # remove any image keys not mentioned in RFC #266 - rfc266_img_fields = [ - "type", - "postprocessing", - "sampler", - "prompt", - "seed", - "variations", - "steps", - "cfg_scale", - "threshold", - "perlin", - "step_number", - "width", - "height", - "extra", - "strength", - "seamless" "init_img", - "init_mask", - "facetool", - "facetool_strength", - "upscale", - "h_symmetry_time_pct", - "v_symmetry_time_pct", - ] - rfc_dict = {} - - for item in image_dict.items(): - key, value = item - if key in rfc266_img_fields: - rfc_dict[key] = value - - # semantic drift - rfc_dict["sampler"] = image_dict.get("sampler_name", None) - - # display weighted subprompts (liable to change) - if opt.prompt: - subprompts = split_weighted_subprompts(opt.prompt) - subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] - rfc_dict["prompt"] = subprompts - - # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs - rfc_dict["variations"] = ( - [{"seed": x[0], "weight": x[1]} for x in opt.with_variations] - if opt.with_variations - else [] - ) - - # if variations are present then we need to replace 'seed' with 'orig_seed' - if hasattr(opt, "first_seed"): - rfc_dict["seed"] = opt.first_seed - - if opt.init_img: - rfc_dict["type"] = "img2img" - rfc_dict["strength_steps"] = rfc_dict.pop("strength") - rfc_dict["orig_hash"] = calculate_init_img_hash(opt.init_img) - rfc_dict["inpaint_replace"] = opt.inpaint_replace - else: - rfc_dict["type"] = "txt2img" - rfc_dict.pop("strength") - - if len(seeds) == 0 and opt.seed: - seeds = [opt.seed] - - if opt.grid: - images = [] - for seed in seeds: - rfc_dict["seed"] = seed - images.append(copy.copy(rfc_dict)) - metadata["images"] = images - else: - # there should only ever be a single seed if we did not generate a grid - assert len(seeds) == 1, "Expected a single seed" - rfc_dict["seed"] = seeds[0] - metadata["image"] = rfc_dict - - return metadata - - -@functools.lru_cache(maxsize=50) -def args_from_png(png_file_path) -> list[Args]: - """ - Given the path to a PNG file created by invoke.py, - retrieves a list of Args objects containing the image - data. - """ - try: - meta = retrieve_metadata(png_file_path) - except AttributeError: - return [legacy_metadata_load({}, png_file_path)] - - try: - return metadata_loads(meta) - except: - return [legacy_metadata_load(meta, png_file_path)] - - -@functools.lru_cache(maxsize=50) -def metadata_from_png(png_file_path) -> Args: - """ - Given the path to a PNG file created by dream.py, retrieves - an Args object containing the image metadata. Note that this - returns a single Args object, not multiple. - """ - args_list = args_from_png(png_file_path) - return args_list[0] if len(args_list) > 0 else Args() # empty args - - -def dream_cmd_from_png(png_file_path): - opt = metadata_from_png(png_file_path) - return opt.dream_prompt_str() - - -def metadata_loads(metadata) -> list: - """ - Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266) - and returns a series of opt objects for each of the images described in the dictionary. Note that this - returns a list, and not a single object. See metadata_from_png() for a more convenient function for - files that contain a single image. - """ - results = [] - try: - if "images" in metadata["sd-metadata"]: - images = metadata["sd-metadata"]["images"] - else: - images = [metadata["sd-metadata"]["image"]] - for image in images: - # repack the prompt and variations - if "prompt" in image: - image["prompt"] = repack_prompt(image["prompt"]) - if "variations" in image: - image["variations"] = ",".join( - [ - ":".join([str(x["seed"]), str(x["weight"])]) - for x in image["variations"] - ] - ) - # fix a bit of semantic drift here - image["sampler_name"] = image.pop("sampler") - opt = Args() - opt._cmd_switches = Namespace(**image) - results.append(opt) - except Exception: - import sys - import traceback - - logger.error("Could not read metadata") - print(traceback.format_exc(), file=sys.stderr) - return results - - -def repack_prompt(prompt_list: list) -> str: - # in the common case of no weighting syntax, just return the prompt as is - if len(prompt_list) > 1: - return ",".join( - [":".join([x["prompt"], str(x["weight"])]) for x in prompt_list] - ) - else: - return prompt_list[0]["prompt"] - - -# image can either be a file path on disk or a base64-encoded -# representation of the file's contents -def calculate_init_img_hash(image_string): - prefix = "data:image/png;base64," - hash = None - if image_string.startswith(prefix): - imagebase64 = image_string[len(prefix) :] - imagedata = base64.b64decode(imagebase64) - with open("outputs/test.png", "wb") as file: - file.write(imagedata) - sha = hashlib.sha256() - sha.update(imagedata) - hash = sha.hexdigest() - else: - hash = sha256(image_string) - return hash - - -# Bah. This should be moved somewhere else... -def sha256(path): - sha = hashlib.sha256() - with open(path, "rb") as f: - while True: - data = f.read(65536) - if not data: - break - sha.update(data) - return sha.hexdigest() - - -def legacy_metadata_load(meta, pathname) -> Args: - opt = Args() - if "Dream" in meta and len(meta["Dream"]) > 0: - dream_prompt = meta["Dream"] - opt.parse_cmd(dream_prompt) - else: # if nothing else, we can get the seed - match = re.search("\d+\.(\d+)", pathname) - if match: - seed = match.groups()[0] - opt.seed = seed - else: - opt.prompt = "" - opt.seed = 0 - return opt diff --git a/invokeai/backend/config/invokeai_configure.py b/invokeai/backend/config/invokeai_configure.py index f95c65cc6c..59f11d35bc 100755 --- a/invokeai/backend/config/invokeai_configure.py +++ b/invokeai/backend/config/invokeai_configure.py @@ -19,10 +19,10 @@ import warnings from argparse import Namespace from pathlib import Path from shutil import get_terminal_size +from typing import get_type_hints from urllib import request import npyscreen -import torch import transformers from diffusers import AutoencoderKL from huggingface_hub import HfFolder @@ -38,34 +38,40 @@ from transformers import ( import invokeai.configs as configs -from ...frontend.install.model_install import addModelsForm, process_and_execute -from ...frontend.install.widgets import ( +from invokeai.frontend.install.model_install import addModelsForm, process_and_execute +from invokeai.frontend.install.widgets import ( CenteredButtonPress, IntTitleSlider, set_min_terminal_size, ) -from ..args import PRECISION_CHOICES, Args -from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file -from .model_install_backend import ( +from invokeai.backend.config.legacy_arg_parsing import legacy_parser +from invokeai.backend.config.model_install_backend import ( default_dataset, download_from_hf, hf_download_with_resume, recommended_datasets, ) +from invokeai.app.services.config import ( + get_invokeai_config, + InvokeAIAppConfig, +) warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() + # --------------------------globals----------------------- +config = get_invokeai_config() + Model_dir = "models" Weights_dir = "ldm/stable-diffusion-v1/" # the initial "configs" dir is now bundled in the `invokeai.configs` package Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" -Default_config_file = Path(global_config_dir()) / "models.yaml" -SD_Configs = Path(global_config_dir()) / "stable-diffusion" +Default_config_file = config.model_conf_path +SD_Configs = config.legacy_conf_path Datasets = OmegaConf.load(Dataset_path) @@ -73,17 +79,12 @@ Datasets = OmegaConf.load(Dataset_path) MIN_COLS = 135 MIN_LINES = 45 +PRECISION_CHOICES = ['auto','float16','float32','autocast'] + INIT_FILE_PREAMBLE = """# InvokeAI initialization file # This is the InvokeAI initialization file, which contains command-line default values. # Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting # or renaming it and then running invokeai-configure again. -# Place frequently-used startup commands here, one or more per line. -# Examples: -# --outdir=D:\data\images -# --no-nsfw_checker -# --web --host=0.0.0.0 -# --steps=20 -# -Ak_euler_a -C10.0 """ @@ -96,14 +97,13 @@ If you installed manually from source or with 'pip install': activate the virtua then run one of the following commands to start InvokeAI. Web UI: - invokeai --web # (connect to http://localhost:9090) - invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network) + invokeai-web -Command-line interface: +Command-line client: invokeai If you installed using an installation script, run: - {Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"} + {config.root}/invoke.{"bat" if sys.platform == "win32" else "sh"} Add the '--help' argument to see all of the command-line switches available for use. """ @@ -216,11 +216,11 @@ def download_realesrgan(): wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" model_dest = os.path.join( - Globals.root, "models/realesrgan/realesr-general-x4v3.pth" + config.root, "models/realesrgan/realesr-general-x4v3.pth" ) wdn_model_dest = os.path.join( - Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth" + config.root, "models/realesrgan/realesr-general-wdn-x4v3.pth" ) download_with_progress_bar(model_url, model_dest, "RealESRGAN") @@ -243,7 +243,7 @@ def download_gfpgan(): "./models/gfpgan/weights/parsing_parsenet.pth", ], ): - model_url, model_dest = model[0], os.path.join(Globals.root, model[1]) + model_url, model_dest = model[0], os.path.join(config.root, model[1]) download_with_progress_bar(model_url, model_dest, "GFPGAN weights") @@ -253,7 +253,7 @@ def download_codeformer(): model_url = ( "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" ) - model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth") + model_dest = os.path.join(config.root, "models/codeformer/codeformer.pth") download_with_progress_bar(model_url, model_dest, "CodeFormer") @@ -295,7 +295,7 @@ def download_vaes(): # first the diffusers version repo_id = "stabilityai/sd-vae-ft-mse" args = dict( - cache_dir=global_cache_dir("hub"), + cache_dir=config.cache_dir, ) if not AutoencoderKL.from_pretrained(repo_id, **args): raise Exception(f"download of {repo_id} failed") @@ -306,7 +306,7 @@ def download_vaes(): if not hf_download_with_resume( repo_id=repo_id, model_name=model_name, - model_dir=str(Globals.root / Model_dir / Weights_dir), + model_dir=str(config.root / Model_dir / Weights_dir), ): raise Exception(f"download of {model_name} failed") except Exception as e: @@ -321,8 +321,7 @@ def get_root(root: str = None) -> str: elif os.environ.get("INVOKEAI_ROOT"): return os.environ.get("INVOKEAI_ROOT") else: - return Globals.root - + return config.root # ------------------------------------- class editOptsForm(npyscreen.FormMultiPage): @@ -332,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage): def create(self): program_opts = self.parentApp.program_opts old_opts = self.parentApp.invokeai_opts - first_time = not (Globals.root / Globals.initfile).exists() + first_time = not (config.root / 'invokeai.yaml').exists() access_token = HfFolder.get_token() window_width, window_height = get_terminal_size() for i in [ @@ -366,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage): self.outdir = self.add_widget_intelligent( npyscreen.TitleFilename, name="( autocompletes, ctrl-N advances):", - value=old_opts.outdir or str(default_output_dir()), + value=str(old_opts.outdir) or str(default_output_dir()), select_dir=True, must_exist=False, use_two_lines=False, @@ -381,17 +380,17 @@ class editOptsForm(npyscreen.FormMultiPage): editable=False, color="CONTROL", ) - self.safety_checker = self.add_widget_intelligent( + self.nsfw_checker = self.add_widget_intelligent( npyscreen.Checkbox, name="NSFW checker", - value=old_opts.safety_checker, + value=old_opts.nsfw_checker, relx=5, scroll_exit=True, ) self.nextrely += 1 for i in [ - "If you have an account at HuggingFace you may paste your access token here", - 'to allow InvokeAI to download styles & subjects from the "Concept Library".', + "If you have an account at HuggingFace you may optionally paste your access token here", + 'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".', "See https://huggingface.co/settings/tokens", ]: self.add_widget_intelligent( @@ -435,17 +434,10 @@ class editOptsForm(npyscreen.FormMultiPage): relx=5, scroll_exit=True, ) - self.xformers = self.add_widget_intelligent( + self.xformers_enabled = self.add_widget_intelligent( npyscreen.Checkbox, name="Enable xformers support if available", - value=old_opts.xformers, - relx=5, - scroll_exit=True, - ) - self.ckpt_convert = self.add_widget_intelligent( - npyscreen.Checkbox, - name="Load legacy checkpoint models into memory as diffusers models", - value=old_opts.ckpt_convert, + value=old_opts.xformers_enabled, relx=5, scroll_exit=True, ) @@ -480,19 +472,30 @@ class editOptsForm(npyscreen.FormMultiPage): self.nextrely += 1 self.add_widget_intelligent( npyscreen.FixedText, - value="Directory containing embedding/textual inversion files:", + value="Directories containing textual inversion and LoRA models ( autocompletes, ctrl-N advances):", editable=False, color="CONTROL", ) - self.embedding_path = self.add_widget_intelligent( + self.embedding_dir = self.add_widget_intelligent( npyscreen.TitleFilename, - name="( autocompletes, ctrl-N advances):", + name=" Textual Inversion Embeddings:", value=str(default_embedding_dir()), select_dir=True, must_exist=False, use_two_lines=False, labelColor="GOOD", - begin_entry_at=40, + begin_entry_at=32, + scroll_exit=True, + ) + self.lora_dir = self.add_widget_intelligent( + npyscreen.TitleFilename, + name=" LoRA and LyCORIS:", + value=str(default_lora_dir()), + select_dir=True, + must_exist=False, + use_two_lines=False, + labelColor="GOOD", + begin_entry_at=32, scroll_exit=True, ) self.nextrely += 1 @@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage): bad_fields.append( f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory." ) - if not Path(opt.embedding_path).parent.exists(): + if not Path(opt.embedding_dir).parent.exists(): bad_fields.append( - f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory." + f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory." ) if len(bad_fields) > 0: message = "The following problems were detected and must be corrected:\n" @@ -576,20 +579,23 @@ class editOptsForm(npyscreen.FormMultiPage): new_opts = Namespace() for attr in [ - "outdir", - "safety_checker", - "free_gpu_mem", - "max_loaded_models", - "xformers", - "always_use_cpu", - "embedding_path", - "ckpt_convert", + "outdir", + "nsfw_checker", + "free_gpu_mem", + "max_loaded_models", + "xformers_enabled", + "always_use_cpu", + "embedding_dir", + "lora_dir", ]: setattr(new_opts, attr, getattr(self, attr).value) new_opts.hf_token = self.hf_token.value new_opts.license_acceptance = self.license_acceptance.value new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] + + # widget library workaround to make max_loaded_models an int rather than a float + new_opts.max_loaded_models = int(new_opts.max_loaded_models) return new_opts @@ -628,15 +634,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam def default_startup_options(init_file: Path) -> Namespace: - opts = Args().parse_args([]) + opts = InvokeAIAppConfig(argv=[]) outdir = Path(opts.outdir) if not outdir.is_absolute(): - opts.outdir = str(Globals.root / opts.outdir) + opts.outdir = str(config.root / opts.outdir) if not init_file.exists(): - opts.safety_checker = True + opts.nsfw_checker = True return opts - def default_user_selections(program_opts: Namespace) -> Namespace: return Namespace( starter_models=default_dataset() @@ -690,70 +695,61 @@ def run_console_ui( # ------------------------------------- def write_opts(opts: Namespace, init_file: Path): """ - Update the invokeai.init file with values from opts Namespace + Update the invokeai.yaml file with values from current settings. """ - # touch file if it doesn't exist - if not init_file.exists(): - with open(init_file, "w") as f: - f.write(INIT_FILE_PREAMBLE) - # We want to write in the changed arguments without clobbering - # any other initialization values the user has entered. There is - # no good way to do this because of the one-way nature of - # argparse: i.e. --outdir could be --outdir, --out, or -o - # initfile needs to be replaced with a fully structured format - # such as yaml; this is a hack that will work much of the time - args_to_skip = re.compile( - "^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)" - ) - # fix windows paths - opts.outdir = opts.outdir.replace("\\", "/") - opts.embedding_path = opts.embedding_path.replace("\\", "/") - new_file = f"{init_file}.new" - try: - lines = [x.strip() for x in open(init_file, "r").readlines()] - with open(new_file, "w") as out_file: - for line in lines: - if len(line) > 0 and not args_to_skip.match(line): - out_file.write(line + "\n") - out_file.write( - f""" ---outdir={opts.outdir} ---embedding_path={opts.embedding_path} ---precision={opts.precision} ---max_loaded_models={int(opts.max_loaded_models)} ---{'no-' if not opts.safety_checker else ''}nsfw_checker ---{'no-' if not opts.xformers else ''}xformers ---{'no-' if not opts.ckpt_convert else ''}ckpt_convert -{'--free_gpu_mem' if opts.free_gpu_mem else ''} -{'--always_use_cpu' if opts.always_use_cpu else ''} -""" - ) - except OSError as e: - print(f"** An error occurred while writing the init file: {str(e)}") - - os.replace(new_file, init_file) - - if opts.hf_token: - HfLogin(opts.hf_token) + # this will load current settings + config = InvokeAIAppConfig() + for key,value in opts.__dict__.items(): + if hasattr(config,key): + setattr(config,key,value) + with open(init_file,'w', encoding='utf-8') as file: + file.write(config.to_yaml()) # ------------------------------------- def default_output_dir() -> Path: - return Globals.root / "outputs" - + return config.root / "outputs" # ------------------------------------- def default_embedding_dir() -> Path: - return Globals.root / "embeddings" + return config.root / "embeddings" +# ------------------------------------- +def default_lora_dir() -> Path: + return config.root / "loras" # ------------------------------------- def write_default_options(program_opts: Namespace, initfile: Path): opt = default_startup_options(initfile) - opt.hf_token = HfFolder.get_token() write_opts(opt, initfile) +# ------------------------------------- +# Here we bring in +# the legacy Args object in order to parse +# the old init file and write out the new +# yaml format. +def migrate_init_file(legacy_format:Path): + old = legacy_parser.parse_args([f'@{str(legacy_format)}']) + new = InvokeAIAppConfig(conf={}) + + fields = list(get_type_hints(InvokeAIAppConfig).keys()) + for attr in fields: + if hasattr(old,attr): + setattr(new,attr,getattr(old,attr)) + + # a few places where the field names have changed and we have to + # manually add in the new names/values + new.nsfw_checker = old.safety_checker + new.xformers_enabled = old.xformers + new.conf_path = old.conf + new.embedding_dir = old.embedding_path + + invokeai_yaml = legacy_format.parent / 'invokeai.yaml' + with open(invokeai_yaml,"w", encoding="utf-8") as outfile: + outfile.write(new.to_yaml()) + + legacy_format.replace(legacy_format.parent / 'invokeai.init.old') # ------------------------------------- def main(): @@ -810,7 +806,8 @@ def main(): opt = parser.parse_args() # setting a global here - Globals.root = Path(os.path.expanduser(get_root(opt.root) or "")) + global config + config.root = Path(os.path.expanduser(get_root(opt.root) or "")) errors = set() @@ -818,19 +815,26 @@ def main(): models_to_download = default_user_selections(opt) # We check for to see if the runtime directory is correctly initialized. - init_file = Path(Globals.root, Globals.initfile) - if not init_file.exists() or not global_config_file().exists(): - initialize_rootdir(Globals.root, opt.yes_to_all) + old_init_file = Path(config.root, 'invokeai.init') + new_init_file = Path(config.root, 'invokeai.yaml') + if old_init_file.exists() and not new_init_file.exists(): + print('** Migrating invokeai.init to invokeai.yaml') + migrate_init_file(old_init_file) + config = get_invokeai_config() # reread defaults + + + if not config.model_conf_path.exists(): + initialize_rootdir(config.root, opt.yes_to_all) if opt.yes_to_all: - write_default_options(opt, init_file) + write_default_options(opt, new_init_file) init_options = Namespace( precision="float32" if opt.full_precision else "float16" ) else: - init_options, models_to_download = run_console_ui(opt, init_file) + init_options, models_to_download = run_console_ui(opt, new_init_file) if init_options: - write_opts(init_options, init_file) + write_opts(init_options, new_init_file) else: print( '\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n' diff --git a/invokeai/backend/config/legacy_arg_parsing.py b/invokeai/backend/config/legacy_arg_parsing.py new file mode 100644 index 0000000000..85ca588fe2 --- /dev/null +++ b/invokeai/backend/config/legacy_arg_parsing.py @@ -0,0 +1,390 @@ +# Copyright 2023 Lincoln D. Stein and the InvokeAI Team + +import argparse +import shlex +from argparse import ArgumentParser + +SAMPLER_CHOICES = [ + "ddim", + "ddpm", + "deis", + "lms", + "pndm", + "heun", + "heun_k", + "euler", + "euler_k", + "euler_a", + "kdpm_2", + "kdpm_2_a", + "dpmpp_2s", + "dpmpp_2m", + "dpmpp_2m_k", + "unipc", +] + +PRECISION_CHOICES = [ + "auto", + "float32", + "autocast", + "float16", +] + +class FileArgumentParser(ArgumentParser): + """ + Supports reading defaults from an init file. + """ + def convert_arg_line_to_args(self, arg_line): + return shlex.split(arg_line, comments=True) + + +legacy_parser = FileArgumentParser( + description= + """ +Generate images using Stable Diffusion. + Use --web to launch the web interface. + Use --from_file to load prompts from a file path or standard input ("-"). + Otherwise you will be dropped into an interactive command prompt (type -h for help.) + Other command-line arguments are defaults that can usually be overridden + prompt the command prompt. + """, + fromfile_prefix_chars='@', +) +general_group = legacy_parser.add_argument_group('General') +model_group = legacy_parser.add_argument_group('Model selection') +file_group = legacy_parser.add_argument_group('Input/output') +web_server_group = legacy_parser.add_argument_group('Web server') +render_group = legacy_parser.add_argument_group('Rendering') +postprocessing_group = legacy_parser.add_argument_group('Postprocessing') +deprecated_group = legacy_parser.add_argument_group('Deprecated options') + +deprecated_group.add_argument('--laion400m') +deprecated_group.add_argument('--weights') # deprecated +general_group.add_argument( + '--version','-V', + action='store_true', + help='Print InvokeAI version number' +) +model_group.add_argument( + '--root_dir', + default=None, + help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.', +) +model_group.add_argument( + '--config', + '-c', + '-config', + dest='conf', + default='./configs/models.yaml', + help='Path to configuration file for alternate models.', +) +model_group.add_argument( + '--model', + help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', +) +model_group.add_argument( + '--weight_dirs', + nargs='+', + type=str, + help='List of one or more directories that will be auto-scanned for new model weights to import', +) +model_group.add_argument( + '--png_compression','-z', + type=int, + default=6, + choices=range(0,9), + dest='png_compression', + help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' +) +model_group.add_argument( + '-F', + '--full_precision', + dest='full_precision', + action='store_true', + help='Deprecated way to set --precision=float32', +) +model_group.add_argument( + '--max_loaded_models', + dest='max_loaded_models', + type=int, + default=2, + help='Maximum number of models to keep in memory for fast switching, including the one in GPU', +) +model_group.add_argument( + '--free_gpu_mem', + dest='free_gpu_mem', + action='store_true', + help='Force free gpu memory before final decoding', +) +model_group.add_argument( + '--sequential_guidance', + dest='sequential_guidance', + action='store_true', + help="Calculate guidance in serial instead of in parallel, lowering memory requirement " + "at the expense of speed", +) +model_group.add_argument( + '--xformers', + action=argparse.BooleanOptionalAction, + default=True, + help='Enable/disable xformers support (default enabled if installed)', +) +model_group.add_argument( + "--always_use_cpu", + dest="always_use_cpu", + action="store_true", + help="Force use of CPU even if GPU is available" +) +model_group.add_argument( + '--precision', + dest='precision', + type=str, + choices=PRECISION_CHOICES, + metavar='PRECISION', + help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', + default='auto', +) +model_group.add_argument( + '--ckpt_convert', + action=argparse.BooleanOptionalAction, + dest='ckpt_convert', + default=True, + help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.' +) +model_group.add_argument( + '--internet', + action=argparse.BooleanOptionalAction, + dest='internet_available', + default=True, + help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).', +) +model_group.add_argument( + '--nsfw_checker', + '--safety_checker', + action=argparse.BooleanOptionalAction, + dest='safety_checker', + default=False, + help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.', +) +model_group.add_argument( + '--autoimport', + default=None, + type=str, + help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly', +) +model_group.add_argument( + '--autoconvert', + default=None, + type=str, + help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models', +) +model_group.add_argument( + '--patchmatch', + action=argparse.BooleanOptionalAction, + default=True, + help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.', +) +file_group.add_argument( + '--from_file', + dest='infile', + type=str, + help='If specified, load prompts from this file', +) +file_group.add_argument( + '--outdir', + '-o', + type=str, + help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs', + default='outputs', +) +file_group.add_argument( + '--prompt_as_dir', + '-p', + action='store_true', + help='Place images in subdirectories named after the prompt.', +) +render_group.add_argument( + '--fnformat', + default='{prefix}.{seed}.png', + type=str, + help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png', +) +render_group.add_argument( + '-s', + '--steps', + type=int, + default=50, + help='Number of steps' +) +render_group.add_argument( + '-W', + '--width', + type=int, + help='Image width, multiple of 64', +) +render_group.add_argument( + '-H', + '--height', + type=int, + help='Image height, multiple of 64', +) +render_group.add_argument( + '-C', + '--cfg_scale', + default=7.5, + type=float, + help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', +) +render_group.add_argument( + '--sampler', + '-A', + '-m', + dest='sampler_name', + type=str, + choices=SAMPLER_CHOICES, + metavar='SAMPLER_NAME', + help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', + default='k_lms', +) +render_group.add_argument( + '--log_tokenization', + '-t', + action='store_true', + help='shows how the prompt is split into tokens' +) +render_group.add_argument( + '-f', + '--strength', + type=float, + help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', +) +render_group.add_argument( + '-T', + '-fit', + '--fit', + action=argparse.BooleanOptionalAction, + help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)', +) + +render_group.add_argument( + '--grid', + '-g', + action=argparse.BooleanOptionalAction, + help='generate a grid' +) +render_group.add_argument( + '--embedding_directory', + '--embedding_path', + dest='embedding_path', + default='embeddings', + type=str, + help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)' +) +render_group.add_argument( + '--lora_directory', + dest='lora_path', + default='loras', + type=str, + help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)' +) +render_group.add_argument( + '--embeddings', + action=argparse.BooleanOptionalAction, + default=True, + help='Enable embedding directory (default). Use --no-embeddings to disable.', +) +render_group.add_argument( + '--enable_image_debugging', + action='store_true', + help='Generates debugging image to display' +) +render_group.add_argument( + '--karras_max', + type=int, + default=None, + help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]." +) +# Restoration related args +postprocessing_group.add_argument( + '--no_restore', + dest='restore', + action='store_false', + help='Disable face restoration with GFPGAN or codeformer', +) +postprocessing_group.add_argument( + '--no_upscale', + dest='esrgan', + action='store_false', + help='Disable upscaling with ESRGAN', +) +postprocessing_group.add_argument( + '--esrgan_bg_tile', + type=int, + default=400, + help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', +) +postprocessing_group.add_argument( + '--esrgan_denoise_str', + type=float, + default=0.75, + help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75', +) +postprocessing_group.add_argument( + '--gfpgan_model_path', + type=str, + default='./models/gfpgan/GFPGANv1.4.pth', + help='Indicates the path to the GFPGAN model', +) +web_server_group.add_argument( + '--web', + dest='web', + action='store_true', + help='Start in web server mode.', +) +web_server_group.add_argument( + '--web_develop', + dest='web_develop', + action='store_true', + help='Start in web server development mode.', +) +web_server_group.add_argument( + "--web_verbose", + action="store_true", + help="Enables verbose logging", +) +web_server_group.add_argument( + "--cors", + nargs="*", + type=str, + help="Additional allowed origins, comma-separated", +) +web_server_group.add_argument( + '--host', + type=str, + default='127.0.0.1', + help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.' +) +web_server_group.add_argument( + '--port', + type=int, + default='9090', + help='Web server: Port to listen on' +) +web_server_group.add_argument( + '--certfile', + type=str, + default=None, + help='Web server: Path to certificate file to use for SSL. Use together with --keyfile' +) +web_server_group.add_argument( + '--keyfile', + type=str, + default=None, + help='Web server: Path to private key file to use for SSL. Use together with --certfile' +) +web_server_group.add_argument( + '--gui', + dest='gui', + action='store_true', + help='Start InvokeAI GUI', +) diff --git a/invokeai/backend/config/model_install_backend.py b/invokeai/backend/config/model_install_backend.py index 2018cd42af..cb76f955bc 100644 --- a/invokeai/backend/config/model_install_backend.py +++ b/invokeai/backend/config/model_install_backend.py @@ -19,13 +19,15 @@ from tqdm import tqdm import invokeai.configs as configs -from ..globals import Globals, global_cache_dir, global_config_dir +from invokeai.app.services.config import get_invokeai_config from ..model_management import ModelManager from ..stable_diffusion import StableDiffusionGeneratorPipeline + warnings.filterwarnings("ignore") # --------------------------globals----------------------- +config = get_invokeai_config() Model_dir = "models" Weights_dir = "ldm/stable-diffusion-v1/" @@ -47,12 +49,11 @@ Config_preamble = """ def default_config_file(): - return Path(global_config_dir()) / "models.yaml" + return config.model_conf_path def sd_configs(): - return Path(global_config_dir()) / "stable-diffusion" - + return config.legacy_conf_path def initial_models(): global Datasets @@ -121,8 +122,9 @@ def install_requested_models( if scan_at_startup and scan_directory.is_dir(): argument = "--autoconvert" - initfile = Path(Globals.root, Globals.initfile) - replacement = Path(Globals.root, f"{Globals.initfile}.new") + print('** The global initfile is no longer supported; rewrite to support new yaml format **') + initfile = Path(config.root, 'invokeai.init') + replacement = Path(config.root, f"invokeai.init.new") directory = str(scan_directory).replace("\\", "/") with open(initfile, "r") as input: with open(replacement, "w") as output: @@ -150,7 +152,7 @@ def get_root(root: str = None) -> str: elif os.environ.get("INVOKEAI_ROOT"): return os.environ.get("INVOKEAI_ROOT") else: - return Globals.root + return config.root # --------------------------------------------- @@ -183,7 +185,7 @@ def all_datasets() -> dict: # look for legacy model.ckpt in models directory and offer to # normalize its name def migrate_models_ckpt(): - model_path = os.path.join(Globals.root, Model_dir, Weights_dir) + model_path = os.path.join(config.root, Model_dir, Weights_dir) if not os.path.exists(os.path.join(model_path, "model.ckpt")): return new_name = initial_models()["stable-diffusion-1.4"]["file"] @@ -228,7 +230,7 @@ def _download_repo_or_file( def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path: repo_id = mconfig["repo_id"] filename = mconfig["file"] - cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir) + cache_dir = os.path.join(config.root, Model_dir, Weights_dir) return hf_download_with_resume( repo_id=repo_id, model_dir=cache_dir, @@ -239,9 +241,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path: # --------------------------------------------- def download_from_hf( - model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs + model_class: object, model_name: str, **kwargs ): - path = global_cache_dir(cache_subdir) + path = config.cache_dir model = model_class.from_pretrained( model_name, cache_dir=path, @@ -417,7 +419,7 @@ def new_config_file_contents( stanza["height"] = mod["height"] if "file" in mod: stanza["weights"] = os.path.relpath( - successfully_downloaded[model], start=Globals.root + successfully_downloaded[model], start=config.root ) stanza["config"] = os.path.normpath( os.path.join(sd_configs(), mod["config"]) @@ -456,7 +458,7 @@ def delete_weights(model_name: str, conf_stanza: dict): weights = Path(weights) if not weights.is_absolute(): - weights = Path(Globals.root) / weights + weights = Path(config.root) / weights try: weights.unlink() except OSError as e: diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py deleted file mode 100644 index 4682d1a1ca..0000000000 --- a/invokeai/backend/generate.py +++ /dev/null @@ -1,1237 +0,0 @@ -# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) -# Derived from source code carrying the following copyrights -# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich -# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors - -import gc -import importlib -import logging -import os -import random -import re -import sys -import time -import traceback -from typing import List - -import cv2 -import diffusers -import numpy as np -import skimage -import torch -import transformers -from PIL import Image, ImageOps -from accelerate.utils import set_seed -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.utils.import_utils import is_xformers_available -from omegaconf import OmegaConf -from pathlib import Path - -import invokeai.backend.util.logging as logger -from .args import metadata_from_png -from .generator import infill_methods -from .globals import Globals, global_cache_dir -from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding -from .model_management import ModelManager -from .safety_checker import SafetyChecker -from .prompting import get_uc_and_c_and_ec -from .prompting.conditioning import log_tokenization -from .stable_diffusion import HuggingFaceConceptsLibrary -from .stable_diffusion.schedulers import SCHEDULER_MAP -from .util import choose_precision, choose_torch_device, torch_dtype - -def fix_func(orig): - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - - def new_func(*args, **kw): - device = kw.get("device", "mps") - kw["device"] = "cpu" - return orig(*args, **kw).to(device) - - return new_func - return orig - -torch.rand = fix_func(torch.rand) -torch.rand_like = fix_func(torch.rand_like) -torch.randn = fix_func(torch.randn) -torch.randn_like = fix_func(torch.randn_like) -torch.randint = fix_func(torch.randint) -torch.randint_like = fix_func(torch.randint_like) -torch.bernoulli = fix_func(torch.bernoulli) -torch.multinomial = fix_func(torch.multinomial) - -# this is fallback model in case no default is defined -FALLBACK_MODEL_NAME = "stable-diffusion-1.5" - -"""Simplified text to image API for stable diffusion/latent diffusion - -Example Usage: - -from ldm.generate import Generate - -# Create an object with default values -gr = Generate('stable-diffusion-1.4') - -# do the slow model initialization -gr.load_model() - -# Do the fast inference & image generation. Any options passed here -# override the default values assigned during class initialization -# Will call load_model() if the model was not previously loaded and so -# may be slow at first. -# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] -results = gr.prompt2png(prompt = "an astronaut riding a horse", - outdir = "./outputs/samples", - iterations = 3) - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') - -# Same thing, but using an initial image. -results = gr.prompt2png(prompt = "an astronaut riding a horse", - outdir = "./outputs/, - iterations = 3, - init_img = "./sketches/horse+rider.png") - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') - -# Same thing, but we return a series of Image objects, which lets you manipulate them, -# combine them, and save them under arbitrary names - -results = gr.prompt2image(prompt = "an astronaut riding a horse" - outdir = "./outputs/") -for row in results: - im = row[0] - seed = row[1] - im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png') - im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg') - -Note that the old txt2img() and img2img() calls are deprecated but will -still work. - -The full list of arguments to Generate() are: -gr = Generate( - # these values are set once and shouldn't be changed - conf:str = path to configuration file ('configs/models.yaml') - model:str = symbolic name of the model in the configuration file - precision:float = float precision to be used - safety_checker:bool = activate safety checker [False] - - # this value is sticky and maintained between generation calls - sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_dpmpp_2', 'k_dpmpp_2_a', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - - # these are deprecated - use conf and model instead - weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') - config = path to model configuration ('configs/stable-diffusion/v1-inference.yaml') - ) - -""" - - -class Generate: - """Generate class - Stores default values for multiple configuration items - """ - - def __init__( - self, - model=None, - conf="configs/models.yaml", - embedding_path=None, - sampler_name="lms", - ddim_eta=0.0, # deterministic - full_precision=False, - precision="auto", - outdir="outputs/img-samples", - gfpgan=None, - codeformer=None, - esrgan=None, - free_gpu_mem: bool = False, - safety_checker: bool = False, - max_cache_size: int = 6, - # these are deprecated; if present they override values in the conf file - weights=None, - config=None, - ): - self.height = None - self.width = None - self.model_manager = None - self.iterations = 1 - self.steps = 50 - self.cfg_scale = 7.5 - self.sampler_name = sampler_name - self.ddim_eta = ddim_eta # same seed always produces same image - self.precision = precision - self.strength = 0.75 - self.seamless = False - self.seamless_axes = {"x", "y"} - self.hires_fix = False - self.embedding_path = embedding_path - self.model_context = None # empty for now - self.model_hash = None - self.sampler = None - self.device = None - self.max_memory_allocated = 0 - self.memory_allocated = 0 - self.session_peakmem = 0 - self.base_generator = None - self.seed = None - self.outdir = outdir - self.gfpgan = gfpgan - self.codeformer = codeformer - self.esrgan = esrgan - self.free_gpu_mem = free_gpu_mem - self.max_cache_size = max_cache_size - self.size_matters = True # used to warn once about large image sizes and VRAM - self.txt2mask = None - self.safety_checker = None - self.karras_max = None - self.infill_method = None - - # Note that in previous versions, there was an option to pass the - # device to Generate(). However the device was then ignored, so - # it wasn't actually doing anything. This logic could be reinstated. - self.device = torch.device(choose_torch_device()) - logger.info(f"Using device_type {self.device.type}") - if full_precision: - if self.precision != "auto": - raise ValueError("Remove --full_precision / -F if using --precision") - logger.warning("Please remove deprecated --full_precision / -F") - logger.warning("If auto config does not work you can use --precision=float32") - self.precision = "float32" - if self.precision == "auto": - self.precision = choose_precision(self.device) - Globals.full_precision = self.precision == "float32" - - if is_xformers_available(): - if torch.cuda.is_available() and not Globals.disable_xformers: - logger.info("xformers memory-efficient attention is available and enabled") - else: - logger.info( - "xformers memory-efficient attention is available but disabled" - ) - else: - logger.info("xformers not installed") - - # model caching system for fast switching - self.model_manager = ModelManager( - conf, - self.device, - torch_dtype(self.device), - max_cache_size=max_cache_size, - sequential_offload=self.free_gpu_mem, -# embedding_path=Path(self.embedding_path), - ) - # don't accept invalid models - fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME - model = model or fallback - if not self.model_manager.model_exists(model): - logger.warning( - f'"{model}" is not a known model name; falling back to {fallback}.' - ) - model = None - self.model_name = model or fallback - - # for VRAM usage statistics - self.session_peakmem = ( - torch.cuda.max_memory_allocated(self.device) if self._has_cuda else None - ) - transformers.logging.set_verbosity_error() - - # gets rid of annoying messages about random seed - logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) - - # load safety checker if requested - if safety_checker: - logger.info("Initializing NSFW checker") - self.safety_checker = SafetyChecker(self.device) - else: - logger.info("NSFW checker is disabled") - - def prompt2png(self, prompt, outdir, **kwargs): - """ - Takes a prompt and an output directory, writes out the requested number - of PNG files, and returns an array of [[filename,seed],[filename,seed]...] - Optional named arguments are the same as those passed to Generate and prompt2image() - """ - results = self.prompt2image(prompt, **kwargs) - pngwriter = PngWriter(outdir) - prefix = pngwriter.unique_prefix() - outputs = [] - for image, seed in results: - name = f"{prefix}.{seed}.png" - path = pngwriter.save_image_and_prompt_to_png( - image, dream_prompt=f"{prompt} -S{seed}", name=name - ) - outputs.append([path, seed]) - return outputs - - def txt2img(self, prompt, **kwargs): - outdir = kwargs.pop("outdir", self.outdir) - return self.prompt2png(prompt, outdir, **kwargs) - - def img2img(self, prompt, **kwargs): - outdir = kwargs.pop("outdir", self.outdir) - assert ( - "init_img" in kwargs - ), "call to img2img() must include the init_img argument" - return self.prompt2png(prompt, outdir, **kwargs) - - def prompt2image( - self, - # these are common - prompt, - iterations=None, - steps=None, - seed=None, - cfg_scale=None, - ddim_eta=None, - skip_normalize=False, - image_callback=None, - step_callback=None, - width=None, - height=None, - sampler_name=None, - seamless=False, - seamless_axes={"x", "y"}, - log_tokenization=False, - with_variations=None, - variation_amount=0.0, - threshold=0.0, - perlin=0.0, - h_symmetry_time_pct=None, - v_symmetry_time_pct=None, - karras_max=None, - outdir=None, - # these are specific to img2img and inpaint - init_img=None, - init_mask=None, - text_mask=None, - invert_mask=False, - fit=False, - strength=None, - init_color=None, - # these are specific to embiggen (which also relies on img2img args) - embiggen=None, - embiggen_tiles=None, - embiggen_strength=None, - # these are specific to GFPGAN/ESRGAN - gfpgan_strength=0, - facetool=None, - facetool_strength=0, - codeformer_fidelity=None, - save_original=False, - upscale=None, - upscale_denoise_str=0.75, - # this is specific to inpainting and causes more extreme inpainting - inpaint_replace=0.0, - # This controls the size at which inpaint occurs (scaled up for inpaint, then back down for the result) - inpaint_width=None, - inpaint_height=None, - # This will help match inpainted areas to the original image more smoothly - mask_blur_radius: int = 8, - # Set this True to handle KeyboardInterrupt internally - catch_interrupts=False, - hires_fix=False, - use_mps_noise=False, - # Seam settings for outpainting - seam_size: int = 0, - seam_blur: int = 0, - seam_strength: float = 0.7, - seam_steps: int = 10, - tile_size: int = 32, - infill_method=None, - force_outpaint: bool = False, - enable_image_debugging=False, - **args, - ): # eat up additional cruft - self.clear_cuda_stats() - """ - ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() - It takes the following arguments: - prompt // prompt string (no default) - iterations // iterations (1); image count=iterations - steps // refinement steps per iteration - seed // seed for random number generator - width // width of image, in multiples of 64 (512) - height // height of image, in multiples of 64 (512) - cfg_scale // how strongly the prompt influences the image (7.5) (must be >1) - seamless // whether the generated image should tile - hires_fix // whether the Hires Fix should be applied during generation - init_img // path to an initial image - init_mask // path to a mask for the initial image - text_mask // a text string that will be used to guide clipseg generation of the init_mask - invert_mask // boolean, if true invert the mask - strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely - facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely - ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) - step_callback // a function or method that will be called each step - image_callback // a function or method that will be called each time an image is generated - with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation - variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) - threshold // optional value >=0 to add thresholding to latent values for k-diffusion samplers (0 disables) - perlin // optional 0-1 value to add a percentage of perlin noise to the initial noise - h_symmetry_time_pct // optional 0-1 value that indicates the time at which horizontal symmetry is applied - v_symmetry_time_pct // optional 0-1 value that indicates the time at which vertical symmetry is applied - embiggen // scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels - embiggen_tiles // list of tiles by number in order to process and replace onto the image e.g. `0 2 4` - embiggen_strength // strength for embiggen. 0.0 preserves image exactly, 1.0 replaces it completely - - To use the step callback, define a function that receives two arguments: - - Image GPU data - - The step number - - To use the image callback, define a function of method that receives two arguments, an Image object - and the seed. You can then do whatever you like with the image, including converting it to - different formats and manipulating it. For example: - - def process_image(image,seed): - image.save(f{'images/seed.png'}) - - The code used to save images to a directory can be found in ldm/invoke/pngwriter.py. - It contains code to create the requested output directory, select a unique informative - name for each image, and write the prompt into the PNG metadata. - """ - # TODO: convert this into a getattr() loop - steps = steps or self.steps - width = width or self.width - height = height or self.height - seamless = seamless or self.seamless - seamless_axes = seamless_axes or self.seamless_axes - hires_fix = hires_fix or self.hires_fix - cfg_scale = cfg_scale or self.cfg_scale - ddim_eta = ddim_eta or self.ddim_eta - iterations = iterations or self.iterations - strength = strength or self.strength - outdir = outdir or self.outdir - self.seed = seed - self.log_tokenization = log_tokenization - self.step_callback = step_callback - self.karras_max = karras_max - self.infill_method = ( - infill_method or infill_methods()[0], - ) # The infill method to use - with_variations = [] if with_variations is None else with_variations - - # will instantiate the model or return it from cache - model_context = self.set_model(self.model_name) - - # self.width and self.height are set by set_model() - # to the width and height of the image training set - width = width or self.width - height = height or self.height - - with model_context as model: - if isinstance(model, DiffusionPipeline): - configure_model_padding(model.unet, seamless, seamless_axes) - configure_model_padding(model.vae, seamless, seamless_axes) - else: - configure_model_padding(model, seamless, seamless_axes) - - assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0" - assert threshold >= 0.0, "--threshold must be >=0.0" - assert ( - 0.0 < strength <= 1.0 - ), "img2img and inpaint strength can only work with 0.0 < strength < 1.0" - assert ( - 0.0 <= variation_amount <= 1.0 - ), "-v --variation_amount must be in [0.0, 1.0]" - assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]" - assert (embiggen == None and embiggen_tiles == None) or ( - (embiggen != None or embiggen_tiles != None) and init_img != None - ), "Embiggen requires an init/input image to be specified" - - if len(with_variations) > 0 or variation_amount > 1.0: - assert seed is not None, "seed must be specified when using with_variations" - if variation_amount == 0.0: - assert ( - iterations == 1 - ), "when using --with_variations, multiple iterations are only possible when using --variation_amount" - assert all( - 0 <= weight <= 1 for _, weight in with_variations - ), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}" - - width, height, _ = self._resolution_check(width, height, log=True) - assert ( - inpaint_replace >= 0.0 and inpaint_replace <= 1.0 - ), "inpaint_replace must be between 0.0 and 1.0" - - if sampler_name and (sampler_name != self.sampler_name): - self.sampler_name = sampler_name - self._set_scheduler(model) - - # apply the concepts library to the prompt - prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( - prompt, - lambda concepts: self.load_huggingface_concepts(concepts), - model.textual_inversion_manager.get_all_trigger_strings(), - ) - - tic = time.time() - if self._has_cuda(): - torch.cuda.reset_peak_memory_stats() - - results = list() - - try: - uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, - model=model, - skip_normalize_legacy_blend=skip_normalize, - log_tokens=self.log_tokenization, - ) - - init_image, mask_image = self._make_images( - init_img, - init_mask, - width, - height, - fit=fit, - text_mask=text_mask, - invert_mask=invert_mask, - force_outpaint=force_outpaint, - ) - - # TODO: Hacky selection of operation to perform. Needs to be refactored. - generator = self.select_generator( - init_image, mask_image, embiggen, hires_fix, force_outpaint - ) - - generator.set_variation(self.seed, variation_amount, with_variations) - generator.use_mps_noise = use_mps_noise - - results = generator.generate( - prompt, - iterations=iterations, - seed=self.seed, - sampler=self.sampler, - steps=steps, - cfg_scale=cfg_scale, - conditioning=(uc, c, extra_conditioning_info), - ddim_eta=ddim_eta, - image_callback=image_callback, # called after the final image is generated - step_callback=step_callback, # called after each intermediate image is generated - width=width, - height=height, - init_img=init_img, # embiggen needs to manipulate from the unmodified init_img - init_image=init_image, # notice that init_image is different from init_img - mask_image=mask_image, - strength=strength, - threshold=threshold, - perlin=perlin, - h_symmetry_time_pct=h_symmetry_time_pct, - v_symmetry_time_pct=v_symmetry_time_pct, - embiggen=embiggen, - embiggen_tiles=embiggen_tiles, - embiggen_strength=embiggen_strength, - inpaint_replace=inpaint_replace, - mask_blur_radius=mask_blur_radius, - safety_checker=self.safety_checker, - seam_size=seam_size, - seam_blur=seam_blur, - seam_strength=seam_strength, - seam_steps=seam_steps, - tile_size=tile_size, - infill_method=infill_method, - force_outpaint=force_outpaint, - inpaint_height=inpaint_height, - inpaint_width=inpaint_width, - enable_image_debugging=enable_image_debugging, - free_gpu_mem=self.free_gpu_mem, - clear_cuda_cache=self.clear_cuda_cache, - ) - - if init_color: - self.correct_colors( - image_list=results, - reference_image_path=init_color, - image_callback=image_callback, - ) - - if upscale is not None or facetool_strength > 0: - self.upscale_and_reconstruct( - results, - upscale=upscale, - upscale_denoise_str=upscale_denoise_str, - facetool=facetool, - strength=facetool_strength, - codeformer_fidelity=codeformer_fidelity, - save_original=save_original, - image_callback=image_callback, - ) - - except KeyboardInterrupt: - # Clear the CUDA cache on an exception - self.clear_cuda_cache() - - if catch_interrupts: - logger.warning("Interrupted** Partial results will be returned.") - else: - raise KeyboardInterrupt - except RuntimeError: - # Clear the CUDA cache on an exception - self.clear_cuda_cache() - - print(traceback.format_exc(), file=sys.stderr) - logger.info("Could not generate image.") - - toc = time.time() - logger.info("Usage stats:") - logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic)) - self.print_cuda_stats() - return results - - def gather_cuda_stats(self): - if self._has_cuda(): - self.max_memory_allocated = max( - self.max_memory_allocated, torch.cuda.max_memory_allocated(self.device) - ) - self.memory_allocated = max( - self.memory_allocated, torch.cuda.memory_allocated(self.device) - ) - self.session_peakmem = max( - self.session_peakmem, torch.cuda.max_memory_allocated(self.device) - ) - - def clear_cuda_cache(self): - if self._has_cuda(): - self.gather_cuda_stats() - # Run garbage collection prior to emptying the CUDA cache - gc.collect() - torch.cuda.empty_cache() - - def clear_cuda_stats(self): - self.max_memory_allocated = 0 - self.memory_allocated = 0 - - def print_cuda_stats(self): - if self._has_cuda(): - self.gather_cuda_stats() - logger.info( - "Max VRAM used for this generation: "+ - "%4.2fG. " % (self.max_memory_allocated / 1e9)+ - "Current VRAM utilization: "+ - "%4.2fG" % (self.memory_allocated / 1e9) - ) - - logger.info( - "Max VRAM used since script start: " + - "%4.2fG" % (self.session_peakmem / 1e9) - ) - - # this needs to be generalized to all sorts of postprocessors, which should be wrapped - # in a nice harmonized call signature. For now we have a bunch of if/elses! - def apply_postprocessor( - self, - image_path, - tool="gfpgan", # one of 'upscale', 'gfpgan', 'codeformer', 'outpaint', or 'embiggen' - facetool_strength=0.0, - codeformer_fidelity=0.75, - upscale=None, - upscale_denoise_str=0.75, - out_direction=None, - outcrop=[], - save_original=True, # to get new name - callback=None, - opt=None, - ): - # retrieve the seed from the image; - seed = None - prompt = None - - args = metadata_from_png(image_path) - seed = opt.seed or args.seed - if seed is None or seed < 0: - seed = random.randrange(0, np.iinfo(np.uint32).max) - - prompt = opt.prompt or args.prompt or "" - logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}') - - # try to reuse the same filename prefix as the original file. - # we take everything up to the first period - prefix = None - m = re.match(r"^([^.]+)\.", os.path.basename(image_path)) - if m: - prefix = m.groups()[0] - - # face fixers and esrgan take an Image, but embiggen takes a path - image = Image.open(image_path) - - # used by multiple postfixers - # todo: cross-attention control - with self.model_context as model: - uc, c, extra_conditioning_info = get_uc_and_c_and_ec( - prompt, - model=model, - skip_normalize_legacy_blend=opt.skip_normalize, - log_tokens=log_tokenization, - ) - - if tool in ("gfpgan", "codeformer", "upscale"): - if tool == "gfpgan": - facetool = "gfpgan" - elif tool == "codeformer": - facetool = "codeformer" - elif tool == "upscale": - facetool = "gfpgan" # but won't be run - facetool_strength = 0 - return self.upscale_and_reconstruct( - [[image, seed]], - facetool=facetool, - strength=facetool_strength, - codeformer_fidelity=codeformer_fidelity, - save_original=save_original, - upscale=upscale, - upscale_denoise_str=upscale_denoise_str, - image_callback=callback, - prefix=prefix, - ) - - elif tool == "outcrop": - from .restoration.outcrop import Outcrop - - extend_instructions = {} - for direction, pixels in _pairwise(opt.outcrop): - try: - extend_instructions[direction] = int(pixels) - except ValueError: - logger.warning( - 'invalid extension instruction. Use ..., as in "top 64 left 128 right 64 bottom 64"' - ) - - opt.seed = seed - opt.prompt = prompt - - if len(extend_instructions) > 0: - restorer = Outcrop( - image, - self, - ) - return restorer.process( - extend_instructions, - opt=opt, - orig_opt=args, - image_callback=callback, - prefix=prefix, - ) - - elif tool == "embiggen": - # fetch the metadata from the image - generator = self.select_generator(embiggen=True) - opt.strength = opt.embiggen_strength or 0.40 - logger.info( - f"Setting img2img strength to {opt.strength} for happy embiggening" - ) - generator.generate( - prompt, - sampler=self.sampler, - steps=opt.steps, - cfg_scale=opt.cfg_scale, - ddim_eta=self.ddim_eta, - conditioning=(uc, c, extra_conditioning_info), - init_img=image_path, # not the Image! (sigh) - init_image=image, # embiggen wants both! (sigh) - strength=opt.strength, - width=opt.width, - height=opt.height, - embiggen=opt.embiggen, - embiggen_tiles=opt.embiggen_tiles, - embiggen_strength=opt.embiggen_strength, - image_callback=callback, - clear_cuda_cache=self.clear_cuda_cache, - ) - elif tool == "outpaint": - from .restoration.outpaint import Outpaint - - restorer = Outpaint(image, self) - return restorer.process(opt, args, image_callback=callback, prefix=prefix) - - elif tool is None: - logger.warning( - "please provide at least one postprocessing option, such as -G or -U" - ) - return None - else: - logger.warning(f"postprocessing tool {tool} is not yet supported") - return None - - def select_generator( - self, - init_image: Image.Image = None, - mask_image: Image.Image = None, - embiggen: bool = False, - hires_fix: bool = False, - force_outpaint: bool = False, - ): - if hires_fix: - return self._make_txt2img2img() - - if embiggen is not None: - return self._make_embiggen() - - if ((init_image is not None) and (mask_image is not None)) or force_outpaint: - return self._make_inpaint() - - if init_image is not None: - return self._make_img2img() - - return self._make_txt2img() - - def _make_images( - self, - img, - mask, - width, - height, - fit=False, - text_mask=None, - invert_mask=False, - force_outpaint=False, - ): - init_image = None - init_mask = None - if not img: - return None, None - - image = self._load_img(img) - - if image.width < self.width and image.height < self.height: - logger.warning( - f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions" - ) - - # if image has a transparent area and no mask was provided, then try to generate mask - if self._has_transparency(image): - self._transparency_check_and_warning(image, mask, force_outpaint) - init_mask = self._create_init_mask(image, width, height, fit=fit) - - if (image.width * image.height) > ( - self.width * self.height - ) and self.size_matters: - logger.info( - "This input is larger than your defaults. If you run out of memory, please use a smaller image." - ) - self.size_matters = False - - init_image = self._create_init_image(image, width, height, fit=fit) - - if mask: - mask_image = self._load_img(mask) - init_mask = self._create_init_mask(mask_image, width, height, fit=fit) - - elif text_mask: - init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) - - if init_mask and invert_mask: - init_mask = ImageOps.invert(init_mask) - - return init_image, init_mask - - def _make_base(self): - return self._load_generator("", "Generator") - - def _make_txt2img(self): - return self._load_generator(".txt2img", "Txt2Img") - - def _make_img2img(self): - return self._load_generator(".img2img", "Img2Img") - - def _make_embiggen(self): - return self._load_generator(".embiggen", "Embiggen") - - def _make_txt2img2img(self): - return self._load_generator(".txt2img2img", "Txt2Img2Img") - - def _make_inpaint(self): - return self._load_generator(".inpaint", "Inpaint") - - def _load_generator(self, module, class_name): - mn = f"invokeai.backend.generator{module}" - cn = class_name - module = importlib.import_module(mn) - constructor = getattr(module, cn) - with self.model_context as model: - return constructor(model, self.precision) - - def load_model(self): - """ - preload model identified in self.model_name - """ - return self.set_model(self.model_name) - - def set_model(self, model_name): - """ - Given the name of a model defined in models.yaml, will load and initialize it - and return the model object. Previously-used models will be cached. - - If the passed model_name is invalid, raises a KeyError. - If the model fails to load for some reason, will attempt to load the previously- - loaded model (if any). If that fallback fails, will raise an AssertionError - """ - if self.model_name == model_name and self.model_context is not None: - return self.model_context - - previous_model_name = self.model_name - - # the model cache does the loading and offloading - cache = self.model_manager - if not cache.model_exists(model_name): - raise KeyError( - f'** "{model_name}" is not a known model name. Cannot change.' - ) - - # have to get rid of all references to model in order - # to free it from GPU memory - self.model_context = None - self.sampler = None - self.generators = {} - gc.collect() - try: - model_data = cache.get_model(model_name) - except Exception as e: - logger.warning(f"model {model_name} could not be loaded: {str(e)}") - print(traceback.format_exc(), file=sys.stderr) - if previous_model_name is None: - raise e - logger.warning("trying to reload previous model") - model_data = cache.get_model(previous_model_name) # load previous - if model_data is None: - raise e - model_name = previous_model_name - - self.model_context = model_data.context - self.width = 512 - self.height = 512 - self.model_hash = model_data.hash - - # uncache generators so they pick up new models - self.generators = {} - - set_seed(random.randrange(0, np.iinfo(np.uint32).max)) - self.model_name = model_name - with self.model_context as model: - self._set_scheduler(model) # requires self.model_name to be set first - return self.model_context - - def load_huggingface_concepts(self, concepts: list[str]): - with self.model_context as model: - model.textual_inversion_manager.load_huggingface_concepts(concepts) - - @property - def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary: - with self.model_context as model: - return model.textual_inversion_manager.hf_concepts_library - - @property - def embedding_trigger_strings(self) -> List[str]: - with self.model_context as model: - return model.textual_inversion_manager.get_all_trigger_strings() - - def correct_colors(self, image_list, reference_image_path, image_callback=None): - reference_image = Image.open(reference_image_path) - correction_target = cv2.cvtColor(np.asarray(reference_image), cv2.COLOR_RGB2LAB) - for r in image_list: - image, seed = r - image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2LAB) - image = skimage.exposure.match_histograms( - image, correction_target, channel_axis=2 - ) - image = Image.fromarray( - cv2.cvtColor(image, cv2.COLOR_LAB2RGB).astype("uint8") - ) - if image_callback is not None: - image_callback(image, seed) - else: - r[0] = image - - def upscale_and_reconstruct( - self, - image_list, - facetool="gfpgan", - upscale=None, - upscale_denoise_str=0.75, - strength=0.0, - codeformer_fidelity=0.75, - save_original=False, - image_callback=None, - prefix=None, - ): - results = [] - for r in image_list: - image, seed, _ = r - try: - if strength > 0: - if self.gfpgan is not None or self.codeformer is not None: - if facetool == "gfpgan": - if self.gfpgan is None: - logger.info( - "GFPGAN not found. Face restoration is disabled." - ) - else: - image = self.gfpgan.process(image, strength, seed) - if facetool == "codeformer": - if self.codeformer is None: - logger.info( - "CodeFormer not found. Face restoration is disabled." - ) - else: - cf_device = ( - "cpu" if str(self.device) == "mps" else self.device - ) - image = self.codeformer.process( - image=image, - strength=strength, - device=cf_device, - seed=seed, - fidelity=codeformer_fidelity, - ) - else: - logger.info("Face Restoration is disabled.") - if upscale is not None: - if self.esrgan is not None: - if len(upscale) < 2: - upscale.append(0.75) - image = self.esrgan.process( - image, - upscale[1], - seed, - int(upscale[0]), - denoise_str=upscale_denoise_str, - ) - else: - logger.info("ESRGAN is disabled. Image not upscaled.") - except Exception as e: - logger.info( - f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}" - ) - - if image_callback is not None: - image_callback(image, seed, upscaled=True, use_prefix=prefix) - else: - r[0] = image - - results.append([image, seed]) - - return results - - def apply_textmask( - self, image_path: str, prompt: str, callback, threshold: float = 0.5 - ): - assert os.path.exists( - image_path - ), f'** "{image_path}" not found. Please enter the name of an existing image file to mask **' - basename, _ = os.path.splitext(os.path.basename(image_path)) - if self.txt2mask is None: - self.txt2mask = Txt2Mask(device=self.device, refined=True) - segmented = self.txt2mask.segment(image_path, prompt) - trans = segmented.to_transparent() - inverse = segmented.to_transparent(invert=True) - mask = segmented.to_mask(threshold) - - path_filter = re.compile(r'[<>:"/\\|?*]') - safe_prompt = path_filter.sub("_", prompt)[:50].rstrip(" .") - - callback(trans, f"{safe_prompt}.deselected", use_prefix=basename) - callback(inverse, f"{safe_prompt}.selected", use_prefix=basename) - callback(mask, f"{safe_prompt}.masked", use_prefix=basename) - - # to help WebGUI - front end to generator util function - def sample_to_image(self, samples): - return self._make_base().sample_to_image(samples) - - def sample_to_lowres_estimated_image(self, samples): - return self._make_base().sample_to_lowres_estimated_image(samples) - - def is_legacy_model(self, model_name) -> bool: - return self.model_manager.is_legacy(model_name) - - def _set_scheduler(self,model): - default = model.scheduler - - if self.sampler_name in SCHEDULER_MAP: - sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name] - msg = ( - f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})" - ) - self.sampler = sampler_class.from_config({**model.scheduler.config, **sampler_extra_config}) - else: - msg = ( - f" Unsupported Sampler: {self.sampler_name} "+ - f"Defaulting to {default}" - ) - self.sampler = default - - logger.info(msg) - - if not hasattr(self.sampler, "uses_inpainting_model"): - # FIXME: terrible kludge! - self.sampler.uses_inpainting_model = lambda: False - - def _load_img(self, img) -> Image: - if isinstance(img, Image.Image): - image = img - logger.info(f"using provided input image of size {image.width}x{image.height}") - elif isinstance(img, str): - assert os.path.exists(img), f"{img}: File not found" - - image = Image.open(img) - logger.info( - f"loaded input image of size {image.width}x{image.height} from {img}" - ) - else: - image = Image.open(img) - logger.info(f"loaded input image of size {image.width}x{image.height}") - image = ImageOps.exif_transpose(image) - return image - - def _create_init_image(self, image: Image.Image, width, height, fit=True): - if image.mode != "RGBA": - image = image.convert("RGBA") - image = ( - self._fit_image(image, (width, height)) - if fit - else self._squeeze_image(image) - ) - return image - - def _create_init_mask(self, image, width, height, fit=True): - # convert into a black/white mask - image = self._image_to_mask(image) - image = image.convert("RGB") - image = ( - self._fit_image(image, (width, height)) - if fit - else self._squeeze_image(image) - ) - return image - - # The mask is expected to have the region to be inpainted - # with alpha transparency. It converts it into a black/white - # image with the transparent part black. - def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image: - # Obtain the mask from the transparency channel - if mask_image.mode == "L": - mask = mask_image - elif mask_image.mode in ("RGB", "P"): - mask = mask_image.convert("L") - else: - # Obtain the mask from the transparency channel - mask = Image.new(mode="L", size=mask_image.size, color=255) - mask.putdata(mask_image.getdata(band=3)) - if invert: - mask = ImageOps.invert(mask) - return mask - - def _txt2mask( - self, image: Image, text_mask: list, width, height, fit=True - ) -> Image: - prompt = text_mask[0] - confidence_level = text_mask[1] if len(text_mask) > 1 else 0.5 - if self.txt2mask is None: - self.txt2mask = Txt2Mask(device=self.device) - - segmented = self.txt2mask.segment(image, prompt) - mask = segmented.to_mask(float(confidence_level)) - mask = mask.convert("RGB") - mask = ( - self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask) - ) - return mask - - def _has_transparency(self, image): - if image.info.get("transparency", None) is not None: - return True - if image.mode == "P": - transparent = image.info.get("transparency", -1) - for _, index in image.getcolors(): - if index == transparent: - return True - elif image.mode == "RGBA": - extrema = image.getextrema() - if extrema[3][0] < 255: - return True - return False - - def _check_for_erasure(self, image: Image.Image) -> bool: - if image.mode not in ("RGBA", "RGB"): - return False - width, height = image.size - pixdata = image.load() - colored = 0 - for y in range(height): - for x in range(width): - if pixdata[x, y][3] == 0: - r, g, b, _ = pixdata[x, y] - if (r, g, b) != (0, 0, 0) and (r, g, b) != (255, 255, 255): - colored += 1 - return colored == 0 - - def _transparency_check_and_warning(self, image, mask, force_outpaint=False): - if not mask: - logger.info( - "Initial image has transparent areas. Will inpaint in these regions." - ) - if (not force_outpaint) and self._check_for_erasure(image): - logger.info( - "Colors underneath the transparent region seem to have been erased.\n" + - "Inpainting will be suboptimal. Please preserve the colors when making\n" + - "a transparency mask, or provide mask explicitly using --init_mask (-M)." - ) - - def _squeeze_image(self, image): - x, y, resize_needed = self._resolution_check(image.width, image.height) - if resize_needed: - return InitImageResizer(image).resize(x, y) - return image - - def _fit_image(self, image, max_dimensions): - w, h = max_dimensions - logger.info(f"image will be resized to fit inside a box {w}x{h} in size.") - # note that InitImageResizer does the multiple of 64 truncation internally - image = InitImageResizer(image).resize(width=w, height=h) - logger.info( - f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}" - ) - return image - - def _resolution_check(self, width, height, log=False): - resize_needed = False - w, h = map( - lambda x: x - x % 64, (width, height) - ) # resize to integer multiple of 64 - if h != height or w != width: - if log: - logger.info( - f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}" - ) - height = h - width = w - resize_needed = True - return width, height, resize_needed - - def _has_cuda(self): - return self.device.type == "cuda" - - def write_intermediate_images(self, modulus, path): - counter = -1 - if not os.path.exists(path): - os.makedirs(path) - - def callback(img): - nonlocal counter - counter += 1 - if counter % modulus != 0: - return - image = self.sample_to_image(img) - image.save(os.path.join(path, f"{counter:03}.png"), "PNG") - - return callback - - -def _pairwise(iterable): - "s -> (s0, s1), (s2, s3), (s4, s5), ..." - a = iter(iterable) - return zip(a, a) diff --git a/invokeai/backend/globals.py b/invokeai/backend/globals.py deleted file mode 100644 index 5106ddb67d..0000000000 --- a/invokeai/backend/globals.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -invokeai.backend.globals defines a small number of global variables that would -otherwise have to be passed through long and complex call chains. - -It defines a Namespace object named "Globals" that contains -the attributes: - - - root - the root directory under which "models" and "outputs" can be found - - initfile - path to the initialization file - - try_patchmatch - option to globally disable loading of 'patchmatch' module - - always_use_cpu - force use of CPU even if GPU is available -""" - -import os -import os.path as osp -from argparse import Namespace -from pathlib import Path -from typing import Union - -Globals = Namespace() - -# Where to look for the initialization file and other key components -Globals.initfile = "invokeai.init" -Globals.models_file = "models.yaml" -Globals.models_dir = "models" -Globals.config_dir = "configs" -Globals.autoscan_dir = "weights" -Globals.converted_ckpts_dir = "converted_ckpts" - -# Set the default root directory. This can be overwritten by explicitly -# passing the `--root ` argument on the command line. -# logic is: -# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory) -# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there -# 3) use ~/invokeai - -if os.environ.get("INVOKEAI_ROOT"): - Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT")) -elif ( - os.environ.get("VIRTUAL_ENV") - and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists() -): - Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), "..")) -else: - Globals.root = osp.abspath(osp.expanduser("~/invokeai")) - -# Try loading patchmatch -Globals.try_patchmatch = True - -# Use CPU even if GPU is available (main use case is for debugging MPS issues) -Globals.always_use_cpu = False - -# Whether the internet is reachable for dynamic downloads -# The CLI will test connectivity at startup time. -Globals.internet_available = True - -# Whether to disable xformers -Globals.disable_xformers = False - -# Low-memory tradeoff for guidance calculations. -Globals.sequential_guidance = False - -# whether we are forcing full precision -Globals.full_precision = False - -# whether we should convert ckpt files into diffusers models on the fly -Globals.ckpt_convert = True - -# logging tokenization everywhere -Globals.log_tokenization = False - - -def global_config_file() -> Path: - return Path(Globals.root, Globals.config_dir, Globals.models_file) - - -def global_config_dir() -> Path: - return Path(Globals.root, Globals.config_dir) - - -def global_models_dir() -> Path: - return Path(Globals.root, Globals.models_dir) - - -def global_autoscan_dir() -> Path: - return Path(Globals.root, Globals.autoscan_dir) - - -def global_converted_ckpts_dir() -> Path: - return Path(global_models_dir(), Globals.converted_ckpts_dir) - - -def global_set_root(root_dir: Union[str, Path]): - Globals.root = root_dir - -def global_resolve_path(path: Union[str,Path]): - if path is None: - return None - return Path(Globals.root,path).resolve() - -def global_cache_dir(subdir: Union[str, Path] = "") -> Path: - """ - Returns Path to the model cache directory. If a subdirectory - is provided, it will be appended to the end of the path, allowing - for Hugging Face-style conventions. Currently, Hugging Face has - moved all models into the "hub" subfolder, so for any pretrained - HF model, use: - global_cache_dir('hub') - - The legacy location for transformers used to be global_cache_dir('transformers') - and global_cache_dir('diffusers') for diffusers. - """ - home: str = os.getenv("HF_HOME") - - if home is None: - home = os.getenv("XDG_CACHE_HOME") - - if home is not None: - # Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library. - # See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome - home += os.sep + "huggingface" - - if home is not None: - return Path(home, subdir) - else: - return Path(Globals.root, "models", subdir) diff --git a/invokeai/backend/image_util/patchmatch.py b/invokeai/backend/image_util/patchmatch.py index 5b5dd75f68..0d2221be41 100644 --- a/invokeai/backend/image_util/patchmatch.py +++ b/invokeai/backend/image_util/patchmatch.py @@ -6,7 +6,7 @@ be suppressed or deferred """ import numpy as np import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config class PatchMatch: """ @@ -21,9 +21,10 @@ class PatchMatch: @classmethod def _load_patch_match(self): + config = get_invokeai_config() if self.tried_load: return - if Globals.try_patchmatch: + if config.try_patchmatch: from patchmatch import patch_match as pm if pm.patchmatch_available: diff --git a/invokeai/backend/image_util/txt2mask.py b/invokeai/backend/image_util/txt2mask.py index 248f19d81d..1a8fcfeb90 100644 --- a/invokeai/backend/image_util/txt2mask.py +++ b/invokeai/backend/image_util/txt2mask.py @@ -33,12 +33,11 @@ from PIL import Image, ImageOps from transformers import AutoProcessor, CLIPSegForImageSegmentation import invokeai.backend.util.logging as logger -from invokeai.backend.globals import global_cache_dir +from invokeai.app.services.config import get_invokeai_config CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_SIZE = 352 - class SegmentedGrayscale(object): def __init__(self, image: Image, heatmap: torch.Tensor): self.heatmap = heatmap @@ -84,14 +83,15 @@ class Txt2Mask(object): def __init__(self, device="cpu", refined=False): logger.info("Initializing clipseg model for text to mask inference") + config = get_invokeai_config() # BUG: we are not doing anything with the device option at this time self.device = device self.processor = AutoProcessor.from_pretrained( - CLIPSEG_MODEL, cache_dir=global_cache_dir("hub") + CLIPSEG_MODEL, cache_dir=config.cache_dir ) self.model = CLIPSegForImageSegmentation.from_pretrained( - CLIPSEG_MODEL, cache_dir=global_cache_dir("hub") + CLIPSEG_MODEL, cache_dir=config.cache_dir ) @torch.no_grad() diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 5874d35c6b..d3dee08b75 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -26,7 +26,7 @@ import torch from safetensors.torch import load_file import invokeai.backend.util.logging as logger -from invokeai.backend.globals import global_cache_dir, global_config_dir +from invokeai.app.services.config import get_invokeai_config from .model_manager import ModelManager, SDLegacyType from .model_cache import ModelCache @@ -76,7 +76,6 @@ from transformers import ( from ..stable_diffusion import StableDiffusionGeneratorPipeline - def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. @@ -858,7 +857,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config): def convert_ldm_clip_checkpoint(checkpoint): text_model = CLIPTextModel.from_pretrained( - "openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub") + "openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir ) keys = list(checkpoint.keys()) @@ -913,7 +912,7 @@ textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): - cache_dir = global_cache_dir("hub") + cache_dir = get_invokeai_config().cache_dir config = CLIPVisionConfig.from_pretrained( "openai/clip-vit-large-patch14", cache_dir=cache_dir ) @@ -985,7 +984,7 @@ def convert_paint_by_example_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint): - cache_dir = global_cache_dir("hub") + cache_dir = get_invokeai_config().cache_dir text_model = CLIPTextModel.from_pretrained( "stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir ) @@ -1121,7 +1120,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( :param vae: A diffusers VAE to load into the pipeline. :param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline. """ - + config = get_invokeai_config() with warnings.catch_warnings(): warnings.simplefilter("ignore") verbosity = dlogging.get_verbosity() @@ -1134,7 +1133,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( else: checkpoint = load_file(checkpoint_path) - cache_dir = global_cache_dir("hub") + cache_dir = config.cache_dir pipeline_class = ( StableDiffusionGeneratorPipeline if return_generator_pipeline @@ -1158,25 +1157,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if model_type == SDLegacyType.V2_v: original_config_file = ( - global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml" + config.legacy_conf_path / "v2-inference-v.yaml" ) if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True elif model_type == SDLegacyType.V2_e: original_config_file = ( - global_config_dir() / "stable-diffusion" / "v2-inference.yaml" + config.legacy_conf_path / "v2-inference.yaml" ) elif model_type == SDLegacyType.V1_INPAINT: original_config_file = ( - global_config_dir() - / "stable-diffusion" - / "v1-inpainting-inference.yaml" + config.legacy_conf_path / "v1-inpainting-inference.yaml" ) elif model_type == SDLegacyType.V1: original_config_file = ( - global_config_dir() / "stable-diffusion" / "v1-inference.yaml" + config.legacy_conf_path / "v1-inference.yaml" ) else: @@ -1323,7 +1320,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( ) safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", - cache_dir=global_cache_dir("hub"), + cache_dir=config.cache_dir, ) feature_extractor = AutoFeatureExtractor.from_pretrained( "CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 5c2b498acf..f8484affa8 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -25,27 +25,25 @@ import warnings from contextlib import suppress from enum import Enum from pathlib import Path -from typing import Dict, Sequence, Union, Tuple, types, Optional, List, Type, Any +from typing import Dict, Sequence, Union, types, Optional, List, Type, Any import torch -import safetensors.torch - + from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin from diffusers import logging as diffusers_logging from huggingface_hub import HfApi, scan_cache_dir -from picklescan.scanner import scan_file_path -from pydantic import BaseModel from transformers import logging as transformers_logging import invokeai.backend.util.logging as logger -from ..globals import global_cache_dir - +from invokeai.app.services.config import get_invokeai_config def get_model_path(repo_id_or_path: str): + globals = get_invokeai_config() + if os.path.exists(repo_id_or_path): return repo_id_or_path - cache = scan_cache_dir(global_cache_dir("hub")) + cache = scan_cache_dir(globals.cache_dir) for repo in cache.repos: if repo.repo_id != repo_id_or_path: continue @@ -234,7 +232,7 @@ class DiffusersModelInfo(ModelInfoBase): model = self.child_types[child_type].from_pretrained( self.repo_id_or_path, subfolder=child_type.value, - cache_dir=global_cache_dir('hub'), + cache_dir=get_invokeai_config.cache_dir('hub'), torch_dtype=torch_dtype, variant=variant, ) @@ -248,7 +246,7 @@ class DiffusersModelInfo(ModelInfoBase): return model - def get_pipeline(self, **kwrags): + def get_pipeline(self, **kwargs): return DiffusionPipeline.from_pretrained( self.repo_id_or_path, **kwargs, @@ -349,7 +347,7 @@ class ClassifierModelInfo(ModelInfoBase): model = self.child_types[child_type].from_pretrained( self.repo_id_or_path, subfolder=child_type.value, - cache_dir=global_cache_dir('hub'), + cache_dir=get_invokeai_config().cache_dir('hub'), torch_dtype=torch_dtype, ) # calc more accurate size @@ -394,7 +392,7 @@ class VaeModelInfo(ModelInfoBase): model = self.vae_type.from_pretrained( self.repo_id_or_path, - cache_dir=global_cache_dir('hub'), + cache_dir=get_invokeai_config().cache_dir('hub'), torch_dtype=torch_dtype, ) # calc more accurate size diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index a77f0613f3..d6aee3652d 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -149,8 +149,7 @@ from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig import invokeai.backend.util.logging as logger -from invokeai.backend.globals import (Globals, global_cache_dir, - global_resolve_path) +from invokeai.app.services.config import get_invokeai_config from invokeai.backend.util import download_with_resume from ..util import CUDA_DEVICE @@ -226,7 +225,8 @@ class ModelManager(object): # check config version number and update on disk/RAM if necessary self._update_config_file_version() - + self.globals = get_invokeai_config() + self.logger = logger self.cache = ModelCache( max_cache_size=max_cache_size, execution_device = device_type, @@ -235,7 +235,6 @@ class ModelManager(object): logger = logger, ) self.cache_keys = dict() - self.logger = logger def model_exists( self, @@ -304,12 +303,6 @@ class ModelManager(object): # raises an InvalidModelError """ - - # Commented-out workaround for callers that use "type/name" as the model name - # because they haven't adjusted to the new return format of `list_models()` - # if "/" in model_name: - # model_key = model_name - # else: model_key = self.create_key(model_name, model_type) if model_key not in self.config: raise InvalidModelError( @@ -326,13 +319,15 @@ class ModelManager(object): if mconfig.format in ["ckpt", "safetensors"]: location = self.convert_ckpt_and_cache(mconfig) else: - location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id') + location = self.globals.root_dir / mconfig.get('path') or mconfig.get('repo_id') + elif p := mconfig.get('path'): + location = self.globals.root_dir / p + elif r := mconfig.get('repo_id'): + location = r + elif w := mconfig.get('weights'): + location = self.globals.root_dir / w else: - location = global_resolve_path( - mconfig.get('path')) \ - or mconfig.get('repo_id') \ - or global_resolve_path(mconfig.get('weights') - ) + location = None revision = mconfig.get('revision') hash = self.cache.model_hash(location, revision) @@ -423,7 +418,7 @@ class ModelManager(object): """ # if we are converting legacy files automatically, then # there are no legacy ckpts! - if Globals.ckpt_convert: + if self.globals.ckpt_convert: return False info = self.model_info(model_name, model_type) if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")): @@ -862,25 +857,16 @@ class ModelManager(object): model_type = self.probe_model_type(checkpoint) if model_type == SDLegacyType.V1: self.logger.debug("SD-v1 model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inference.yaml" - ) + model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml" elif model_type == SDLegacyType.V1_INPAINT: self.logger.debug("SD-v1 inpainting model detected") - model_config_file = Path( - Globals.root, - "configs/stable-diffusion/v1-inpainting-inference.yaml", - ) + model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml", elif model_type == SDLegacyType.V2_v: self.logger.debug("SD-v2-v model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" - ) + model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml" elif model_type == SDLegacyType.V2_e: self.logger.debug("SD-v2-e model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference.yaml" - ) + model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml" elif model_type == SDLegacyType.V2: self.logger.warning( f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." @@ -907,9 +893,7 @@ class ModelManager(object): self.logger.debug(f"Using VAE file {vae_path.name}") vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse") - diffuser_path = Path( - Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem - ) + diffuser_path = self.globals.converted_ckpts_dir / model_path.stem with SilenceWarnings(): model_name = self.convert_and_import( model_path, @@ -930,9 +914,9 @@ class ModelManager(object): diffusers, cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ - weights = global_resolve_path(mconfig.weights) - config_file = global_resolve_path(mconfig.config) - diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem + weights = self.globals.root_dir / mconfig.weights + config_file = self.globals.root_dir / mconfig.config + diffusers_path = self.globals.converted_ckpts_dir / weights.stem # return cached version if it exists if diffusers_path.exists(): @@ -949,7 +933,7 @@ class ModelManager(object): extract_ema=True, original_config_file=config_file, vae=vae_model, - vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None, + vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None, scan_needed=True, ) return diffusers_path @@ -960,9 +944,10 @@ class ModelManager(object): object, cache it to disk, and return Path to converted file. If already on disk then just returns Path. """ - weights_file = global_resolve_path(mconfig.weights) - config_file = global_resolve_path(mconfig.config) - diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem + root = self.globals.root_dir + weights_file = root / mconfig.weights + config_file = root / mconfig.config + diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem image_size = mconfig.get('width') or mconfig.get('height') or 512 # return cached version if it exists @@ -1018,7 +1003,9 @@ class ModelManager(object): # 3. If mconfig has a vae dict, then we use it as the diffusers-style vae if vae_config and isinstance(vae_config,DictConfig): - vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id') + vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \ + if vae_config.get('path') \ + else vae_config.get('repo_id') # 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works" else: @@ -1072,7 +1059,9 @@ class ModelManager(object): # will be built into the model rather than tacked on afterward via the config file vae_model = None if vae: - vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id') + vae_location = self.globals.root_dir / vae.get('path') \ + if vae.get('path') \ + else vae.get('repo_id') vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model vae_path = None convert_ckpt_to_diffusers( @@ -1140,6 +1129,7 @@ class ModelManager(object): yaml_str = OmegaConf.to_yaml(self.config) config_file_path = conf_file or self.config_path assert config_file_path is not None,'no config file path to write to' + config_file_path = self.globals.root_dir / config_file_path tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") with open(tmpfile, "w", encoding="utf-8") as outfile: outfile.write(self.preamble()) @@ -1160,7 +1150,7 @@ class ModelManager(object): @classmethod def _delete_model_from_cache(cls,repo_id): - cache_info = scan_cache_dir(global_cache_dir("hub")) + cache_info = scan_cache_dir(get_invokeai_config().cache_dir) # I'm sure there is a way to do this with comprehensions # but the code quickly became incomprehensible! @@ -1177,9 +1167,10 @@ class ModelManager(object): @staticmethod def _abs_path(path: str | Path) -> Path: + globals = get_invokeai_config() if path is None or Path(path).is_absolute(): return path - return Path(Globals.root, path).resolve() + return Path(globals.root_dir, path).resolve() # This is not the same as global_resolve_path(), which prepends # Globals.root. @@ -1188,15 +1179,11 @@ class ModelManager(object): ) -> Optional[Path]: resolved_path = None if str(source).startswith(("http:", "https:", "ftp:")): - dest_directory = Path(dest_directory) - if not dest_directory.is_absolute(): - dest_directory = Globals.root / dest_directory + dest_directory = self.globals.root_dir / dest_directory dest_directory.mkdir(parents=True, exist_ok=True) resolved_path = download_with_resume(str(source), dest_directory) else: - if not os.path.isabs(source): - source = os.path.join(Globals.root, source) - resolved_path = Path(source) + resolved_path = self.globals.root_dir / source return resolved_path def _update_config_file_version(self): diff --git a/invokeai/backend/prompting/conditioning.py b/invokeai/backend/prompting/conditioning.py index 7c6cc0eea2..c03bd93ede 100644 --- a/invokeai/backend/prompting/conditioning.py +++ b/invokeai/backend/prompting/conditioning.py @@ -17,67 +17,59 @@ from compel.prompt_parser import ( FlattenedPrompt, Fragment, PromptParser, + Conjunction, ) import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config from ..stable_diffusion import InvokeAIDiffuserComponent from ..util import torch_dtype - -def get_uc_and_c_and_ec( - prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False -): +def get_uc_and_c_and_ec(prompt_string, + model: InvokeAIDiffuserComponent, + log_tokens=False, skip_normalize_legacy_blend=False): # lazy-load any deferred textual inversions. # this might take a couple of seconds the first time a textual inversion is used. - model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( - prompt_string - ) + model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) - tokenizer = model.tokenizer - compel = Compel( - tokenizer=tokenizer, - text_encoder=model.text_encoder, - textual_inversion_manager=model.textual_inversion_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=False - ) + compel = Compel(tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=False, + ) + + config = get_invokeai_config() # get rid of any newline characters prompt_string = prompt_string.replace("\n", " ") - ( - positive_prompt_string, - negative_prompt_string, - ) = split_prompt_to_positive_and_negative(prompt_string) - legacy_blend = try_parse_legacy_blend( - positive_prompt_string, skip_normalize_legacy_blend - ) - positive_prompt: Union[FlattenedPrompt, Blend] - if legacy_blend is not None: - positive_prompt = legacy_blend - else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) - if log_tokens or getattr(Globals, "log_tokenization", False): - log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) + legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) + positive_conjunction: Conjunction + if legacy_blend is not None: + positive_conjunction = legacy_blend + else: + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0] + + tokens_count = get_max_token_count(model.tokenizer, positive_prompt) + if log_tokens or config.log_tokenization: + log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer) c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - tokens_count = get_max_token_count(tokenizer, positive_prompt) - - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=tokens_count, - cross_attention_control_args=options.get("cross_attention_control", None), - ) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, + cross_attention_control_args=options.get( + 'cross_attention_control', None)) return uc, c, ec - def get_prompt_structure( prompt_string, skip_normalize_legacy_blend: bool = False ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt): @@ -88,18 +80,17 @@ def get_prompt_structure( legacy_blend = try_parse_legacy_blend( positive_prompt_string, skip_normalize_legacy_blend ) - positive_prompt: Union[FlattenedPrompt, Blend] + positive_prompt: Conjunction if legacy_blend is not None: - positive_prompt = legacy_blend + positive_conjunction = legacy_blend else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0] return positive_prompt, negative_prompt - def get_max_token_count( tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False ) -> int: @@ -246,22 +237,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") logger.debug(f"{discarded}\x1b[0m") - -def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]: +def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]: weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) if len(weighted_subprompts) <= 1: return None strings = [x[0] for x in weighted_subprompts] - weights = [x[1] for x in weighted_subprompts] pp = PromptParser() parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] - flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] - - return Blend( - prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize - ) - + flattened_prompts = [] + weights = [] + for i, x in enumerate(parsed_conjunctions): + if len(x.prompts)>0: + flattened_prompts.append(x.prompts[0]) + weights.append(weighted_subprompts[i][1]) + return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)]) def split_weighted_subprompts(text, skip_normalize=False) -> list: """ diff --git a/invokeai/backend/restoration/codeformer.py b/invokeai/backend/restoration/codeformer.py index 5b578af082..b7073f8f8b 100644 --- a/invokeai/backend/restoration/codeformer.py +++ b/invokeai/backend/restoration/codeformer.py @@ -6,7 +6,7 @@ import numpy as np import torch import invokeai.backend.util.logging as logger -from ..globals import Globals +from invokeai.app.services.config import get_invokeai_config pretrained_model_url = ( "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" @@ -17,11 +17,11 @@ class CodeFormerRestoration: def __init__( self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth" ) -> None: - if not os.path.isabs(codeformer_dir): - codeformer_dir = os.path.join(Globals.root, codeformer_dir) - self.model_path = os.path.join(codeformer_dir, codeformer_model_path) - self.codeformer_model_exists = os.path.isfile(self.model_path) + self.globals = get_invokeai_config() + codeformer_dir = self.globals.root_dir / codeformer_dir + self.model_path = codeformer_dir / codeformer_model_path + self.codeformer_model_exists = self.model_path.exists() if not self.codeformer_model_exists: logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path) @@ -71,9 +71,7 @@ class CodeFormerRestoration: upscale_factor=1, use_parse=True, device=device, - model_rootpath=os.path.join( - Globals.root, "models", "gfpgan", "weights" - ), + model_rootpath = self.globals.root_dir / "gfpgan" / "weights" ) face_helper.clean_all() face_helper.read_image(bgr_image_array) diff --git a/invokeai/backend/restoration/gfpgan.py b/invokeai/backend/restoration/gfpgan.py index b5c0278362..063feaa89a 100644 --- a/invokeai/backend/restoration/gfpgan.py +++ b/invokeai/backend/restoration/gfpgan.py @@ -7,14 +7,13 @@ import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config class GFPGAN: def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None: + self.globals = get_invokeai_config() if not os.path.isabs(gfpgan_model_path): - gfpgan_model_path = os.path.abspath( - os.path.join(Globals.root, gfpgan_model_path) - ) + gfpgan_model_path = self.globals.root_dir / gfpgan_model_path self.model_path = gfpgan_model_path self.gfpgan_model_exists = os.path.isfile(self.model_path) @@ -33,7 +32,7 @@ class GFPGAN: warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) cwd = os.getcwd() - os.chdir(os.path.join(Globals.root, "models")) + os.chdir(self.globals.root_dir / 'models') try: from gfpgan import GFPGANer diff --git a/invokeai/backend/restoration/realesrgan.py b/invokeai/backend/restoration/realesrgan.py index 9f26cc63ac..c6c6d2d3b4 100644 --- a/invokeai/backend/restoration/realesrgan.py +++ b/invokeai/backend/restoration/realesrgan.py @@ -1,4 +1,3 @@ -import os import warnings import numpy as np @@ -7,7 +6,8 @@ from PIL import Image from PIL.Image import Image as ImageType import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config +config = get_invokeai_config() class ESRGAN: def __init__(self, bg_tile_size=400) -> None: @@ -30,12 +30,8 @@ class ESRGAN: upscale=4, act_type="prelu", ) - model_path = os.path.join( - Globals.root, "models/realesrgan/realesr-general-x4v3.pth" - ) - wdn_model_path = os.path.join( - Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth" - ) + model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth" + wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth" scale = 4 bg_upsampler = RealESRGANer( diff --git a/invokeai/backend/safety_checker.py b/invokeai/backend/safety_checker.py index 3003981888..55e8eb1987 100644 --- a/invokeai/backend/safety_checker.py +++ b/invokeai/backend/safety_checker.py @@ -15,7 +15,7 @@ from transformers import AutoFeatureExtractor import invokeai.assets.web as web_assets import invokeai.backend.util.logging as logger -from .globals import global_cache_dir +from invokeai.app.services.config import get_invokeai_config from .util import CPU_DEVICE class SafetyChecker(object): @@ -26,10 +26,11 @@ class SafetyChecker(object): caution = Image.open(path) self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) self.device = device - + config = get_invokeai_config() + try: safety_model_id = "CompVis/stable-diffusion-safety-checker" - safety_model_path = global_cache_dir("hub") + safety_model_path = config.cache_dir self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( safety_model_id, local_files_only=True, diff --git a/invokeai/backend/stable_diffusion/concepts_lib.py b/invokeai/backend/stable_diffusion/concepts_lib.py index ebbcc9c3e9..beb884b012 100644 --- a/invokeai/backend/stable_diffusion/concepts_lib.py +++ b/invokeai/backend/stable_diffusion/concepts_lib.py @@ -18,15 +18,15 @@ from huggingface_hub import ( ) import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals - +from invokeai.app.services.config import get_invokeai_config class HuggingFaceConceptsLibrary(object): def __init__(self, root=None): """ Initialize the Concepts object. May optionally pass a root directory. """ - self.root = root or Globals.root + self.config = get_invokeai_config() + self.root = root or self.config.root self.hf_api = HfApi() self.local_concepts = dict() self.concept_list = None @@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object): self.concept_list.extend(list(local_concepts_to_add)) return self.concept_list return self.concept_list - elif Globals.internet_available is True: + elif self.config.internet_available is True: try: models = self.hf_api.list_models( filter=ModelFilter(model_name="sd-concepts-library/") diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index c8a932b9e9..4ca2a5cb30 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -33,8 +33,7 @@ from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from typing_extensions import ParamSpec -from invokeai.backend.globals import Globals - +from invokeai.app.services.config import get_invokeai_config from ..util import CPU_DEVICE, normalize_device from .diffusion import ( AttentionMapSaver, @@ -44,7 +43,6 @@ from .diffusion import ( from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup from .textual_inversion_manager import TextualInversionManager - @dataclass class PipelineIntermediateState: run_id: str @@ -348,10 +346,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): """ if xformers is available, use it, otherwise use sliced attention. """ + config = get_invokeai_config() if ( torch.cuda.is_available() and is_xformers_available() - and not Globals.disable_xformers + and not config.disable_xformers ): self.enable_xformers_memory_efficient_attention() else: @@ -548,8 +547,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( - extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), + self.invokeai_diffuser.model, + extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), ): yield PipelineIntermediateState( run_id=run_id, diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index dfd19ea964..79a0982cfe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -10,6 +10,7 @@ import diffusers import psutil import torch from compel.cross_attention_control import Arguments +from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from torch import nn @@ -352,8 +353,7 @@ def restore_default_cross_attention( else: remove_attention_function(model) - -def override_cross_attention(model, context: Context, is_running_diffusers=False): +def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: - if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - if is_running_diffusers: - unet = model - old_attn_processors = unet.attn_processors - if torch.backends.mps.is_available(): - # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS - unet.set_attn_processor(SwapCrossAttnProcessor()) - else: - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next( - ( - p.slice_size - for p in old_attn_processors.values() - if type(p) is SlicedAttnProcessor - ), - default_slice_size, - ) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) - return old_attn_processors + old_attn_processors = unet.attn_processors + if torch.backends.mps.is_available(): + # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS + unet.set_attn_processor(SwapCrossAttnProcessor()) else: - context.register_cross_attention_modules(model) - inject_attention_function(model, context) - return None - + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) def get_cross_attention_modules( model, which: CrossAttentionType diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index b0c85e9fd3..4131837b41 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch +from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from typing_extensions import TypeAlias import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config from .cross_attention_control import ( Arguments, @@ -17,8 +18,8 @@ from .cross_attention_control import ( CrossAttentionType, SwapCrossAttnContext, get_cross_attention_modules, - override_cross_attention, restore_default_cross_attention, + setup_cross_attention_control_attention_processors, ) from .cross_attention_map_saving import AttentionMapSaver @@ -31,7 +32,6 @@ ModelForwardCallback: TypeAlias = Union[ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], ] - @dataclass(frozen=True) class PostprocessingSettings: threshold: float @@ -72,31 +72,43 @@ class InvokeAIDiffuserComponent: :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) """ + config = get_invokeai_config() self.conditioning = None self.model = model self.is_running_diffusers = is_running_diffusers self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None - self.sequential_guidance = Globals.sequential_guidance + self.sequential_guidance = config.sequential_guidance + @classmethod @contextmanager def custom_attention_context( - self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int + cls, + unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int ): - do_swap = ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ) - old_attn_processor = None - if do_swap: - old_attn_processor = self.override_cross_attention( - extra_conditioning_info, step_count=step_count - ) + old_attn_processors = None + if extra_conditioning_info and ( + extra_conditioning_info.wants_cross_attention_control + ): + old_attn_processors = unet.attn_processors + # Load lora conditions into the model + if extra_conditioning_info.wants_cross_attention_control: + cross_attention_control_context = Context( + arguments=extra_conditioning_info.cross_attention_control_args, + step_count=step_count, + ) + setup_cross_attention_control_attention_processors( + unet, + cross_attention_control_context, + ) + try: yield None finally: - if old_attn_processor is not None: - self.restore_default_cross_attention(old_attn_processor) + if old_attn_processors is not None: + unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving # self.remove_attention_map_saving() diff --git a/invokeai/backend/stable_diffusion/schedulers/schedulers.py b/invokeai/backend/stable_diffusion/schedulers/schedulers.py index fab28aca8c..08f85cf559 100644 --- a/invokeai/backend/stable_diffusion/schedulers/schedulers.py +++ b/invokeai/backend/stable_diffusion/schedulers/schedulers.py @@ -9,7 +9,8 @@ SCHEDULER_MAP = dict( deis=(DEISMultistepScheduler, dict()), lms=(LMSDiscreteScheduler, dict()), pndm=(PNDMScheduler, dict()), - heun=(HeunDiscreteScheduler, dict()), + heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)), + heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)), euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)), euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)), euler_a=(EulerAncestralDiscreteScheduler, dict()), diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index d2d994906a..8c27a6e718 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -7,7 +7,6 @@ This is the backend to "textual_inversion.py" """ -import argparse import logging import math import os @@ -47,8 +46,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer # invokeai stuff -from ..args import ArgFormatter, PagingArgumentParser -from ..globals import Globals, global_cache_dir +from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { @@ -90,8 +88,9 @@ def save_progress( def parse_args(): + config = InvokeAIAppConfig(argv=[]) parser = PagingArgumentParser( - description="Textual inversion training", formatter_class=ArgFormatter + description="Textual inversion training" ) general_group = parser.add_argument_group("General") model_group = parser.add_argument_group("Models and Paths") @@ -112,7 +111,7 @@ def parse_args(): "--root_dir", "--root", type=Path, - default=Globals.root, + default=config.root, help="Path to the invokeai runtime directory", ) general_group.add_argument( @@ -127,7 +126,7 @@ def parse_args(): general_group.add_argument( "--output_dir", type=Path, - default=f"{Globals.root}/text-inversion-model", + default=f"{config.root}/text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) model_group.add_argument( @@ -528,6 +527,7 @@ def get_full_repo_name( def do_textual_inversion_training( + config: InvokeAIAppConfig, model: str, train_data_dir: Path, output_dir: Path, @@ -580,7 +580,7 @@ def do_textual_inversion_training( # setting up things the way invokeai expects them if not os.path.isabs(output_dir): - output_dir = os.path.join(Globals.root, output_dir) + output_dir = os.path.join(config.root, output_dir) logging_dir = output_dir / logging_dir @@ -628,7 +628,7 @@ def do_textual_inversion_training( elif output_dir is not None: os.makedirs(output_dir, exist_ok=True) - models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml")) + models_conf = OmegaConf.load(config.model_conf_path) model_conf = models_conf.get(model, None) assert model_conf is not None, f"Unknown model: {model}" assert ( @@ -640,7 +640,7 @@ def do_textual_inversion_training( assert ( pretrained_model_name_or_path ), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}" - pipeline_args = dict(cache_dir=global_cache_dir("hub")) + pipeline_args = dict(cache_dir=config.cache_dir) # Load tokenizer if tokenizer_name: diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index c70a43ff09..c6c0819df8 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -4,17 +4,16 @@ from contextlib import nullcontext import torch from torch import autocast - -from invokeai.backend.globals import Globals +from invokeai.app.services.config import get_invokeai_config CPU_DEVICE = torch.device("cpu") CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") - def choose_torch_device() -> torch.device: """Convenience routine for guessing which GPU device to run model on""" - if Globals.always_use_cpu: + config = get_invokeai_config() + if config.always_use_cpu: return CPU_DEVICE if torch.cuda.is_available(): return torch.device("cuda") @@ -33,7 +32,8 @@ def choose_precision(device: torch.device) -> str: def torch_dtype(device: torch.device) -> torch.dtype: - if Globals.full_precision: + config = get_invokeai_config() + if config.full_precision: return torch.float32 if choose_precision(device) == "float16": return torch.float16 diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 73f980aeff..3822ccafbe 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -2,34 +2,37 @@ """invokeai.util.logging -Logging class for InvokeAI that produces console messages that follow -the conventions established in InvokeAI 1.X through 2.X. +Logging class for InvokeAI that produces console messages - -One way to use it: +Usage: from invokeai.backend.util.logging import InvokeAILogger -logger = InvokeAILogger.getLogger(__name__) -logger.critical('this is critical') -logger.error('this is an error') -logger.warning('this is a warning') -logger.info('this is info') -logger.debug('this is debugging') +logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization +(or) +logger = InvokeAILogger.getLogger(__name__) // To use the filename + +logger.critical('this is critical') // Critical Message +logger.error('this is an error') // Error Message +logger.warning('this is a warning') // Warning Message +logger.info('this is info') // Info Message +logger.debug('this is debugging') // Debug Message Console messages: - ### this is critical - *** this is an error *** - ** this is a warning - >> this is info - | this is debugging + [12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red] + [12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red] + [12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow] + [12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey] + [12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey] -Another way: -import invokeai.backend.util.logging as ialog -ialogger.debug('this is a debugging message') +Alternate Method (in this case the logger name will be set to InvokeAI): +import invokeai.backend.util.logging as IAILogger +IAILogger.debug('this is a debugging message') """ + import logging + # module level functions def debug(msg, *args, **kwargs): InvokeAILogger.getLogger().debug(msg, *args, **kwargs) @@ -42,7 +45,7 @@ def warning(msg, *args, **kwargs): def error(msg, *args, **kwargs): InvokeAILogger.getLogger().error(msg, *args, **kwargs) - + def critical(msg, *args, **kwargs): InvokeAILogger.getLogger().critical(msg, *args, **kwargs) @@ -55,49 +58,47 @@ def disable(level=logging.CRITICAL): def basicConfig(**kwargs): InvokeAILogger.getLogger().basicConfig(**kwargs) -def getLogger(name: str=None)->logging.Logger: +def getLogger(name: str = None) -> logging.Logger: return InvokeAILogger.getLogger(name) + class InvokeAILogFormatter(logging.Formatter): ''' - Repurposed from: - https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3 + Custom Formatting for the InvokeAI Logger ''' - crit_fmt = "### %(msg)s" - err_fmt = "*** %(msg)s" - warn_fmt = "** %(msg)s" - info_fmt = ">> %(msg)s" - dbg_fmt = " | %(msg)s" - def __init__(self): - super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%') + # Color Codes + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + cyan = "\x1b[36;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + + # Log Format + format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s" + ## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d + + # Format Map + FORMATS = { + logging.DEBUG: cyan + format + reset, + logging.INFO: grey + format + reset, + logging.WARNING: yellow + format + reset, + logging.ERROR: red + format + reset, + logging.CRITICAL: bold_red + format + reset + } def format(self, record): - # Remember the format used when the logging module - # was installed (in the event that this formatter is - # used with the vanilla logging module. - format_orig = self._style._fmt - if record.levelno == logging.DEBUG: - self._style._fmt = InvokeAILogFormatter.dbg_fmt - if record.levelno == logging.INFO: - self._style._fmt = InvokeAILogFormatter.info_fmt - if record.levelno == logging.WARNING: - self._style._fmt = InvokeAILogFormatter.warn_fmt - if record.levelno == logging.ERROR: - self._style._fmt = InvokeAILogFormatter.err_fmt - if record.levelno == logging.CRITICAL: - self._style._fmt = InvokeAILogFormatter.crit_fmt + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S") + return formatter.format(record) - # parent class does the work - result = super().format(record) - self._style._fmt = format_orig - return result class InvokeAILogger(object): loggers = dict() - + @classmethod - def getLogger(self, name:str='invokeai')->logging.Logger: + def getLogger(self, name: str = 'InvokeAI') -> logging.Logger: if name not in self.loggers: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) diff --git a/invokeai/backend/web/modules/parameters.py b/invokeai/backend/web/modules/parameters.py index 72211857a3..9a4bc0aec3 100644 --- a/invokeai/backend/web/modules/parameters.py +++ b/invokeai/backend/web/modules/parameters.py @@ -9,6 +9,7 @@ SAMPLER_CHOICES = [ "lms", "pndm", "heun", + 'heun_k', "euler", "euler_k", "euler_a", diff --git a/invokeai/frontend/CLI/CLI.py b/invokeai/frontend/CLI/CLI.py deleted file mode 100644 index 8525853e93..0000000000 --- a/invokeai/frontend/CLI/CLI.py +++ /dev/null @@ -1,1286 +0,0 @@ -import os -import re -import shlex -import sys -import traceback -from argparse import Namespace -from pathlib import Path -from typing import Union - -import click -from compel import PromptParser - -if sys.platform == "darwin": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -import pyparsing # type: ignore - -import invokeai.version as invokeai -import invokeai.backend.util.logging as logger - -from ...backend import Generate, ModelManager -from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png -from ...backend.globals import Globals, global_config_dir -from ...backend.image_util import ( - PngWriter, - make_grid, - retrieve_metadata, - write_metadata, -) -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import url_attachment_name, write_log -from .readline import Completer, get_completer - -# global used in multiple functions (fix) -infile = None - - -def main(): - """Initialize command-line parsers and the diffusion model""" - global infile - - opt = Args() - args = opt.parse_args() - if not args: - sys.exit(-1) - - if args.laion400m: - print( - "--laion400m flag has been deprecated. Please use --model laion400m instead." - ) - sys.exit(-1) - if args.weights: - print( - "--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead." - ) - sys.exit(-1) - - # alert - setting a few globals here - Globals.try_patchmatch = args.patchmatch - Globals.always_use_cpu = args.always_use_cpu - Globals.internet_available = args.internet_available and check_internet() - Globals.disable_xformers = not args.xformers - Globals.sequential_guidance = args.sequential_guidance - Globals.ckpt_convert = True # always true now - - # run any post-install patches needed - run_patches() - - logger.info(f"Internet connectivity is {Globals.internet_available}") - - if not args.conf: - config_file = os.path.join(Globals.root, "configs", "models.yaml") - if not os.path.exists(config_file): - report_model_error( - opt, FileNotFoundError(f"The file {config_file} could not be found.") - ) - - logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}") - logger.info(f'InvokeAI runtime directory is "{Globals.root}"') - - # loading here to avoid long delays on startup - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers # type: ignore - - transformers.logging.set_verbosity_error() - import diffusers - - diffusers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan, codeformer, esrgan = load_face_restoration(opt) - - # normalize the config directory relative to root - if not os.path.isabs(opt.conf): - opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf)) - - if opt.embeddings: - if not os.path.isabs(opt.embedding_path): - embedding_path = os.path.normpath( - os.path.join(Globals.root, opt.embedding_path) - ) - else: - embedding_path = opt.embedding_path - else: - embedding_path = None - - # load the infile as a list of lines - if opt.infile: - try: - if os.path.isfile(opt.infile): - infile = open(opt.infile, "r", encoding="utf-8") - elif opt.infile == "-": # stdin - infile = sys.stdin - else: - raise FileNotFoundError(f"{opt.infile} not found.") - except (FileNotFoundError, IOError) as e: - logger.critical('Aborted',exc_info=True) - sys.exit(-1) - - # creating a Generate object: - try: - gen = Generate( - conf=opt.conf, - model=opt.model, - sampler_name=opt.sampler_name, - embedding_path=embedding_path, - full_precision=opt.full_precision, - precision=opt.precision, - gfpgan=gfpgan, - codeformer=codeformer, - esrgan=esrgan, - free_gpu_mem=opt.free_gpu_mem, - safety_checker=opt.safety_checker, - max_cache_size=opt.max_cache_size, - ) - except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt, e) - except (IOError, KeyError): - logger.critical("Aborted",exc_info=True) - sys.exit(-1) - - if opt.seamless: - logger.info("Changed to seamless tiling mode") - - # preload the model - try: - gen.load_model() - except KeyError: - pass - except Exception as e: - report_model_error(opt, e) - - # try to autoconvert new models - if path := opt.autoconvert: - gen.model_manager.heuristic_import( - str(path), commit_to_conf=opt.conf - ) - - # web server loops forever - if opt.web or opt.gui: - invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan) - sys.exit(0) - - if not infile: - print( - "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)" - ) - - try: - main_loop(gen, opt) - except KeyboardInterrupt: - print( - f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' - ) - except Exception: - logger.error("An error occurred",exc_info=True) - -# TODO: main_loop() has gotten busy. Needs to be refactored. -def main_loop(gen, opt): - """prompt/read/execute loop""" - global infile - done = False - doneAfterInFile = infile is not None - path_filter = re.compile(r'[<>:"/\\|?*]') - last_results = list() - - # The readline completer reads history from the .dream_history file located in the - # output directory specified at the time of script launch. We do not currently support - # changing the history file midstream when the output directory is changed. - completer = get_completer(opt, models=gen.model_manager.list_models()) - set_default_output_dir(opt, completer) - if gen.model_context: - add_embedding_terms(gen, completer) - output_cntr = completer.get_current_history_length() + 1 - - # os.pathconf is not available on Windows - if hasattr(os, "pathconf"): - path_max = os.pathconf(opt.outdir, "PC_PATH_MAX") - name_max = os.pathconf(opt.outdir, "PC_NAME_MAX") - else: - path_max = 260 - name_max = 255 - - while not done: - operation = "generate" - - try: - command = get_next_command(infile, gen.model_name) - except EOFError: - done = infile is None or doneAfterInFile - infile = None - continue - - # skip empty lines - if not command.strip(): - continue - - if command.startswith(("#", "//")): - continue - - if len(command.strip()) == 1 and command.startswith("q"): - done = True - break - - if not command.startswith("!history"): - completer.add_history(command) - - if command.startswith("!"): - command, operation = do_command(command, gen, opt, completer) - - if operation is None: - continue - - if opt.parse_cmd(command) is None: - continue - - if opt.init_img: - try: - if not opt.prompt: - oldargs = metadata_from_png(opt.init_img) - opt.prompt = oldargs.prompt - logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}') - except (OSError, AttributeError, KeyError): - pass - - if len(opt.prompt) == 0: - opt.prompt = "" - - # width and height are set by model if not specified - if not opt.width: - opt.width = gen.width - if not opt.height: - opt.height = gen.height - - # retrieve previous value of init image if requested - if opt.init_img is not None and re.match("^-\\d+$", opt.init_img): - try: - opt.init_img = last_results[int(opt.init_img)][0] - logger.info(f"Reusing previous image {opt.init_img}") - except IndexError: - logger.info(f"No previous initial image at position {opt.init_img} found") - opt.init_img = None - continue - - # the outdir can change with each command, so we adjust it here - set_default_output_dir(opt, completer) - - # try to relativize pathnames - for attr in ("init_img", "init_mask", "init_color"): - if getattr(opt, attr) and not os.path.exists(getattr(opt, attr)): - basename = getattr(opt, attr) - path = os.path.join(opt.outdir, basename) - setattr(opt, attr, path) - - # retrieve previous value of seed if requested - # Exception: for postprocess operations negative seed values - # mean "discard the original seed and generate a new one" - # (this is a non-obvious hack and needs to be reworked) - if opt.seed is not None and opt.seed < 0 and operation != "postprocess": - try: - opt.seed = last_results[opt.seed][1] - logger.info(f"Reusing previous seed {opt.seed}") - except IndexError: - logger.info(f"No previous seed at position {opt.seed} found") - opt.seed = None - continue - - if opt.strength is None: - opt.strength = 0.75 if opt.out_direction is None else 0.83 - - if opt.with_variations is not None: - opt.with_variations = split_variations(opt.with_variations) - - if opt.prompt_as_dir and operation == "generate": - # sanitize the prompt to a valid folder name - subdir = path_filter.sub("_", opt.prompt)[:name_max].rstrip(" .") - - # truncate path to maximum allowed length - # 39 is the length of '######.##########.##########-##.png', plus two separators and a NUL - subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))] - current_outdir = os.path.join(opt.outdir, subdir) - - logger.info('Writing files to directory: "' + current_outdir + '"') - - # make sure the output directory exists - if not os.path.exists(current_outdir): - os.makedirs(current_outdir) - else: - if not os.path.exists(opt.outdir): - os.makedirs(opt.outdir) - current_outdir = opt.outdir - - # Here is where the images are actually generated! - last_results = [] - try: - file_writer = PngWriter(current_outdir) - results = [] # list of filename, prompt pairs - grid_images = dict() # seed -> Image, only used if `opt.grid` - prior_variations = opt.with_variations or [] - prefix = file_writer.unique_prefix() - step_callback = ( - make_step_callback(gen, opt, prefix) - if opt.save_intermediates > 0 - else None - ) - - def image_writer( - image, - seed, - upscaled=False, - first_seed=None, - use_prefix=None, - prompt_in=None, - attention_maps_image=None, - ): - # note the seed is the seed of the current image - # the first_seed is the original seed that noise is added to - # when the -v switch is used to generate variations - nonlocal prior_variations - nonlocal prefix - - path = None - if opt.grid: - grid_images[seed] = image - - elif operation == "mask": - filename = f"{prefix}.{use_prefix}.{seed}.png" - tm = opt.text_mask[0] - th = opt.text_mask[1] if len(opt.text_mask) > 1 else 0.5 - formatted_dream_prompt = ( - f"!mask {opt.input_file_path} -tm {tm} {th}" - ) - path = file_writer.save_image_and_prompt_to_png( - image=image, - dream_prompt=formatted_dream_prompt, - metadata={}, - name=filename, - compress_level=opt.png_compression, - ) - results.append([path, formatted_dream_prompt]) - - else: - if use_prefix is not None: - prefix = use_prefix - postprocessed = upscaled if upscaled else operation == "postprocess" - opt.prompt = ( - gen.huggingface_concepts_library.replace_triggers_with_concepts( - opt.prompt or prompt_in - ) - ) # to avoid the problem of non-unique concept triggers - filename, formatted_dream_prompt = prepare_image_metadata( - opt, - prefix, - seed, - operation, - prior_variations, - postprocessed, - first_seed, - ) - path = file_writer.save_image_and_prompt_to_png( - image=image, - dream_prompt=formatted_dream_prompt, - metadata=metadata_dumps( - opt, - seeds=[ - seed - if opt.variation_amount == 0 - and len(prior_variations) == 0 - else first_seed - ], - model_hash=gen.model_hash, - ), - name=filename, - compress_level=opt.png_compression, - ) - - # update rfc metadata - if operation == "postprocess": - tool = re.match( - "postprocess:(\w+)", opt.last_operation - ).groups()[0] - add_postprocessing_to_metadata( - opt, - opt.input_file_path, - filename, - tool, - formatted_dream_prompt, - ) - - if (not postprocessed) or opt.save_original: - # only append to results if we didn't overwrite an earlier output - results.append([path, formatted_dream_prompt]) - - # so that the seed autocompletes (on linux|mac when -S or --seed specified - if completer and operation == "generate": - completer.add_seed(seed) - completer.add_seed(first_seed) - last_results.append([path, seed]) - - if operation == "generate": - catch_ctrl_c = ( - infile is None - ) # if running interactively, we catch keyboard interrupts - opt.last_operation = "generate" - try: - gen.prompt2image( - image_callback=image_writer, - step_callback=step_callback, - catch_interrupts=catch_ctrl_c, - **vars(opt), - ) - except (PromptParser.ParsingException, pyparsing.ParseException): - logger.error("An error occurred while processing your prompt",exc_info=True) - elif operation == "postprocess": - logger.info(f"fixing {opt.prompt}") - opt.last_operation = do_postprocess(gen, opt, image_writer) - - elif operation == "mask": - logger.info(f"generating masks from {opt.prompt}") - do_textmask(gen, opt, image_writer) - - if opt.grid and len(grid_images) > 0: - grid_img = make_grid(list(grid_images.values())) - grid_seeds = list(grid_images.keys()) - first_seed = last_results[0][1] - filename = f"{prefix}.{first_seed}.png" - formatted_dream_prompt = opt.dream_prompt_str( - seed=first_seed, grid=True, iterations=len(grid_images) - ) - formatted_dream_prompt += f" # {grid_seeds}" - metadata = metadata_dumps( - opt, seeds=grid_seeds, model_hash=gen.model_hash - ) - path = file_writer.save_image_and_prompt_to_png( - image=grid_img, - dream_prompt=formatted_dream_prompt, - metadata=metadata, - name=filename, - ) - results = [[path, formatted_dream_prompt]] - - except AssertionError: - logger.error(e) - continue - - except OSError as e: - logger.error(e) - continue - - print("Outputs:") - log_path = os.path.join(current_outdir, "invoke_log") - output_cntr = write_log(results, log_path, ("txt", "md"), output_cntr) - print() - - print( - f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}' - ) - - -# TO DO: remove repetitive code and the awkward command.replace() trope -# Just do a simple parse of the command! -def do_command(command: str, gen, opt: Args, completer) -> tuple: - global infile - operation = "generate" # default operation, alternative is 'postprocess' - command = command.replace("\\", "/") # windows - - if command.startswith( - "!dream" - ): # in case a stored prompt still contains the !dream command - command = command.replace("!dream ", "", 1) - - elif command.startswith("!fix"): - command = command.replace("!fix ", "", 1) - operation = "postprocess" - - elif command.startswith("!mask"): - command = command.replace("!mask ", "", 1) - operation = "mask" - - elif command.startswith("!switch"): - model_name = command.replace("!switch ", "", 1) - try: - gen.set_model(model_name) - add_embedding_terms(gen, completer) - except KeyError as e: - logger.error(e) - except Exception as e: - report_model_error(opt, e) - completer.add_history(command) - operation = None - - elif command.startswith("!models"): - gen.model_manager.print_models() - completer.add_history(command) - operation = None - - elif command.startswith("!import"): - path = shlex.split(command) - if len(path) < 2: - logger.warning( - "please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1" - ) - else: - try: - import_model(path[1], gen, opt, completer) - completer.add_history(command) - except KeyboardInterrupt: - print("\n") - operation = None - - elif command.startswith(("!convert", "!optimize")): - path = shlex.split(command) - if len(path) < 2: - logger.warning("please provide the path to a .ckpt or .safetensors model") - else: - try: - convert_model(path[1], gen, opt, completer) - completer.add_history(command) - except KeyboardInterrupt: - print("\n") - operation = None - - elif command.startswith("!edit"): - path = shlex.split(command) - if len(path) < 2: - logger.warning("please provide the name of a model") - else: - edit_model(path[1], gen, opt, completer) - completer.add_history(command) - operation = None - - elif command.startswith("!del"): - path = shlex.split(command) - if len(path) < 2: - logger.warning("please provide the name of a model") - else: - del_config(path[1], gen, opt, completer) - completer.add_history(command) - operation = None - - elif command.startswith("!fetch"): - file_path = command.replace("!fetch", "", 1).strip() - retrieve_dream_command(opt, file_path, completer) - completer.add_history(command) - operation = None - - elif command.startswith("!replay"): - file_path = command.replace("!replay", "", 1).strip() - file_path = os.path.join(opt.outdir, file_path) - if infile is None and os.path.isfile(file_path): - infile = open(file_path, "r", encoding="utf-8") - completer.add_history(command) - operation = None - - elif command.startswith("!trigger"): - print("Embedding trigger strings: ", ", ".join(gen.embedding_trigger_strings)) - operation = None - - elif command.startswith("!history"): - completer.show_history() - operation = None - - elif command.startswith("!search"): - search_str = command.replace("!search", "", 1).strip() - completer.show_history(search_str) - operation = None - - elif command.startswith("!clear"): - completer.clear_history() - operation = None - - elif re.match("^!(\d+)", command): - command_no = re.match("^!(\d+)", command).groups()[0] - command = completer.get_line(int(command_no)) - completer.set_line(command) - operation = None - - else: # not a recognized command, so give the --help text - command = "-h" - return command, operation - - -def set_default_output_dir(opt: Args, completer: Completer): - """ - If opt.outdir is relative, we add the root directory to it - normalize the outdir relative to root and make sure it exists. - """ - if not os.path.isabs(opt.outdir): - opt.outdir = os.path.normpath(os.path.join(Globals.root, opt.outdir)) - if not os.path.exists(opt.outdir): - os.makedirs(opt.outdir) - completer.set_default_dir(opt.outdir) - - -def import_model(model_path: str, gen, opt, completer): - """ - model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; - (3) a huggingface repository id; or (4) a local directory containing a - diffusers model. - """ - default_name = Path(model_path).stem - model_name = None - model_desc = None - - if ( - Path(model_path).is_dir() - and not (Path(model_path) / "model_index.json").exists() - ): - pass - else: - if model_path.startswith(("http:", "https:")): - try: - default_name = url_attachment_name(model_path) - default_name = Path(default_name).stem - except Exception: - logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True) - model_name, model_desc = _get_model_name_and_desc( - gen.model_manager, - completer, - model_name=default_name, - ) - imported_name = gen.model_manager.heuristic_import( - model_path, - model_name=model_name, - description=model_desc, - ) - - if not imported_name: - if config_file := _pick_configuration_file(completer): - imported_name = gen.model_manager.heuristic_import( - model_path, - model_name=model_name, - description=model_desc, - model_config_file=config_file, - ) - if not imported_name: - logger.error("Aborting import.") - return - - if not _verify_load(imported_name, gen): - logger.error("model failed to load. Discarding configuration entry") - gen.model_manager.del_model(imported_name) - return - if click.confirm("Make this the default model?", default=False): - gen.model_manager.set_default_model(imported_name) - - gen.model_manager.commit(opt.conf) - completer.update_models(gen.model_manager.list_models()) - logger.info(f"{imported_name} successfully installed") - -def _pick_configuration_file(completer)->Path: - print( -""" -Please select the type of this model: -[1] A Stable Diffusion v1.x ckpt/safetensors model -[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model -[3] A Stable Diffusion v2.x base model (512 pixels) -[4] A Stable Diffusion v2.x v-predictive model (768 pixels) -[5] Other (you will be prompted to enter the config file path) -[Q] I have no idea! Skip the import. -""") - choices = [ - global_config_dir() / 'stable-diffusion' / x - for x in [ - 'v1-inference.yaml', - 'v1-inpainting-inference.yaml', - 'v2-inference.yaml', - 'v2-inference-v.yaml', - ] - ] - - ok = False - while not ok: - try: - choice = input('select 0-5, Q > ').strip() - if choice.startswith(('q','Q')): - return - if choice == '5': - completer.complete_extensions(('.yaml')) - choice = Path(input('Select config file for this model> ').strip()).absolute() - completer.complete_extensions(None) - ok = choice.exists() - else: - choice = choices[int(choice)-1] - ok = True - except (ValueError, IndexError): - print(f'{choice} is not a valid choice') - except EOFError: - return - return choice - -def _verify_load(model_name: str, gen) -> bool: - logger.info("Verifying that new model loads...") - current_model = gen.model_name - try: - if not gen.set_model(model_name): - return - except Exception as e: - logger.warning(f"model failed to load: {str(e)}") - logger.warning( - "** note that importing 2.X checkpoints is not supported. Please use !convert_model instead." - ) - return False - if click.confirm("Keep model loaded?", default=True): - gen.set_model(model_name) - else: - logger.info("Restoring previous model") - gen.set_model(current_model) - return True - - -def _get_model_name_and_desc( - model_manager, completer, model_name: str = "", model_description: str = "" -): - model_name = _get_model_name(model_manager.list_models(), completer, model_name) - model_description = model_description or f"Imported model {model_name}" - completer.set_line(model_description) - model_description = ( - input(f"Description for this model [{model_description}]: ").strip() - or model_description - ) - return model_name, model_description - -def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): - model_name_or_path = model_name_or_path.replace("\\", "/") # windows - manager = gen.model_manager - ckpt_path = None - original_config_file = None - if model_name_or_path == gen.model_name: - logger.warning("Can't convert the active model. !switch to another model first. **") - return - elif model_info := manager.model_info(model_name_or_path): - if "weights" in model_info: - ckpt_path = Path(model_info["weights"]) - original_config_file = Path(model_info["config"]) - model_name = model_name_or_path - model_description = model_info["description"] - vae_path = model_info.get("vae") - else: - logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file") - return - model_name = manager.convert_and_import( - ckpt_path, - diffusers_path=Path( - Globals.root, "models", Globals.converted_ckpts_dir, model_name_or_path - ), - model_name=model_name, - model_description=model_description, - original_config_file=original_config_file, - vae_path=vae_path, - ) - else: - try: - import_model(model_name_or_path, gen, opt, completer) - except KeyboardInterrupt: - return - - manager.commit(opt.conf) - if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False): - ckpt_path.unlink(missing_ok=True) - logger.warning(f"{ckpt_path} deleted") - - -def del_config(model_name: str, gen, opt, completer): - current_model = gen.model_name - if model_name == current_model: - logger.warning("Can't delete active model. !switch to another model first. **") - return - if model_name not in gen.model_manager.config: - logger.warning(f"Unknown model {model_name}") - return - - if not click.confirm( - f"Remove {model_name} from the list of models known to InvokeAI?", default=True - ): - return - - delete_completely = click.confirm( - "Completely remove the model file or directory from disk?", default=False - ) - gen.model_manager.del_model(model_name, delete_files=delete_completely) - gen.model_manager.commit(opt.conf) - logger.warning(f"{model_name} deleted") - completer.update_models(gen.model_manager.list_models()) - - -def edit_model(model_name: str, gen, opt, completer): - manager = gen.model_manager - if not (info := manager.model_info(model_name)): - logger.warning(f"** Unknown model {model_name}") - return - print() - logger.info(f"Editing model {model_name} from configuration file {opt.conf}") - new_name = _get_model_name(manager.list_models(), completer, model_name) - - for attribute in info.keys(): - if type(info[attribute]) != str: - continue - if attribute == "format": - continue - completer.set_line(info[attribute]) - info[attribute] = input(f"{attribute}: ") or info[attribute] - - if info["format"] == "diffusers": - vae = info.get("vae", dict(repo_id=None, path=None, subfolder=None)) - completer.set_line(vae.get("repo_id") or "stabilityai/sd-vae-ft-mse") - vae["repo_id"] = input("External VAE repo_id: ").strip() or None - if not vae["repo_id"]: - completer.set_line(vae.get("path") or "") - vae["path"] = ( - input("Path to a local diffusers VAE model (usually none): ").strip() - or None - ) - completer.set_line(vae.get("subfolder") or "") - vae["subfolder"] = ( - input("Name of subfolder containing the VAE model (usually none): ").strip() - or None - ) - info["vae"] = vae - - if new_name != model_name: - manager.del_model(model_name) - - # this does the update - manager.add_model(new_name, info, True) - - if click.confirm("Make this the default model?", default=False): - manager.set_default_model(new_name) - manager.commit(opt.conf) - completer.update_models(manager.list_models()) - logger.info("Model successfully updated") - - -def _get_model_name(existing_names, completer, default_name: str = "") -> str: - done = False - completer.set_line(default_name) - while not done: - model_name = input(f"Short name for this model [{default_name}]: ").strip() - if len(model_name) == 0: - model_name = default_name - if not re.match("^[\w._+:/-]+$", model_name): - logger.warning( - 'model name must contain only words, digits and the characters "._+:/-" **' - ) - elif model_name != default_name and model_name in existing_names: - logger.warning(f"the name {model_name} is already in use. Pick another.") - else: - done = True - return model_name - - -def do_textmask(gen, opt, callback): - image_path = opt.prompt - if not os.path.exists(image_path): - image_path = os.path.join(opt.outdir, image_path) - assert os.path.exists( - image_path - ), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' - assert ( - opt.text_mask is not None and len(opt.text_mask) >= 1 - ), "** Please provide a text mask with -tm **" - opt.input_file_path = image_path - tm = opt.text_mask[0] - threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 - gen.apply_textmask( - image_path=image_path, - prompt=tm, - threshold=threshold, - callback=callback, - ) - - -def do_postprocess(gen, opt, callback): - file_path = opt.prompt # treat the prompt as the file pathname - if opt.new_prompt is not None: - opt.prompt = opt.new_prompt - else: - opt.prompt = None - - if os.path.dirname(file_path) == "": # basename given - file_path = os.path.join(opt.outdir, file_path) - - opt.input_file_path = file_path - - tool = None - if opt.facetool_strength > 0: - tool = opt.facetool - elif opt.embiggen: - tool = "embiggen" - elif opt.upscale: - tool = "upscale" - elif opt.out_direction: - tool = "outpaint" - elif opt.outcrop: - tool = "outcrop" - opt.save_original = True # do not overwrite old image! - opt.last_operation = f"postprocess:{tool}" - try: - gen.apply_postprocessor( - image_path=file_path, - tool=tool, - facetool_strength=opt.facetool_strength, - codeformer_fidelity=opt.codeformer_fidelity, - save_original=opt.save_original, - upscale=opt.upscale, - upscale_denoise_str=opt.esrgan_denoise_str, - out_direction=opt.out_direction, - outcrop=opt.outcrop, - callback=callback, - opt=opt, - ) - except OSError: - logger.error(f"{file_path}: file could not be read",exc_info=True) - return - except (KeyError, AttributeError): - logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True) - return - return opt.last_operation - - -def add_postprocessing_to_metadata(opt, original_file, new_file, tool, command): - original_file = ( - original_file - if os.path.exists(original_file) - else os.path.join(opt.outdir, original_file) - ) - new_file = ( - new_file if os.path.exists(new_file) else os.path.join(opt.outdir, new_file) - ) - try: - meta = retrieve_metadata(original_file)["sd-metadata"] - except AttributeError: - try: - meta = retrieve_metadata(new_file)["sd-metadata"] - except AttributeError: - meta = {} - - if "image" not in meta: - meta = metadata_dumps(opt, seeds=[opt.seed])["image"] - meta["image"] = {} - img_data = meta.get("image") - pp = img_data.get("postprocessing", []) or [] - pp.append( - { - "tool": tool, - "dream_command": command, - } - ) - meta["image"]["postprocessing"] = pp - write_metadata(new_file, meta) - - -def prepare_image_metadata( - opt, - prefix, - seed, - operation="generate", - prior_variations=[], - postprocessed=False, - first_seed=None, -): - if postprocessed and opt.save_original: - filename = choose_postprocess_name(opt, prefix, seed) - else: - wildcards = dict(opt.__dict__) - wildcards["prefix"] = prefix - wildcards["seed"] = seed - try: - filename = opt.fnformat.format(**wildcards) - except KeyError as e: - logger.error( - f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead" - ) - filename = f"{prefix}.{seed}.png" - except IndexError: - logger.error( - "The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead" - ) - filename = f"{prefix}.{seed}.png" - - if opt.variation_amount > 0: - first_seed = first_seed or seed - this_variation = [[seed, opt.variation_amount]] - opt.with_variations = prior_variations + this_variation - formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) - elif len(prior_variations) > 0: - formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed) - elif operation == "postprocess": - formatted_dream_prompt = "!fix " + opt.dream_prompt_str( - seed=seed, prompt=opt.input_file_path - ) - else: - formatted_dream_prompt = opt.dream_prompt_str(seed=seed) - return filename, formatted_dream_prompt - - -def choose_postprocess_name(opt, prefix, seed) -> str: - match = re.search("postprocess:(\w+)", opt.last_operation) - if match: - modifier = match.group( - 1 - ) # will look like "gfpgan", "upscale", "outpaint" or "embiggen" - else: - modifier = "postprocessed" - - counter = 0 - filename = None - available = False - while not available: - if counter == 0: - filename = f"{prefix}.{seed}.{modifier}.png" - else: - filename = f"{prefix}.{seed}.{modifier}-{counter:02d}.png" - available = not os.path.exists(os.path.join(opt.outdir, filename)) - counter += 1 - return filename - - -def get_next_command(infile=None, model_name="no model") -> str: # command string - if infile is None: - command = input(f"({model_name}) invoke> ").strip() - else: - command = infile.readline() - if not command: - raise EOFError - else: - command = command.strip() - if len(command) > 0: - print(f"#{command}") - return command - - -def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan): - print("\n* --web was specified, starting web server...") - from invokeai.backend.web import InvokeAIWebServer - - # Change working directory to the stable-diffusion directory - os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - - invoke_ai_web_server = InvokeAIWebServer( - generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan - ) - - try: - invoke_ai_web_server.run() - except KeyboardInterrupt: - pass - - -def add_embedding_terms(gen, completer): - """ - Called after setting the model, updates the autocompleter with - any terms loaded by the embedding manager. - """ - with gen.model_context as model: - trigger_strings = model.textual_inversion_manager.get_all_trigger_strings() - completer.add_embedding_terms(trigger_strings) - - -def split_variations(variations_string) -> list: - # shotgun parsing, woo - parts = [] - broken = False # python doesn't have labeled loops... - for part in variations_string.split(","): - seed_and_weight = part.split(":") - if len(seed_and_weight) != 2: - logger.warning(f'Could not parse with_variation part "{part}"') - broken = True - break - try: - seed = int(seed_and_weight[0]) - weight = float(seed_and_weight[1]) - except ValueError: - logger.warning(f'Could not parse with_variation part "{part}"') - broken = True - break - parts.append([seed, weight]) - if broken: - return None - elif len(parts) == 0: - return None - else: - return parts - - -def load_face_restoration(opt): - try: - gfpgan, codeformer, esrgan = None, None, None - if opt.restore or opt.esrgan: - from invokeai.backend.restoration import Restoration - - restoration = Restoration() - if opt.restore: - gfpgan, codeformer = restoration.load_face_restore_models( - opt.gfpgan_model_path - ) - else: - logger.info("Face restoration disabled") - if opt.esrgan: - esrgan = restoration.load_esrgan(opt.esrgan_bg_tile) - else: - logger.info("Upscaling disabled") - else: - logger.info("Face restoration and upscaling disabled") - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - logger.info("You may need to install the ESRGAN and/or GFPGAN modules") - return gfpgan, codeformer, esrgan - - -def make_step_callback(gen, opt, prefix): - destination = os.path.join(opt.outdir, "intermediates", prefix) - os.makedirs(destination, exist_ok=True) - logger.info(f"Intermediate images will be written into {destination}") - - def callback(state: PipelineIntermediateState): - latents = state.latents - step = state.step - if step % opt.save_intermediates == 0 or step == opt.steps - 1: - filename = os.path.join(destination, f"{step:04}.png") - image = gen.sample_to_lowres_estimated_image(latents) - image = image.resize((image.size[0] * 8, image.size[1] * 8)) - image.save(filename, "PNG") - - return callback - - -def retrieve_dream_command(opt, command, completer): - """ - Given a full or partial path to a previously-generated image file, - will retrieve and format the dream command used to generate the image, - and pop it into the readline buffer (linux, Mac), or print out a comment - for cut-and-paste (windows) - - Given a wildcard path to a folder with image png files, - will retrieve and format the dream command used to generate the images, - and save them to a file commands.txt for further processing - """ - if len(command) == 0: - return - - tokens = command.split() - dir, basename = os.path.split(tokens[0]) - if len(dir) == 0: - path = os.path.join(opt.outdir, basename) - else: - path = tokens[0] - - if len(tokens) > 1: - return write_commands(opt, path, tokens[1]) - - cmd = "" - try: - cmd = dream_cmd_from_png(path) - except OSError: - logger.error(f"{tokens[0]}: file could not be read") - except (KeyError, AttributeError, IndexError): - logger.error(f"{tokens[0]}: file has no metadata") - except: - logger.error(f"{tokens[0]}: file could not be processed") - if len(cmd) > 0: - completer.set_line(cmd) - -def write_commands(opt, file_path: str, outfilepath: str): - dir, basename = os.path.split(file_path) - try: - paths = sorted(list(Path(dir).glob(basename))) - except ValueError: - logger.error(f'"{basename}": unacceptable pattern') - return - - commands = [] - cmd = None - for path in paths: - try: - cmd = dream_cmd_from_png(path) - except (KeyError, AttributeError, IndexError): - logger.error(f"{path}: file has no metadata") - except: - logger.error(f"{path}: file could not be processed") - if cmd: - commands.append(f"# {path}") - commands.append(cmd) - if len(commands) > 0: - dir, basename = os.path.split(outfilepath) - if len(dir) == 0: - outfilepath = os.path.join(opt.outdir, basename) - with open(outfilepath, "w", encoding="utf-8") as f: - f.write("\n".join(commands)) - logger.info(f"File {outfilepath} with commands created") - - -def report_model_error(opt: Namespace, e: Exception): - logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"') - logger.warning( - "This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." - ) - traceback.print_exc() - yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") - if yes_to_all: - logger.warning( - "Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE" - ) - else: - if not click.confirm( - "Do you want to run invokeai-configure script to select and/or reinstall models?", - default=False, - ): - return - - logger.info("invokeai-configure is launching....\n") - - # Match arguments that were set on the CLI - # only the arguments accepted by the configuration script are parsed - root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else [] - config = ["--config", opt.conf] if opt.conf is not None else [] - previous_args = sys.argv - sys.argv = ["invokeai-configure"] - sys.argv.extend(root_dir) - sys.argv.extend(config) - if yes_to_all is not None: - for arg in yes_to_all.split(): - sys.argv.append(arg) - - from ..install import invokeai_configure - - invokeai_configure() - logger.warning("InvokeAI will now restart") - sys.argv = previous_args - main() # would rather do a os.exec(), but doesn't exist? - sys.exit(0) - - -def check_internet() -> bool: - """ - Return true if the internet is reachable. - It does this by pinging huggingface.co. - """ - import urllib.request - - host = "http://huggingface.co" - try: - urllib.request.urlopen(host, timeout=1) - return True - except: - return False - -# This routine performs any patch-ups needed after installation -def run_patches(): - # install ckpt configuration files that may have been added to the - # distro after original root directory configuration - import invokeai.configs as conf - from shutil import copyfile - - root_configs = Path(global_config_dir(), 'stable-diffusion') - repo_configs = Path(conf.__path__[0], 'stable-diffusion') - if not root_configs.exists(): - os.makedirs(root_configs, exist_ok=True) - for src in repo_configs.iterdir(): - dest = root_configs / src.name - if not dest.exists(): - copyfile(src, dest) - -if __name__ == "__main__": - main() diff --git a/invokeai/frontend/CLI/readline.py b/invokeai/frontend/CLI/readline.py deleted file mode 100644 index 228ab88b57..0000000000 --- a/invokeai/frontend/CLI/readline.py +++ /dev/null @@ -1,498 +0,0 @@ -""" -Readline helper functions for invoke.py. -You may import the global singleton `completer` to get access to the -completer object itself. This is useful when you want to autocomplete -seeds: - - from invokeai.frontend.CLI.readline import completer - completer.add_seed(18247566) - completer.add_seed(9281839) -""" -import atexit -import os -import re - -from ...backend.args import Args -from ...backend.globals import Globals -from ...backend.stable_diffusion import HuggingFaceConceptsLibrary - -# ---------------readline utilities--------------------- -try: - import readline - - readline_available = True -except (ImportError, ModuleNotFoundError) as e: - print(f"** An error occurred when loading the readline module: {str(e)}") - readline_available = False - -IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF") -WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors") -TEXT_EXTENSIONS = (".txt", ".TXT") -CONFIG_EXTENSIONS = (".yaml", ".yml") -COMMANDS = ( - "--steps", - "-s", - "--seed", - "-S", - "--iterations", - "-n", - "--width", - "-W", - "--height", - "-H", - "--cfg_scale", - "-C", - "--threshold", - "--perlin", - "--grid", - "-g", - "--individual", - "-i", - "--save_intermediates", - "--init_img", - "-I", - "--init_mask", - "-M", - "--init_color", - "--strength", - "-f", - "--variants", - "-v", - "--outdir", - "-o", - "--sampler", - "-A", - "-m", - "--embedding_path", - "--device", - "--grid", - "-g", - "--facetool", - "-ft", - "--facetool_strength", - "-G", - "--codeformer_fidelity", - "-cf", - "--upscale", - "-U", - "-save_orig", - "--save_original", - "--log_tokenization", - "-t", - "--hires_fix", - "--inpaint_replace", - "-r", - "--png_compression", - "-z", - "--text_mask", - "-tm", - "--h_symmetry_time_pct", - "--v_symmetry_time_pct", - "!fix", - "!fetch", - "!replay", - "!history", - "!search", - "!clear", - "!models", - "!switch", - "!import_model", - "!optimize_model", - "!convert_model", - "!edit_model", - "!del_model", - "!mask", - "!triggers", -) -MODEL_COMMANDS = ( - "!switch", - "!edit_model", - "!del_model", -) -CKPT_MODEL_COMMANDS = ("!optimize_model",) -WEIGHT_COMMANDS = ( - "!import_model", - "!convert_model", -) -IMG_PATH_COMMANDS = ("--outdir[=\s]",) -TEXT_PATH_COMMANDS = ("!replay",) -IMG_FILE_COMMANDS = ( - "!fix", - "!fetch", - "!mask", - "--init_img[=\s]", - "-I", - "--init_mask[=\s]", - "-M", - "--init_color[=\s]", - "--embedding_path[=\s]", -) - -path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$" -weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$" -text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$" - - -class Completer(object): - def __init__(self, options, models={}): - self.options = sorted(options) - self.models = models - self.seeds = set() - self.matches = list() - self.default_dir = None - self.linebuffer = None - self.auto_history_active = True - self.extensions = None - self.concepts = None - self.embedding_terms = set() - return - - def complete(self, text, state): - """ - Completes invoke command line. - BUG: it doesn't correctly complete files that have spaces in the name. - """ - buffer = readline.get_line_buffer() - - if state == 0: - # extensions defined, so go directly into path completion mode - if self.extensions is not None: - self.matches = self._path_completions(text, state, self.extensions) - - # looking for an image file - elif re.search(path_regexp, buffer): - do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer) - self.matches = self._path_completions( - text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut - ) - - # looking for a seed - elif re.search("(-S\s*|--seed[=\s])\d*$", buffer): - self.matches = self._seed_completions(text, state) - - # looking for an embedding concept - elif re.search("<[\w-]*$", buffer): - self.matches = self._concept_completions(text, state) - - # looking for a model - elif re.match("^" + "|".join(MODEL_COMMANDS), buffer): - self.matches = self._model_completions(text, state) - - # looking for a ckpt model - elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer): - self.matches = self._model_completions(text, state, ckpt_only=True) - - elif re.search(weight_regexp, buffer): - self.matches = self._path_completions( - text, - state, - WEIGHT_EXTENSIONS, - default_dir=Globals.root, - ) - - elif re.search(text_regexp, buffer): - self.matches = self._path_completions(text, state, TEXT_EXTENSIONS) - - # This is the first time for this text, so build a match list. - elif text: - self.matches = [s for s in self.options if s and s.startswith(text)] - else: - self.matches = self.options[:] - - # Return the state'th item from the match list, - # if we have that many. - try: - response = self.matches[state] - except IndexError: - response = None - return response - - def complete_extensions(self, extensions: list): - """ - If called with a list of extensions, will force completer - to do file path completions. - """ - self.extensions = extensions - - def add_history(self, line): - """ - Pass thru to readline - """ - if not self.auto_history_active: - readline.add_history(line) - - def clear_history(self): - """ - Pass clear_history() thru to readline - """ - readline.clear_history() - - def search_history(self, match: str): - """ - Like show_history() but only shows items that - contain the match string. - """ - self.show_history(match) - - def remove_history_item(self, pos): - readline.remove_history_item(pos) - - def add_seed(self, seed): - """ - Add a seed to the autocomplete list for display when -S is autocompleted. - """ - if seed is not None: - self.seeds.add(str(seed)) - - def set_default_dir(self, path): - self.default_dir = path - - def set_options(self, options): - self.options = options - - def get_line(self, index): - try: - line = self.get_history_item(index) - except IndexError: - return None - return line - - def get_current_history_length(self): - return readline.get_current_history_length() - - def get_history_item(self, index): - return readline.get_history_item(index) - - def show_history(self, match=None): - """ - Print the session history using the pydoc pager - """ - import pydoc - - lines = list() - h_len = self.get_current_history_length() - if h_len < 1: - print("") - return - - for i in range(0, h_len): - line = self.get_history_item(i + 1) - if match and match not in line: - continue - lines.append(f"[{i+1}] {line}") - pydoc.pager("\n".join(lines)) - - def set_line(self, line) -> None: - """ - Set the default string displayed in the next line of input. - """ - self.linebuffer = line - readline.redisplay() - - def update_models(self, models: dict) -> None: - """ - update our list of models - """ - self.models = models - - def _seed_completions(self, text, state): - m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text) - if m: - switch = m.groups()[0] - partial = m.groups()[1] - else: - switch = "" - partial = text - - matches = list() - for s in self.seeds: - if s.startswith(partial): - matches.append(switch + s) - matches.sort() - return matches - - def add_embedding_terms(self, terms: list[str]): - self.embedding_terms = set(terms) - if self.concepts: - self.embedding_terms.update(set(self.concepts.list_concepts())) - - def _concept_completions(self, text, state): - if self.concepts is None: - # cache Concepts() instance so we can check for updates in concepts_list during runtime. - self.concepts = HuggingFaceConceptsLibrary() - self.embedding_terms.update(set(self.concepts.list_concepts())) - else: - self.embedding_terms.update(set(self.concepts.list_concepts())) - - partial = text[1:] # this removes the leading '<' - if len(partial) == 0: - return list(self.embedding_terms) # whole dump - think if user wants this! - - matches = list() - for concept in self.embedding_terms: - if concept.startswith(partial): - matches.append(f"<{concept}>") - matches.sort() - return matches - - def _model_completions(self, text, state, ckpt_only=False): - m = re.search("(!switch\s+)(\w*)", text) - if m: - switch = m.groups()[0] - partial = m.groups()[1] - else: - switch = "" - partial = text - matches = list() - for s in self.models: - name = self.models[s]["model_name"] - format = self.models[s]["format"] - if format == "vae": - continue - if ckpt_only and format != "ckpt": - continue - if name.startswith(partial): - matches.append(switch + name) - matches.sort() - return matches - - def _pre_input_hook(self): - if self.linebuffer: - readline.insert_text(self.linebuffer) - readline.redisplay() - self.linebuffer = None - - def _path_completions( - self, text, state, extensions, shortcut_ok=True, default_dir: str = "" - ): - # separate the switch from the partial path - match = re.search("^(-\w|--\w+=?)(.*)", text) - if match is None: - switch = None - partial_path = text - else: - switch, partial_path = match.groups() - - partial_path = partial_path.lstrip() - - matches = list() - path = os.path.expanduser(partial_path) - - if os.path.isdir(path): - dir = path - elif os.path.dirname(path) != "": - dir = os.path.dirname(path) - else: - dir = default_dir if os.path.exists(default_dir) else "" - path = os.path.join(dir, path) - - dir_list = os.listdir(dir or ".") - if shortcut_ok and os.path.exists(self.default_dir) and dir == "": - dir_list += os.listdir(self.default_dir) - - for node in dir_list: - if node.startswith(".") and len(node) > 1: - continue - full_path = os.path.join(dir, node) - - if not (node.endswith(extensions) or os.path.isdir(full_path)): - continue - - if path and not full_path.startswith(path): - continue - - if switch is None: - match_path = os.path.join(dir, node) - matches.append( - match_path + "/" if os.path.isdir(full_path) else match_path - ) - elif os.path.isdir(full_path): - matches.append( - switch + os.path.join(os.path.dirname(full_path), node) + "/" - ) - elif node.endswith(extensions): - matches.append(switch + os.path.join(os.path.dirname(full_path), node)) - - return matches - - -class DummyCompleter(Completer): - def __init__(self, options): - super().__init__(options) - self.history = list() - - def add_history(self, line): - self.history.append(line) - - def clear_history(self): - self.history = list() - - def get_current_history_length(self): - return len(self.history) - - def get_history_item(self, index): - return self.history[index - 1] - - def remove_history_item(self, index): - return self.history.pop(index - 1) - - def set_line(self, line): - print(f"# {line}") - - -def generic_completer(commands: list) -> Completer: - if readline_available: - completer = Completer(commands, []) - readline.set_completer(completer.complete) - readline.set_pre_input_hook(completer._pre_input_hook) - readline.set_completer_delims(" ") - readline.parse_and_bind("tab: complete") - readline.parse_and_bind("set print-completions-horizontally off") - readline.parse_and_bind("set page-completions on") - readline.parse_and_bind("set skip-completed-text on") - readline.parse_and_bind("set show-all-if-ambiguous on") - else: - completer = DummyCompleter(commands) - return completer - - -def get_completer(opt: Args, models=[]) -> Completer: - if readline_available: - completer = Completer(COMMANDS, models) - - readline.set_completer(completer.complete) - # pyreadline3 does not have a set_auto_history() method - try: - readline.set_auto_history(False) - completer.auto_history_active = False - except: - completer.auto_history_active = True - readline.set_pre_input_hook(completer._pre_input_hook) - readline.set_completer_delims(" ") - readline.parse_and_bind("tab: complete") - readline.parse_and_bind("set print-completions-horizontally off") - readline.parse_and_bind("set page-completions on") - readline.parse_and_bind("set skip-completed-text on") - readline.parse_and_bind("set show-all-if-ambiguous on") - - outdir = os.path.expanduser(opt.outdir) - if os.path.isabs(outdir): - histfile = os.path.join(outdir, ".invoke_history") - else: - histfile = os.path.join(Globals.root, outdir, ".invoke_history") - try: - readline.read_history_file(histfile) - readline.set_history_length(1000) - except FileNotFoundError: - pass - except OSError: # file likely corrupted - newname = f"{histfile}.old" - print( - f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}" - ) - os.replace(histfile, newname) - atexit.register(readline.write_history_file, histfile) - - else: - completer = DummyCompleter(COMMANDS) - return completer diff --git a/invokeai/frontend/CLI/sd_metadata.py b/invokeai/frontend/CLI/sd_metadata.py deleted file mode 100644 index c26907a18e..0000000000 --- a/invokeai/frontend/CLI/sd_metadata.py +++ /dev/null @@ -1,30 +0,0 @@ -''' -This is a modularized version of the sd-metadata.py script, -which retrieves and prints the metadata from a series of generated png files. -''' -import sys -import json -from invokeai.backend.image_util import retrieve_metadata - - -def print_metadata(): - if len(sys.argv) < 2: - print("Usage: file2prompt.py ...") - print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.") - exit(-1) - - filenames = sys.argv[1:] - for f in filenames: - try: - metadata = retrieve_metadata(f) - print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4)) - except FileNotFoundError: - sys.stderr.write(f'{f} not found\n') - continue - except PermissionError: - sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n') - continue - -if __name__== '__main__': - print_metadata() - diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index c12104033f..a283b4952d 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -23,7 +23,6 @@ from npyscreen import widget from omegaconf import OmegaConf import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals, global_config_dir from ...backend.config.model_install_backend import ( Dataset_path, @@ -41,11 +40,13 @@ from .widgets import ( TextBox, set_min_terminal_size, ) +from invokeai.app.services.config import get_invokeai_config # minimum size for the UI MIN_COLS = 120 MIN_LINES = 45 +config = get_invokeai_config() class addModelsForm(npyscreen.FormMultiPage): # for responsive resizing - disabled @@ -453,9 +454,9 @@ def main(): opt = parser.parse_args() # setting a global here - Globals.root = os.path.expanduser(get_root(opt.root) or "") + config.root = os.path.expanduser(get_root(opt.root) or "") - if not global_config_dir().exists(): + if not (config.conf_path / '..' ).exists(): logger.info( "Your InvokeAI root directory is not set up. Calling invokeai-configure." ) diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 524118ba7c..882a4587b6 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -8,7 +8,6 @@ import argparse import curses import os import sys -import traceback import warnings from argparse import Namespace from pathlib import Path @@ -20,20 +19,13 @@ from diffusers import logging as dlogging from npyscreen import widget from omegaconf import OmegaConf -from ...backend.globals import ( - Globals, - global_cache_dir, - global_config_file, - global_models_dir, - global_set_root, -) - import invokeai.backend.util.logging as logger +from invokeai.services.config import get_invokeai_config from ...backend.model_management import ModelManager from ...frontend.install.widgets import FloatTitleSlider DEST_MERGED_MODEL_DIR = "merged_models" - +config = get_invokeai_config() def merge_diffusion_models( model_ids_or_paths: List[Union[str, Path]], @@ -60,7 +52,7 @@ def merge_diffusion_models( pipe = DiffusionPipeline.from_pretrained( model_ids_or_paths[0], - cache_dir=kwargs.get("cache_dir", global_cache_dir()), + cache_dir=kwargs.get("cache_dir", config.cache_dir), custom_pipeline="checkpoint_merger", ) merged_pipe = pipe.merge( @@ -94,7 +86,7 @@ def merge_diffusion_models_and_commit( **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ - config_file = global_config_file() + config_file = config.model_conf_path model_manager = ModelManager(OmegaConf.load(config_file)) for mod in models: assert mod in model_manager.model_names(), f'** Unknown model "{mod}"' @@ -106,7 +98,7 @@ def merge_diffusion_models_and_commit( merged_pipe = merge_diffusion_models( model_ids_or_paths, alpha, interp, force, **kwargs ) - dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR + dump_path = config.models_dir / DEST_MERGED_MODEL_DIR os.makedirs(dump_path, exist_ok=True) dump_path = dump_path / merged_model_name @@ -126,7 +118,7 @@ def _parse_args() -> Namespace: parser.add_argument( "--root_dir", type=Path, - default=Globals.root, + default=config.root, help="Path to the invokeai runtime directory", ) parser.add_argument( @@ -398,7 +390,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): class Mergeapp(npyscreen.NPSAppManaged): def __init__(self): super().__init__() - conf = OmegaConf.load(global_config_file()) + conf = OmegaConf.load(config.model_conf_path) self.model_manager = ModelManager( conf, "cpu", "float16" ) # precision doesn't really matter here @@ -429,7 +421,7 @@ def run_cli(args: Namespace): f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' ) - model_manager = ModelManager(OmegaConf.load(global_config_file())) + model_manager = ModelManager(OmegaConf.load(config.model_conf_path)) assert ( args.clobber or args.merged_model_name not in model_manager.model_names() ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' @@ -440,9 +432,9 @@ def run_cli(args: Namespace): def main(): args = _parse_args() - global_set_root(args.root_dir) + config.root = args.root_dir - cache_dir = str(global_cache_dir("hub")) + cache_dir = config.cache_dir os.environ[ "HF_HOME" ] = cache_dir # because not clear the merge pipeline is honoring cache_dir diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py index 23134d2736..90e402f48b 100755 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -21,14 +21,17 @@ from npyscreen import widget from omegaconf import OmegaConf import invokeai.backend.util.logging as logger -from invokeai.backend.globals import Globals, global_set_root -from ...backend.training import do_textual_inversion_training, parse_args +from invokeai.app.services.config import get_invokeai_config +from ...backend.training import ( + do_textual_inversion_training, + parse_args +) TRAINING_DATA = "text-inversion-training-data" TRAINING_DIR = "text-inversion-output" CONF_FILE = "preferences.conf" - +config = None class textualInversionForm(npyscreen.FormMultiPageAction): resolutions = [512, 768, 1024] @@ -122,7 +125,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): value=str( saved_args.get( "train_data_dir", - Path(Globals.root) / TRAINING_DATA / default_placeholder_token, + config.root_dir / TRAINING_DATA / default_placeholder_token, ) ), scroll_exit=True, @@ -135,7 +138,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): value=str( saved_args.get( "output_dir", - Path(Globals.root) / TRAINING_DIR / default_placeholder_token, + config.root_dir / TRAINING_DIR / default_placeholder_token, ) ), scroll_exit=True, @@ -241,9 +244,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction): placeholder = self.placeholder_token.value self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)" self.train_data_dir.value = str( - Path(Globals.root) / TRAINING_DATA / placeholder + config.root_dir / TRAINING_DATA / placeholder ) - self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder) + self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder) self.resume_from_checkpoint.value = Path(self.output_dir.value).exists() def on_ok(self): @@ -284,7 +287,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): return True def get_model_names(self) -> Tuple[List[str], int]: - conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml")) + conf = OmegaConf.load(config.root_dir / "configs/models.yaml") model_names = [ idx for idx in sorted(list(conf.keys())) @@ -367,7 +370,7 @@ def copy_to_embeddings_folder(args: dict): """ source = Path(args["output_dir"], "learned_embeds.bin") dest_dir_name = args["placeholder_token"].strip("<>") - destination = Path(Globals.root, "embeddings", dest_dir_name) + destination = config.root_dir / "embeddings" / dest_dir_name os.makedirs(destination, exist_ok=True) logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}") shutil.copy(source, destination) @@ -383,7 +386,7 @@ def save_args(args: dict): """ Save the current argument values to an omegaconf file """ - dest_dir = Path(Globals.root) / TRAINING_DIR + dest_dir = config.root_dir / TRAINING_DIR os.makedirs(dest_dir, exist_ok=True) conf_file = dest_dir / CONF_FILE conf = OmegaConf.create(args) @@ -394,7 +397,7 @@ def previous_args() -> dict: """ Get the previous arguments used. """ - conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE + conf_file = config.root_dir / TRAINING_DIR / CONF_FILE try: conf = OmegaConf.load(conf_file) conf["placeholder_token"] = conf["placeholder_token"].strip("<>") @@ -420,7 +423,7 @@ def do_front_end(args: Namespace): save_args(args) try: - do_textual_inversion_training(**args) + do_textual_inversion_training(get_invokeai_config(),**args) copy_to_embeddings_folder(args) except Exception as e: logger.error("An exception occurred during training. The exception was:") @@ -430,13 +433,20 @@ def do_front_end(args: Namespace): def main(): + global config + args = parse_args() - global_set_root(args.root_dir or Globals.root) + config = get_invokeai_config(argv=[]) + + # change root if needed + if args.root_dir: + config.root = args.root_dir + try: if args.front_end: do_front_end(args) else: - do_textual_inversion_training(**vars(args)) + do_textual_inversion_training(config,**vars(args)) except AssertionError as e: logger.error(e) sys.exit(-1) diff --git a/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md b/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md index 90d85bb540..5f882717b1 100644 --- a/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md +++ b/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md @@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener ### Patch `@chakra-ui/cli` See: - -### Patch `redux-persist` - -We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`. - -`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it. - -So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that. - -### Patch `redux-deep-persist` - -This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work. diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 13f79f4a44..317929c6a4 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -62,11 +62,13 @@ "@dagrejs/graphlib": "^2.1.12", "@emotion/react": "^11.10.6", "@emotion/styled": "^11.10.6", + "@floating-ui/react-dom": "^2.0.0", "@fontsource/inter": "^4.5.15", "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", "chakra-ui-contextmenu": "^1.0.5", "dateformat": "^5.0.3", + "downshift": "^7.6.0", "formik": "^2.2.9", "framer-motion": "^10.12.4", "fuse.js": "^6.6.2", @@ -87,18 +89,13 @@ "react-i18next": "^12.2.2", "react-icons": "^4.7.1", "react-konva": "^18.2.7", - "react-konva-utils": "^1.0.4", "react-redux": "^8.0.5", "react-resizable-panels": "^0.0.42", - "react-rnd": "^10.4.1", - "react-transition-group": "^4.4.5", "react-use": "^17.4.0", "react-virtuoso": "^4.3.5", "react-zoom-pan-pinch": "^3.0.7", "reactflow": "^11.7.0", - "redux-deep-persist": "^1.0.7", "redux-dynamic-middlewares": "^2.2.0", - "redux-persist": "^6.0.0", "redux-remember": "^3.3.1", "roarr": "^7.15.0", "serialize-error": "^11.0.0", diff --git a/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch b/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch deleted file mode 100644 index 47a62e6aac..0000000000 --- a/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts -index b67b8c2..7fc0fa1 100644 ---- a/node_modules/redux-deep-persist/lib/types.d.ts -+++ b/node_modules/redux-deep-persist/lib/types.d.ts -@@ -35,6 +35,7 @@ export interface PersistConfig { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - getStoredState?: (config: PersistConfig) => Promise; -diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts -index 398ac19..cbc5663 100644 ---- a/node_modules/redux-deep-persist/src/types.ts -+++ b/node_modules/redux-deep-persist/src/types.ts -@@ -91,6 +91,7 @@ export interface PersistConfig { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - /** diff --git a/invokeai/frontend/web/patches/redux-persist+6.0.0.patch b/invokeai/frontend/web/patches/redux-persist+6.0.0.patch deleted file mode 100644 index 9e0a8492db..0000000000 --- a/invokeai/frontend/web/patches/redux-persist+6.0.0.patch +++ /dev/null @@ -1,116 +0,0 @@ -diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js -index 8b43b9a..184faab 100644 ---- a/node_modules/redux-persist/es/createPersistoid.js -+++ b/node_modules/redux-persist/es/createPersistoid.js -@@ -6,6 +6,7 @@ export default function createPersistoid(config) { - var whitelist = config.whitelist || null; - var transforms = config.transforms || []; - var throttle = config.throttle || 0; -+ var debounce = config.debounce || 0; - var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key); - var storage = config.storage; - var serialize; -@@ -28,30 +29,37 @@ export default function createPersistoid(config) { - var timeIterator = null; - var writePromise = null; - -- var update = function update(state) { -- // add any changed keys to the queue -- Object.keys(state).forEach(function (key) { -- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop -+ // Timer for debounced `update()` -+ let timer = 0; - -- if (lastState[key] === state[key]) return; // value unchanged? noop -+ function update(state) { -+ // Debounce the update -+ clearTimeout(timer); -+ timer = setTimeout(() => { -+ // add any changed keys to the queue -+ Object.keys(state).forEach(function (key) { -+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop - -- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop -+ if (lastState[key] === state[key]) return; // value unchanged? noop - -- keysToProcess.push(key); // add key to queue -- }); //if any key is missing in the new state which was present in the lastState, -- //add it for processing too -+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop - -- Object.keys(lastState).forEach(function (key) { -- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) { -- keysToProcess.push(key); -- } -- }); // start the time iterator if not running (read: throttle) -+ keysToProcess.push(key); // add key to queue -+ }); //if any key is missing in the new state which was present in the lastState, -+ //add it for processing too - -- if (timeIterator === null) { -- timeIterator = setInterval(processNextKey, throttle); -- } -+ Object.keys(lastState).forEach(function (key) { -+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) { -+ keysToProcess.push(key); -+ } -+ }); // start the time iterator if not running (read: throttle) -+ -+ if (timeIterator === null) { -+ timeIterator = setInterval(processNextKey, throttle); -+ } - -- lastState = state; -+ lastState = state; -+ }, debounce) - }; - - function processNextKey() { -diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/es/types.js.flow -+++ b/node_modules/redux-persist/es/types.js.flow -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/lib/types.js.flow -+++ b/node_modules/redux-persist/lib/types.js.flow -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/src/types.js -+++ b/node_modules/redux-persist/src/types.js -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts -index b3733bc..2a1696c 100644 ---- a/node_modules/redux-persist/types/types.d.ts -+++ b/node_modules/redux-persist/types/types.d.ts -@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - /** diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 3592e141d0..94dff3934a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -450,7 +450,7 @@ "cfgScale": "CFG Scale", "width": "Width", "height": "Height", - "sampler": "Sampler", + "scheduler": "Scheduler", "seed": "Seed", "imageToImage": "Image to Image", "randomizeSeed": "Randomize Seed", @@ -540,7 +540,10 @@ "consoleLogLevel": "Log Level", "shouldLogToConsole": "Console Logging", "developer": "Developer", - "general": "General" + "general": "General", + "generation": "Generation", + "ui": "User Interface", + "availableSchedulers": "Available Schedulers" }, "toast": { "serverError": "Server Error", @@ -549,8 +552,8 @@ "canceled": "Processing Canceled", "tempFoldersEmptied": "Temp Folder Emptied", "uploadFailed": "Upload failed", - "uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time", "uploadFailedUnableToLoadDesc": "Unable to load file", + "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "downloadImageStarted": "Image Download Started", "imageCopied": "Image Copied", "imageLinkCopied": "Image Link Copied", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 3fbcbc49ea..40554356b1 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -2,14 +2,11 @@ import ImageUploader from 'common/components/ImageUploader'; import SiteHeader from 'features/system/components/SiteHeader'; import ProgressBar from 'features/system/components/ProgressBar'; import InvokeTabs from 'features/ui/components/InvokeTabs'; - -import useToastWatcher from 'features/system/hooks/useToastWatcher'; - import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import { Box, Flex, Grid, Portal } from '@chakra-ui/react'; import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; -import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel'; +import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import Lightbox from 'features/lightbox/components/Lightbox'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { memo, ReactNode, useCallback, useEffect, useState } from 'react'; @@ -17,25 +14,28 @@ import { motion, AnimatePresence } from 'framer-motion'; import Loading from 'common/components/Loading/Loading'; import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady'; import { PartialAppConfig } from 'app/types/invokeai'; -import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys'; import { configChanged } from 'features/system/store/configSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useLogger } from 'app/logging/useLogger'; import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import { languageSelector } from 'features/system/store/systemSelectors'; import i18n from 'i18n'; +import Toaster from './Toaster'; +import GlobalHotkeys from './GlobalHotkeys'; const DEFAULT_CONFIG = {}; interface Props { config?: PartialAppConfig; headerComponent?: ReactNode; + setIsReady?: (isReady: boolean) => void; } -const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { - useToastWatcher(); - useGlobalHotkeys(); - +const App = ({ + config = DEFAULT_CONFIG, + headerComponent, + setIsReady, +}: Props) => { const language = useAppSelector(languageSelector); const log = useLogger(); @@ -61,66 +61,80 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => { setLoadingOverridden(true); }, []); + useEffect(() => { + if (isApplicationReady && setIsReady) { + setIsReady(true); + } + + return () => { + setIsReady && setIsReady(false); + }; + }, [isApplicationReady, setIsReady]); + return ( - - {isLightboxEnabled && } - - - - {headerComponent || } - + + {isLightboxEnabled && } + + + - - - - + {headerComponent || } + + + + + - - + + - - {!isApplicationReady && !loadingOverridden && ( - - - - - - - )} - + + {!isApplicationReady && !loadingOverridden && ( + + + + + + + )} + - - - - - - - + + + + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx b/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx new file mode 100644 index 0000000000..a0c5d22266 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx @@ -0,0 +1,44 @@ +import { Flex, Spinner, Tooltip } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { systemSelector } from 'features/system/store/systemSelectors'; +import { memo } from 'react'; + +const selector = createSelector(systemSelector, (system) => { + const { isUploading } = system; + + let tooltip = ''; + + if (isUploading) { + tooltip = 'Uploading...'; + } + + return { + tooltip, + shouldShow: isUploading, + }; +}); + +export const AuxiliaryProgressIndicator = () => { + const { shouldShow, tooltip } = useAppSelector(selector); + + if (!shouldShow) { + return null; + } + + return ( + + + + + + ); +}; + +export default memo(AuxiliaryProgressIndicator); diff --git a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts similarity index 89% rename from invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts rename to invokeai/frontend/web/src/app/components/GlobalHotkeys.ts index 3935a390fb..c4660416bf 100644 --- a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts +++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts @@ -10,6 +10,7 @@ import { togglePinParametersPanel, } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; +import React, { memo } from 'react'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; const globalHotkeysSelector = createSelector( @@ -27,7 +28,11 @@ const globalHotkeysSelector = createSelector( // TODO: Does not catch keypresses while focused in an input. Maybe there is a way? -export const useGlobalHotkeys = () => { +/** + * Logical component. Handles app-level global hotkeys. + * @returns null + */ +const GlobalHotkeys: React.FC = () => { const dispatch = useAppDispatch(); const { shift } = useAppSelector(globalHotkeysSelector); @@ -75,4 +80,8 @@ export const useGlobalHotkeys = () => { useHotkeys('4', () => { dispatch(setActiveTab('nodes')); }); + + return null; }; + +export default memo(GlobalHotkeys); diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 442c1d967a..c04a8184d7 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -24,9 +24,16 @@ interface Props extends PropsWithChildren { token?: string; config?: PartialAppConfig; headerComponent?: ReactNode; + setIsReady?: (isReady: boolean) => void; } -const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => { +const InvokeAIUI = ({ + apiUrl, + token, + config, + headerComponent, + setIsReady, +}: Props) => { useEffect(() => { // configure API client token if (token) { @@ -55,7 +62,11 @@ const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => { }> - + diff --git a/invokeai/frontend/web/src/app/components/Toaster.ts b/invokeai/frontend/web/src/app/components/Toaster.ts new file mode 100644 index 0000000000..66ba1d4925 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/Toaster.ts @@ -0,0 +1,65 @@ +import { useToast, UseToastOptions } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { toastQueueSelector } from 'features/system/store/systemSelectors'; +import { addToast, clearToastQueue } from 'features/system/store/systemSlice'; +import { useCallback, useEffect } from 'react'; + +export type MakeToastArg = string | UseToastOptions; + +/** + * Makes a toast from a string or a UseToastOptions object. + * If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms. + */ +export const makeToast = (arg: MakeToastArg): UseToastOptions => { + if (typeof arg === 'string') { + return { + title: arg, + status: 'info', + isClosable: true, + duration: 2500, + }; + } + + return { status: 'info', isClosable: true, duration: 2500, ...arg }; +}; + +/** + * Logical component. Watches the toast queue and makes toasts when the queue is not empty. + * @returns null + */ +const Toaster = () => { + const dispatch = useAppDispatch(); + const toastQueue = useAppSelector(toastQueueSelector); + const toast = useToast(); + useEffect(() => { + toastQueue.forEach((t) => { + toast(t); + }); + toastQueue.length > 0 && dispatch(clearToastQueue()); + }, [dispatch, toast, toastQueue]); + + return null; +}; + +/** + * Returns a function that can be used to make a toast. + * @example + * const toaster = useAppToaster(); + * toaster('Hello world!'); + * toaster({ title: 'Hello world!', status: 'success' }); + * @returns A function that can be used to make a toast. + * @see makeToast + * @see MakeToastArg + * @see UseToastOptions + */ +export const useAppToaster = () => { + const dispatch = useAppDispatch(); + const toaster = useCallback( + (arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), + [dispatch] + ); + + return toaster; +}; + +export default Toaster; diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 189fbc9dd4..d312d725ba 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,28 +1,28 @@ // TODO: use Enums? -export const DIFFUSERS_SCHEDULERS: Array = [ +export const SCHEDULERS = [ 'ddim', - 'ddpm', - 'deis', 'lms', - 'pndm', - 'heun', 'euler', 'euler_k', 'euler_a', - 'kdpm_2', - 'kdpm_2_a', 'dpmpp_2s', 'dpmpp_2m', 'dpmpp_2m_k', + 'kdpm_2', + 'kdpm_2_a', + 'deis', + 'ddpm', + 'pndm', + 'heun', + 'heun_k', 'unipc', -]; +] as const; -export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter( - (scheduler) => { - return scheduler !== 'dpmpp_2s'; - } -); +export type Scheduler = (typeof SCHEDULERS)[number]; + +export const isScheduler = (x: string): x is Scheduler => + SCHEDULERS.includes(x as Scheduler); // Valid image widths export const WIDTHS: Array = Array.from(Array(64)).map( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 36bf6adfe7..f23e83a191 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -15,6 +15,10 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; +import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery'; +import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; +import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; +import { addCanvasMergedListener } from './listeners/canvasMerged'; export const listenerMiddleware = createListenerMiddleware(); @@ -43,3 +47,8 @@ addUserInvokedCanvasListener(); addUserInvokedNodesListener(); addUserInvokedTextToImageListener(); addUserInvokedImageToImageListener(); + +addCanvasSavedToGalleryListener(); +addCanvasDownloadedAsImageListener(); +addCanvasCopiedToClipboardListener(); +addCanvasMergedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts new file mode 100644 index 0000000000..16642f1f32 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts @@ -0,0 +1,33 @@ +import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; +import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard'; + +const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' }); + +export const addCanvasCopiedToClipboardListener = () => { + startAppListening({ + actionCreator: canvasCopiedToClipboard, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Copying Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + copyBlobToClipboard(blob); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts new file mode 100644 index 0000000000..ef4c63b31c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts @@ -0,0 +1,33 @@ +import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { downloadBlob } from 'features/canvas/util/downloadBlob'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); + +export const addCanvasDownloadedAsImageListener = () => { + startAppListening({ + actionCreator: canvasDownloadedAsImage, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Downloading Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + downloadBlob(blob, 'mergedCanvas.png'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts deleted file mode 100644 index 532bac3eee..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { startAppListening } from '..'; -import { - canvasSessionIdChanged, - stagingAreaInitialized, -} from 'features/canvas/store/canvasSlice'; -import { sessionInvoked } from 'services/thunks/session'; - -export const addCanvasGraphBuiltListener = () => - startAppListening({ - actionCreator: canvasGraphBuilt, - effect: async (action, { dispatch, getState, take }) => { - const [{ meta }] = await take(sessionInvoked.fulfilled.match); - const { sessionId } = meta.arg; - const state = getState(); - - if (!state.canvas.layerState.stagingArea.boundingBox) { - dispatch( - stagingAreaInitialized({ - sessionId, - boundingBox: { - ...state.canvas.boundingBoxCoordinates, - ...state.canvas.boundingBoxDimensions, - }, - }) - ); - } - - dispatch(canvasSessionIdChanged(sessionId)); - }, - }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts new file mode 100644 index 0000000000..d7a58c2050 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -0,0 +1,88 @@ +import { canvasMerged } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; +import { imageUploaded } from 'services/thunks/image'; +import { v4 as uuidv4 } from 'uuid'; +import { deserializeImageResponse } from 'services/util/deserializeImageResponse'; +import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; +import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; + +const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' }); + +export const addCanvasMergedListener = () => { + startAppListening({ + actionCreator: canvasMerged, + effect: async (action, { dispatch, getState, take }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state, true); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Merging Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + const canvasBaseLayer = getCanvasBaseLayer(); + + if (!canvasBaseLayer) { + moduleLog.error('Problem getting canvas base layer'); + dispatch( + addToast({ + title: 'Problem Merging Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + const baseLayerRect = canvasBaseLayer.getClientRect({ + relativeTo: canvasBaseLayer.getParent(), + }); + + const filename = `mergedCanvas_${uuidv4()}.png`; + + dispatch( + imageUploaded({ + imageType: 'intermediates', + formData: { + file: new File([blob], filename, { type: 'image/png' }), + }, + }) + ); + + const [{ payload }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === filename + ); + + const mergedCanvasImage = deserializeImageResponse(payload.response); + + dispatch( + setMergedCanvas({ + kind: 'image', + layer: 'base', + image: mergedCanvasImage, + ...baseLayerRect, + }) + ); + + dispatch( + addToast({ + title: 'Canvas Merged', + status: 'success', + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts new file mode 100644 index 0000000000..d8237d1d5c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -0,0 +1,40 @@ +import { canvasSavedToGallery } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { imageUploaded } from 'services/thunks/image'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); + +export const addCanvasSavedToGalleryListener = () => { + startAppListening({ + actionCreator: canvasSavedToGallery, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Saving Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + dispatch( + imageUploaded({ + imageType: 'results', + formData: { + file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), + }, + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index c32da2e710..de06220ecd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -3,6 +3,10 @@ import { startAppListening } from '..'; import { uploadAdded } from 'features/gallery/store/uploadsSlice'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; +import { addToast } from 'features/system/store/systemSlice'; +import { initialImageSelected } from 'features/parameters/store/actions'; +import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { resultAdded } from 'features/gallery/store/resultsSlice'; export const addImageUploadedListener = () => { startAppListening({ @@ -11,14 +15,31 @@ export const addImageUploadedListener = () => { action.payload.response.image_type !== 'intermediates', effect: (action, { dispatch, getState }) => { const { response } = action.payload; + const { imageType } = action.meta.arg; const state = getState(); const image = deserializeImageResponse(response); - dispatch(uploadAdded(image)); + if (imageType === 'uploads') { + dispatch(uploadAdded(image)); - if (state.gallery.shouldAutoSwitchToNewImages) { - dispatch(imageSelected(image)); + dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); + + if (state.gallery.shouldAutoSwitchToNewImages) { + dispatch(imageSelected(image)); + } + + if (action.meta.arg.activeTabName === 'img2img') { + dispatch(initialImageSelected(image)); + } + + if (action.meta.arg.activeTabName === 'unifiedCanvas') { + dispatch(setInitialCanvasImage(image)); + } + } + + if (imageType === 'results') { + dispatch(resultAdded(image)); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index 6bc2f9e9bc..ae3a35f537 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -2,11 +2,11 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { Image, isInvokeAIImage } from 'app/types/invokeai'; import { selectResultsById } from 'features/gallery/store/resultsSlice'; import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; -import { makeToast } from 'features/system/hooks/useToastWatcher'; import { t } from 'i18next'; import { addToast } from 'features/system/store/systemSlice'; import { startAppListening } from '..'; import { initialImageSelected } from 'features/parameters/store/actions'; +import { makeToast } from 'app/components/Toaster'; export const addInitialImageSelectedListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index cdb2c83e12..2ebd3684e9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -1,6 +1,6 @@ import { startAppListening } from '..'; import { sessionCreated, sessionInvoked } from 'services/thunks/session'; -import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; +import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { log } from 'app/logging/useLogger'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { imageUploaded } from 'services/thunks/image'; @@ -11,9 +11,17 @@ import { stagingAreaInitialized, } from 'features/canvas/store/canvasSlice'; import { userInvoked } from 'app/store/actions'; +import { getCanvasData } from 'features/canvas/util/getCanvasData'; +import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; +import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; +import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; const moduleLog = log.child({ namespace: 'invoke' }); +/** + * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. + * It is also responsible for uploading the base and mask layers to the server. + */ export const addUserInvokedCanvasListener = () => { startAppListening({ predicate: (action): action is ReturnType => @@ -21,25 +29,49 @@ export const addUserInvokedCanvasListener = () => { effect: async (action, { getState, dispatch, take }) => { const state = getState(); - const data = await buildCanvasGraphAndBlobs(state); + // Build canvas blobs + const canvasBlobsAndImageData = await getCanvasData(state); - if (!data) { + if (!canvasBlobsAndImageData) { + moduleLog.error('Unable to create canvas data'); + return; + } + + const { baseBlob, baseImageData, maskBlob, maskImageData } = + canvasBlobsAndImageData; + + // Determine the generation mode + const generationMode = getCanvasGenerationMode( + baseImageData, + maskImageData + ); + + if (state.system.enableImageDebugging) { + const baseDataURL = await blobToDataURL(baseBlob); + const maskDataURL = await blobToDataURL(maskBlob); + openBase64ImageInTab([ + { base64: maskDataURL, caption: 'mask b64' }, + { base64: baseDataURL, caption: 'image b64' }, + ]); + } + + moduleLog.debug(`Generation mode: ${generationMode}`); + + // Build the canvas graph + const graphComponents = await buildCanvasGraphComponents( + state, + generationMode + ); + + if (!graphComponents) { moduleLog.error('Problem building graph'); return; } - const { - rangeNode, - iterateNode, - baseNode, - edges, - baseBlob, - maskBlob, - generationMode, - } = data; + const { rangeNode, iterateNode, baseNode, edges } = graphComponents; + // Upload the base layer, to be used as init image const baseFilename = `${uuidv4()}.png`; - const maskFilename = `${uuidv4()}.png`; dispatch( imageUploaded({ @@ -66,6 +98,9 @@ export const addUserInvokedCanvasListener = () => { }; } + // Upload the mask layer image + const maskFilename = `${uuidv4()}.png`; + if (baseNode.type === 'inpaint') { dispatch( imageUploaded({ @@ -103,9 +138,12 @@ export const addUserInvokedCanvasListener = () => { dispatch(canvasGraphBuilt(graph)); moduleLog({ data: graph }, 'Canvas graph built'); + // Actually create the session dispatch(sessionCreated({ graph })); + // Wait for the session to be invoked (this is just the HTTP request to start processing) const [{ meta }] = await take(sessionInvoked.fulfilled.match); + const { sessionId } = meta.arg; if (!state.canvas.layerState.stagingArea.boundingBox) { diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index d0e5437d36..f684dc1ccf 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -52,6 +52,7 @@ export type CommonGeneratedImageMetadata = { | 'lms' | 'pndm' | 'heun' + | 'heun_k' | 'euler' | 'euler_k' | 'euler_a' diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx new file mode 100644 index 0000000000..d9610346ec --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx @@ -0,0 +1,172 @@ +import { CheckIcon } from '@chakra-ui/icons'; +import { + Box, + Flex, + FlexProps, + FormControl, + FormControlProps, + FormLabel, + Grid, + GridItem, + List, + ListItem, + Select, + Text, + Tooltip, + TooltipProps, +} from '@chakra-ui/react'; +import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; +import { useSelect } from 'downshift'; +import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; + +import { memo } from 'react'; + +type IAICustomSelectProps = { + label?: string; + items: string[]; + selectedItem: string; + setSelectedItem: (v: string | null | undefined) => void; + withCheckIcon?: boolean; + formControlProps?: FormControlProps; + buttonProps?: FlexProps; + tooltip?: string; + tooltipProps?: Omit; +}; + +const IAICustomSelect = (props: IAICustomSelectProps) => { + const { + label, + items, + setSelectedItem, + selectedItem, + withCheckIcon, + formControlProps, + tooltip, + buttonProps, + tooltipProps, + } = props; + + const { + isOpen, + getToggleButtonProps, + getLabelProps, + getMenuProps, + highlightedIndex, + getItemProps, + } = useSelect({ + items, + selectedItem, + onSelectedItemChange: ({ selectedItem: newSelectedItem }) => + setSelectedItem(newSelectedItem), + }); + + const { refs, floatingStyles } = useFloating({ + whileElementsMounted: autoUpdate, + middleware: [offset(4), shift({ crossAxis: true, padding: 8 })], + }); + + return ( + + {label && ( + { + refs.floating.current && refs.floating.current.focus(); + }} + > + {label} + + )} + + + + + {isOpen && ( + + + {items.map((item, index) => ( + + {withCheckIcon ? ( + + + {selectedItem === item && } + + + + {item} + + + + ) : ( + + {item} + + )} + + ))} + + + )} + + + ); +}; + +export default memo(IAICustomSelect); diff --git a/invokeai/frontend/web/src/common/components/IAIInput.tsx b/invokeai/frontend/web/src/common/components/IAIInput.tsx index 3e90dca83a..3cba36d2c9 100644 --- a/invokeai/frontend/web/src/common/components/IAIInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAIInput.tsx @@ -5,6 +5,7 @@ import { Input, InputProps, } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { ChangeEvent, memo } from 'react'; interface IAIInputProps extends InputProps { @@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => { {...formControlProps} > {label !== '' && {label}} - + ); }; diff --git a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx index 762182eb47..bf598f3b12 100644 --- a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx @@ -14,6 +14,7 @@ import { Tooltip, TooltipProps, } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { clamp } from 'lodash-es'; import { FocusEvent, memo, useEffect, useState } from 'react'; @@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => { onChange={handleOnChange} onBlur={handleBlur} {...rest} + onPaste={stopPastePropagation} > {showStepper && ( diff --git a/invokeai/frontend/web/src/common/components/IAITextarea.tsx b/invokeai/frontend/web/src/common/components/IAITextarea.tsx new file mode 100644 index 0000000000..b5247887bb --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAITextarea.tsx @@ -0,0 +1,9 @@ +import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; +import { memo } from 'react'; + +const IAITextarea = forwardRef((props: TextareaProps, ref) => { + return