mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve some undefined symbols in model_cache
This commit is contained in:
commit
d96175d127
20
.github/workflows/test-invoke-pip.yml
vendored
20
.github/workflows/test-invoke-pip.yml
vendored
@ -80,12 +80,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: set test prompt to Pull Request validation
|
||||
if: ${{ github.ref != 'refs/heads/main' }}
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v4
|
||||
@ -105,12 +100,6 @@ jobs:
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
- name: set INVOKEAI_OUTDIR
|
||||
run: >
|
||||
python -c
|
||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
||||
>> ${{ matrix.github-env }}
|
||||
|
||||
- name: run invokeai-configure
|
||||
id: run-preload-models
|
||||
env:
|
||||
@ -129,15 +118,20 @@ jobs:
|
||||
HF_HUB_OFFLINE: 1
|
||||
HF_DATASETS_OFFLINE: 1
|
||||
TRANSFORMERS_OFFLINE: 1
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
run: >
|
||||
invokeai
|
||||
--no-patchmatch
|
||||
--no-nsfw_checker
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
--precision=float32
|
||||
--always_use_cpu
|
||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
--from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
- name: Archive results
|
||||
id: archive-results
|
||||
env:
|
||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: results
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -201,6 +201,8 @@ checkpoints
|
||||
# If it's a Mac
|
||||
.DS_Store
|
||||
|
||||
invokeai/frontend/web/dist/*
|
||||
|
||||
# Let the frontend manage its own gitignore
|
||||
!invokeai/frontend/web/*
|
||||
|
||||
|
@ -7,7 +7,6 @@ from typing import types
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ...backend import Globals
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
@ -42,17 +41,8 @@ class ApiDependencies:
|
||||
|
||||
invoker: Invoker = None
|
||||
|
||||
@staticmethod
|
||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||
Globals.try_patchmatch = config.patchmatch
|
||||
Globals.always_use_cpu = config.always_use_cpu
|
||||
Globals.internet_available = config.internet_available and check_internet()
|
||||
Globals.disable_xformers = not config.xformers
|
||||
Globals.ckpt_convert = config.ckpt_convert
|
||||
|
||||
# TO DO: Use the config to select the logger rather than use the default
|
||||
# invokeai logging module
|
||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
||||
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
@ -72,7 +62,6 @@ class ApiDependencies:
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config,logger),
|
||||
events=events,
|
||||
logger=logger,
|
||||
latents=latents,
|
||||
images=images,
|
||||
metadata=metadata,
|
||||
@ -85,6 +74,8 @@ class ApiDependencies:
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
@ -13,11 +13,11 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@ -33,30 +33,25 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
|
||||
config = {}
|
||||
|
||||
# initialize config
|
||||
# this is a module global
|
||||
app_config = InvokeAIAppConfig()
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=app_config.allow_origins,
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
)
|
||||
|
||||
ApiDependencies.initialize(
|
||||
config=config, event_handler_id=event_handler_id, logger=logger
|
||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||
)
|
||||
|
||||
|
||||
@ -148,14 +143,11 @@ app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), n
|
||||
|
||||
def invoke_api():
|
||||
# Start our own event loop for eventing usage
|
||||
# TODO: determine if there's a better way to do this
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||
def _iter_indented_subactions(self, action):
|
||||
try:
|
||||
get_subactions = action._get_subactions
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self._indent()
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||
yield subaction
|
||||
else:
|
||||
for subaction in get_subactions():
|
||||
yield subaction
|
||||
self._dedent()
|
||||
|
@ -11,9 +11,10 @@ from pathlib import Path
|
||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ...backend import ModelManager, Globals
|
||||
from ...backend import ModelManager
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from .commands import BaseCommand
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
# singleton object, class variable
|
||||
completer = None
|
||||
@ -131,13 +132,13 @@ class Completer(object):
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||
global completer
|
||||
|
||||
if completer:
|
||||
return completer
|
||||
|
||||
completer = Completer(model_manager)
|
||||
completer = Completer(services.model_manager)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
histfile = Path(Globals.root, ".invoke_history")
|
||||
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
|
@ -4,13 +4,14 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -20,10 +21,7 @@ from invokeai.app.services.metadata import PngMetadataService
|
||||
from .services.default_graphs import create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from ..backend import Globals # this should go when pr 3340 merged
|
||||
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
@ -37,6 +35,7 @@ from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.config import get_invokeai_config
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@ -66,7 +65,7 @@ def add_invocation_args(command_parser):
|
||||
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||
|
||||
def exit(*args, **kwargs):
|
||||
raise InvalidArgs
|
||||
@ -191,28 +190,26 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = Args()
|
||||
config.parse_args()
|
||||
# this gets the basic configuration
|
||||
config = get_invokeai_config()
|
||||
|
||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
invocation_commands = parser.parse_args().commands
|
||||
|
||||
# get the optional file to read commands from.
|
||||
# Simplest is to use it for STDIN
|
||||
if infile := config.from_file:
|
||||
sys.stdin = open(infile,"r")
|
||||
|
||||
model_manager = ModelManagerService(config,logger)
|
||||
|
||||
# This initializes the autocompleter and returns it.
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
output_folder = config.output_path
|
||||
metadata = PngMetadataService()
|
||||
|
||||
output_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||
)
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
@ -232,6 +229,7 @@ def invoke_cli():
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config,logger=logger),
|
||||
logger=logger,
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@ -247,10 +245,18 @@ def invoke_cli():
|
||||
# print(services.session_manager.list())
|
||||
|
||||
context = CliContext(invoker, session, parser)
|
||||
set_autocompleter(services)
|
||||
|
||||
while True:
|
||||
command_line_args_exist = len(invocation_commands) > 0
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
try:
|
||||
cmd_input = input("invoke> ")
|
||||
if command_line_args_exist:
|
||||
cmd_input = invocation_commands.pop(0)
|
||||
done = len(invocation_commands) == 0
|
||||
else:
|
||||
cmd_input = input("invoke> ")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl-c exits
|
||||
break
|
||||
@ -374,6 +380,9 @@ def invoke_cli():
|
||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||
continue
|
||||
|
||||
except ValidationError:
|
||||
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
invoker.services.logger.warning("Session error: creating a new session")
|
||||
|
@ -5,10 +5,8 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
|
||||
|
||||
from .model import ClipField
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
||||
from ...backend.model_management import SDModelType
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
@ -18,8 +16,6 @@ from compel.prompt_parser import (
|
||||
Fragment,
|
||||
)
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||
@ -91,7 +87,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
@ -5,7 +5,12 @@ from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
@ -22,19 +27,21 @@ class MathInvocationConfig(BaseModel):
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
@ -42,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
@ -54,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
@ -66,11 +75,12 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
@ -78,8 +88,13 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
#fmt: off
|
||||
|
||||
# fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
#fmt: on
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
||||
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||
|
528
invokeai/app/services/config.py
Normal file
528
invokeai/app/services/config.py
Normal file
@ -0,0 +1,528 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
'''Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||
categories returned by `invokeai --help`. The file looks like this:
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
embedding_dir: embeddings
|
||||
lora_dir: loras
|
||||
autoconvert_dir: null
|
||||
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||
Models:
|
||||
model: stable-diffusion-1.5
|
||||
embeddings: true
|
||||
Memory/Performance:
|
||||
xformers_enabled: false
|
||||
sequential_guidance: false
|
||||
precision: float16
|
||||
max_loaded_models: 4
|
||||
always_use_cpu: false
|
||||
free_gpu_mem: false
|
||||
Features:
|
||||
nsfw_checker: true
|
||||
restore: true
|
||||
esrgan: true
|
||||
patchmatch: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 8081
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
OmegaConf dictionary object initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
||||
|
||||
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
|
||||
initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
||||
|
||||
It is also possible to set a value at initialization time. This value
|
||||
has highest priority.
|
||||
|
||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<setting>", as in:
|
||||
|
||||
export INVOKEAI_port=8080
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.invocations.generate import TextToImageInvocation
|
||||
|
||||
# get global configuration and print its nsfw_checker value
|
||||
conf = InvokeAIAppConfig()
|
||||
print(conf.nsfw_checker)
|
||||
|
||||
# get the text2image invocation and print its step value
|
||||
text2image = TextToImageInvocation()
|
||||
print(text2image.steps)
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
In most cases, you will want to create a single InvokeAIAppConfig
|
||||
object for the entire application. The get_invokeai_config() function
|
||||
does this:
|
||||
|
||||
config = get_invokeai_config()
|
||||
print(config.root)
|
||||
|
||||
# Subclassing
|
||||
|
||||
If you wish to create a similar class, please subclass the
|
||||
`InvokeAISettings` class and define a Literal field named "type",
|
||||
which is set to the desired top-level name. For example, to create a
|
||||
"InvokeBatch" configuration, define like this:
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||
two configs are kept in separate sections of the config file:
|
||||
|
||||
# invokeai.yaml
|
||||
|
||||
InvokeBatch:
|
||||
Resources:
|
||||
node_count: 1
|
||||
cpu_count: 8
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
'''
|
||||
import argparse
|
||||
import pydoc
|
||||
import typing
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
from pydantic import BaseSettings, Field, parse_obj_as
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||
|
||||
# This global stores a singleton InvokeAIAppConfig configuration object
|
||||
global_config = None
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
'''
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
'''
|
||||
initconf : ClassVar[DictConfig] = None
|
||||
argparse_groups : ClassVar[Dict] = {}
|
||||
|
||||
def parse_args(self, argv: list=sys.argv[1:]):
|
||||
parser = self.get_parser()
|
||||
opt, _ = parser.parse_known_args(argv)
|
||||
for name in self.__fields__:
|
||||
if name not in self._excluded():
|
||||
setattr(self, name, getattr(opt,name))
|
||||
|
||||
def to_yaml(self)->str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)['type'])[0]
|
||||
field_dict = dict({type:dict()})
|
||||
for name,field in self.__fields__.items():
|
||||
if name in cls._excluded():
|
||||
continue
|
||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||
value = getattr(self,name)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = dict()
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
if 'type' in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||
|
||||
initconf = cls.initconf.get(settings_stanza) \
|
||||
if cls.initconf and settings_stanza in cls.initconf \
|
||||
else OmegaConf.create()
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = dict()
|
||||
for key,value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.__fields__
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = field.field_info.extra.get("category","Uncategorized")
|
||||
env_name = env_prefix + '_' + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(self, command_field: str='type')->str:
|
||||
hints = get_type_hints(self)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
else:
|
||||
return 'Uncategorized'
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls)->ArgumentParser:
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
return ['type','initconf']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
case_sensitive = True
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if category := field.field_info.extra.get("category"):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs='*',
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||
help=field.field_info.description,
|
||||
)
|
||||
def _find_root()->Path:
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||
or
|
||||
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||
)
|
||||
):
|
||||
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
'''
|
||||
#fmt: off
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||
|
||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||
|
||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||
#fmt: on
|
||||
|
||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
||||
'''
|
||||
Initialize InvokeAIAppconfig.
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param **kwargs: attributes to initialize with
|
||||
'''
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
self.parse_args(argv)
|
||||
if conf is None:
|
||||
try:
|
||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except:
|
||||
pass
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
self.parse_args(argv)
|
||||
|
||||
# restore initialization values
|
||||
hints = get_type_hints(self)
|
||||
for k in kwargs:
|
||||
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
||||
|
||||
@property
|
||||
def root_path(self)->Path:
|
||||
'''
|
||||
Path to the runtime root directory
|
||||
'''
|
||||
if self.root:
|
||||
return Path(self.root).expanduser()
|
||||
else:
|
||||
return self.find_root()
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
'''
|
||||
Alias for above.
|
||||
'''
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
'''
|
||||
Path to defaults outputs directory.
|
||||
'''
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
'''
|
||||
Path to models configuration file.
|
||||
'''
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self)->Path:
|
||||
'''
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
'''
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def cache_dir(self)->Path:
|
||||
'''
|
||||
Path to the global cache directory for HuggingFace hub-managed models
|
||||
'''
|
||||
return self.models_dir / "hub"
|
||||
|
||||
@property
|
||||
def models_dir(self)->Path:
|
||||
'''
|
||||
Path to the models directory
|
||||
'''
|
||||
return self._resolve("models")
|
||||
|
||||
@property
|
||||
def converted_ckpts_dir(self)->Path:
|
||||
'''
|
||||
Path to the converted models
|
||||
'''
|
||||
return self._resolve("models/converted_ckpts")
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
'''
|
||||
Path to the textual inversion embeddings directory.
|
||||
'''
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def lora_path(self)->Path:
|
||||
'''
|
||||
Path to the LoRA models directory.
|
||||
'''
|
||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
'''
|
||||
Path to the directory containing models to be imported automatically at startup.
|
||||
'''
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
'''
|
||||
Path to the GFPGAN model.
|
||||
'''
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self)->bool:
|
||||
"""Return true if precision set to float32"""
|
||||
return self.precision=='float32'
|
||||
|
||||
@property
|
||||
def disable_xformers(self)->bool:
|
||||
"""Return true if xformers_enabled is false"""
|
||||
return not self.xformers_enabled
|
||||
|
||||
@property
|
||||
def try_patchmatch(self)->bool:
|
||||
"""Return true if patchmatch true"""
|
||||
return self.patchmatch
|
||||
|
||||
@staticmethod
|
||||
def find_root()->Path:
|
||||
'''
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
'''
|
||||
return _find_root()
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
'''
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
It also supports reading defaults from an init file.
|
||||
'''
|
||||
def print_help(self, file=None):
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
|
||||
'''
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
'''
|
||||
global global_config
|
||||
if global_config is None or type(global_config)!=cls:
|
||||
global_config = cls(**kwargs)
|
||||
return global_config
|
@ -135,6 +135,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
type: Literal["graph"] = "graph"
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
@ -162,6 +163,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
type: Literal["iterate"] = "iterate"
|
||||
|
||||
collection: list[Any] = Field(
|
||||
|
@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .config import InvokeAISettings
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
@ -21,7 +22,8 @@ class InvocationServices:
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
configuration: InvokeAISettings
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
@ -40,6 +42,7 @@ class InvocationServices:
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
configuration: InvokeAISettings=None,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
@ -52,3 +55,4 @@ class InvocationServices:
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
@ -12,5 +11,3 @@ from .generator import (
|
||||
)
|
||||
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,10 +19,10 @@ import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import HfFolder
|
||||
@ -38,34 +38,40 @@ from transformers import (
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from ...frontend.install.widgets import (
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
IntTitleSlider,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
from ..args import PRECISION_CHOICES, Args
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
||||
from .model_install_backend import (
|
||||
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.config.model_install_backend import (
|
||||
default_dataset,
|
||||
download_from_hf,
|
||||
hf_download_with_resume,
|
||||
recommended_datasets,
|
||||
)
|
||||
from invokeai.app.services.config import (
|
||||
get_invokeai_config,
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = get_invokeai_config()
|
||||
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
Datasets = OmegaConf.load(Dataset_path)
|
||||
|
||||
@ -73,17 +79,12 @@ Datasets = OmegaConf.load(Dataset_path)
|
||||
MIN_COLS = 135
|
||||
MIN_LINES = 45
|
||||
|
||||
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
# Place frequently-used startup commands here, one or more per line.
|
||||
# Examples:
|
||||
# --outdir=D:\data\images
|
||||
# --no-nsfw_checker
|
||||
# --web --host=0.0.0.0
|
||||
# --steps=20
|
||||
# -Ak_euler_a -C10.0
|
||||
"""
|
||||
|
||||
|
||||
@ -96,14 +97,13 @@ If you installed manually from source or with 'pip install': activate the virtua
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai --web # (connect to http://localhost:9090)
|
||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
||||
invokeai-web
|
||||
|
||||
Command-line interface:
|
||||
Command-line client:
|
||||
invokeai
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
{config.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
@ -216,11 +216,11 @@ def download_realesrgan():
|
||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||
|
||||
model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
config.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
|
||||
wdn_model_dest = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
config.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
|
||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||
@ -243,7 +243,7 @@ def download_gfpgan():
|
||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||
],
|
||||
):
|
||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
||||
model_url, model_dest = model[0], os.path.join(config.root, model[1])
|
||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ def download_codeformer():
|
||||
model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
)
|
||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
||||
model_dest = os.path.join(config.root, "models/codeformer/codeformer.pth")
|
||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||
|
||||
|
||||
@ -295,7 +295,7 @@ def download_vaes():
|
||||
# first the diffusers version
|
||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||
args = dict(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
cache_dir=config.cache_dir,
|
||||
)
|
||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||
raise Exception(f"download of {repo_id} failed")
|
||||
@ -306,7 +306,7 @@ def download_vaes():
|
||||
if not hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=model_name,
|
||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
||||
model_dir=str(config.root / Model_dir / Weights_dir),
|
||||
):
|
||||
raise Exception(f"download of {model_name} failed")
|
||||
except Exception as e:
|
||||
@ -321,8 +321,7 @@ def get_root(root: str = None) -> str:
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
return config.root
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(npyscreen.FormMultiPage):
|
||||
@ -332,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts = self.parentApp.invokeai_opts
|
||||
first_time = not (Globals.root / Globals.initfile).exists()
|
||||
first_time = not (config.root / 'invokeai.yaml').exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
for i in [
|
||||
@ -366,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
value=old_opts.outdir or str(default_output_dir()),
|
||||
value=str(old_opts.outdir) or str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
@ -381,17 +380,17 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.safety_checker = self.add_widget_intelligent(
|
||||
self.nsfw_checker = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="NSFW checker",
|
||||
value=old_opts.safety_checker,
|
||||
value=old_opts.nsfw_checker,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
for i in [
|
||||
"If you have an account at HuggingFace you may paste your access token here",
|
||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
||||
"If you have an account at HuggingFace you may optionally paste your access token here",
|
||||
'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".',
|
||||
"See https://huggingface.co/settings/tokens",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
@ -435,17 +434,10 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.xformers = self.add_widget_intelligent(
|
||||
self.xformers_enabled = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Enable xformers support if available",
|
||||
value=old_opts.xformers,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.ckpt_convert = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Load legacy checkpoint models into memory as diffusers models",
|
||||
value=old_opts.ckpt_convert,
|
||||
value=old_opts.xformers_enabled,
|
||||
relx=5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
@ -480,19 +472,30 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Directory containing embedding/textual inversion files:",
|
||||
value="Directories containing textual inversion and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.embedding_path = self.add_widget_intelligent(
|
||||
self.embedding_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="(<tab> autocompletes, ctrl-N advances):",
|
||||
name=" Textual Inversion Embeddings:",
|
||||
value=str(default_embedding_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lora_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name=" LoRA and LyCORIS:",
|
||||
value=str(default_lora_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if not Path(opt.embedding_path).parent.exists():
|
||||
if not Path(opt.embedding_dir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
@ -576,20 +579,23 @@ class editOptsForm(npyscreen.FormMultiPage):
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"outdir",
|
||||
"safety_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers",
|
||||
"always_use_cpu",
|
||||
"embedding_path",
|
||||
"ckpt_convert",
|
||||
"outdir",
|
||||
"nsfw_checker",
|
||||
"free_gpu_mem",
|
||||
"max_loaded_models",
|
||||
"xformers_enabled",
|
||||
"always_use_cpu",
|
||||
"embedding_dir",
|
||||
"lora_dir",
|
||||
]:
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
|
||||
# widget library workaround to make max_loaded_models an int rather than a float
|
||||
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||
|
||||
return new_opts
|
||||
|
||||
@ -628,15 +634,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = Args().parse_args([])
|
||||
opts = InvokeAIAppConfig(argv=[])
|
||||
outdir = Path(opts.outdir)
|
||||
if not outdir.is_absolute():
|
||||
opts.outdir = str(Globals.root / opts.outdir)
|
||||
opts.outdir = str(config.root / opts.outdir)
|
||||
if not init_file.exists():
|
||||
opts.safety_checker = True
|
||||
opts.nsfw_checker = True
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||
return Namespace(
|
||||
starter_models=default_dataset()
|
||||
@ -690,70 +695,61 @@ def run_console_ui(
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
"""
|
||||
Update the invokeai.init file with values from opts Namespace
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
# touch file if it doesn't exist
|
||||
if not init_file.exists():
|
||||
with open(init_file, "w") as f:
|
||||
f.write(INIT_FILE_PREAMBLE)
|
||||
|
||||
# We want to write in the changed arguments without clobbering
|
||||
# any other initialization values the user has entered. There is
|
||||
# no good way to do this because of the one-way nature of
|
||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
||||
# initfile needs to be replaced with a fully structured format
|
||||
# such as yaml; this is a hack that will work much of the time
|
||||
args_to_skip = re.compile(
|
||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
||||
)
|
||||
# fix windows paths
|
||||
opts.outdir = opts.outdir.replace("\\", "/")
|
||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
||||
new_file = f"{init_file}.new"
|
||||
try:
|
||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
||||
with open(new_file, "w") as out_file:
|
||||
for line in lines:
|
||||
if len(line) > 0 and not args_to_skip.match(line):
|
||||
out_file.write(line + "\n")
|
||||
out_file.write(
|
||||
f"""
|
||||
--outdir={opts.outdir}
|
||||
--embedding_path={opts.embedding_path}
|
||||
--precision={opts.precision}
|
||||
--max_loaded_models={int(opts.max_loaded_models)}
|
||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
||||
--{'no-' if not opts.xformers else ''}xformers
|
||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
||||
"""
|
||||
)
|
||||
except OSError as e:
|
||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
||||
|
||||
os.replace(new_file, init_file)
|
||||
|
||||
if opts.hf_token:
|
||||
HfLogin(opts.hf_token)
|
||||
# this will load current settings
|
||||
config = InvokeAIAppConfig()
|
||||
for key,value in opts.__dict__.items():
|
||||
if hasattr(config,key):
|
||||
setattr(config,key,value)
|
||||
|
||||
with open(init_file,'w', encoding='utf-8') as file:
|
||||
file.write(config.to_yaml())
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return Globals.root / "outputs"
|
||||
|
||||
return config.root / "outputs"
|
||||
|
||||
# -------------------------------------
|
||||
def default_embedding_dir() -> Path:
|
||||
return Globals.root / "embeddings"
|
||||
return config.root / "embeddings"
|
||||
|
||||
# -------------------------------------
|
||||
def default_lora_dir() -> Path:
|
||||
return config.root / "loras"
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
opt = default_startup_options(initfile)
|
||||
opt.hf_token = HfFolder.get_token()
|
||||
write_opts(opt, initfile)
|
||||
|
||||
# -------------------------------------
|
||||
# Here we bring in
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format:Path):
|
||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||
new = InvokeAIAppConfig(conf={})
|
||||
|
||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||
for attr in fields:
|
||||
if hasattr(old,attr):
|
||||
setattr(new,attr,getattr(old,attr))
|
||||
|
||||
# a few places where the field names have changed and we have to
|
||||
# manually add in the new names/values
|
||||
new.nsfw_checker = old.safety_checker
|
||||
new.xformers_enabled = old.xformers
|
||||
new.conf_path = old.conf
|
||||
new.embedding_dir = old.embedding_path
|
||||
|
||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||
outfile.write(new.to_yaml())
|
||||
|
||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
@ -810,7 +806,8 @@ def main():
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
global config
|
||||
config.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||
|
||||
errors = set()
|
||||
|
||||
@ -818,19 +815,26 @@ def main():
|
||||
models_to_download = default_user_selections(opt)
|
||||
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
init_file = Path(Globals.root, Globals.initfile)
|
||||
if not init_file.exists() or not global_config_file().exists():
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
old_init_file = Path(config.root, 'invokeai.init')
|
||||
new_init_file = Path(config.root, 'invokeai.yaml')
|
||||
if old_init_file.exists() and not new_init_file.exists():
|
||||
print('** Migrating invokeai.init to invokeai.yaml')
|
||||
migrate_init_file(old_init_file)
|
||||
config = get_invokeai_config() # reread defaults
|
||||
|
||||
|
||||
if not config.model_conf_path.exists():
|
||||
initialize_rootdir(config.root, opt.yes_to_all)
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, init_file)
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(
|
||||
precision="float32" if opt.full_precision else "float16"
|
||||
)
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
if init_options:
|
||||
write_opts(init_options, init_file)
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
print(
|
||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||
|
390
invokeai/backend/config/legacy_arg_parsing.py
Normal file
390
invokeai/backend/config/legacy_arg_parsing.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
"auto",
|
||||
"float32",
|
||||
"autocast",
|
||||
"float16",
|
||||
]
|
||||
|
||||
class FileArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Supports reading defaults from an init file.
|
||||
"""
|
||||
def convert_arg_line_to_args(self, arg_line):
|
||||
return shlex.split(arg_line, comments=True)
|
||||
|
||||
|
||||
legacy_parser = FileArgumentParser(
|
||||
description=
|
||||
"""
|
||||
Generate images using Stable Diffusion.
|
||||
Use --web to launch the web interface.
|
||||
Use --from_file to load prompts from a file path or standard input ("-").
|
||||
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||
Other command-line arguments are defaults that can usually be overridden
|
||||
prompt the command prompt.
|
||||
""",
|
||||
fromfile_prefix_chars='@',
|
||||
)
|
||||
general_group = legacy_parser.add_argument_group('General')
|
||||
model_group = legacy_parser.add_argument_group('Model selection')
|
||||
file_group = legacy_parser.add_argument_group('Input/output')
|
||||
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||
render_group = legacy_parser.add_argument_group('Rendering')
|
||||
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
general_group.add_argument(
|
||||
'--version','-V',
|
||||
action='store_true',
|
||||
help='Print InvokeAI version number'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--root_dir',
|
||||
default=None,
|
||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--config',
|
||||
'-c',
|
||||
'-config',
|
||||
dest='conf',
|
||||
default='./configs/models.yaml',
|
||||
help='Path to configuration file for alternate models.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--model',
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--weight_dirs',
|
||||
nargs='+',
|
||||
type=str,
|
||||
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--png_compression','-z',
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0,9),
|
||||
dest='png_compression',
|
||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'-F',
|
||||
'--full_precision',
|
||||
dest='full_precision',
|
||||
action='store_true',
|
||||
help='Deprecated way to set --precision=float32',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--max_loaded_models',
|
||||
dest='max_loaded_models',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
action='store_true',
|
||||
help='Force free gpu memory before final decoding',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--sequential_guidance',
|
||||
dest='sequential_guidance',
|
||||
action='store_true',
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||
"at the expense of speed",
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--xformers',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable/disable xformers support (default enabled if installed)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu",
|
||||
dest="always_use_cpu",
|
||||
action="store_true",
|
||||
help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--precision',
|
||||
dest='precision',
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar='PRECISION',
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--ckpt_convert',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='ckpt_convert',
|
||||
default=True,
|
||||
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--internet',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='internet_available',
|
||||
default=True,
|
||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--nsfw_checker',
|
||||
'--safety_checker',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest='safety_checker',
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
type=str,
|
||||
help='If specified, load prompts from this file',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||
default='outputs',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--prompt_as_dir',
|
||||
'-p',
|
||||
action='store_true',
|
||||
help='Place images in subdirectories named after the prompt.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--fnformat',
|
||||
default='{prefix}.{seed}.png',
|
||||
type=str,
|
||||
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-s',
|
||||
'--steps',
|
||||
type=int,
|
||||
default=50,
|
||||
help='Number of steps'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-W',
|
||||
'--width',
|
||||
type=int,
|
||||
help='Image width, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-H',
|
||||
'--height',
|
||||
type=int,
|
||||
help='Image height, multiple of 64',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-C',
|
||||
'--cfg_scale',
|
||||
default=7.5,
|
||||
type=float,
|
||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--sampler',
|
||||
'-A',
|
||||
'-m',
|
||||
dest='sampler_name',
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar='SAMPLER_NAME',
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default='k_lms',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--log_tokenization',
|
||||
'-t',
|
||||
action='store_true',
|
||||
help='shows how the prompt is split into tokens'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-f',
|
||||
'--strength',
|
||||
type=float,
|
||||
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'-T',
|
||||
'-fit',
|
||||
'--fit',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||
)
|
||||
|
||||
render_group.add_argument(
|
||||
'--grid',
|
||||
'-g',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='generate a grid'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embedding_directory',
|
||||
'--embedding_path',
|
||||
dest='embedding_path',
|
||||
default='embeddings',
|
||||
type=str,
|
||||
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--lora_directory',
|
||||
dest='lora_path',
|
||||
default='loras',
|
||||
type=str,
|
||||
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--embeddings',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--enable_image_debugging',
|
||||
action='store_true',
|
||||
help='Generates debugging image to display'
|
||||
)
|
||||
render_group.add_argument(
|
||||
'--karras_max',
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--no_restore',
|
||||
dest='restore',
|
||||
action='store_false',
|
||||
help='Disable face restoration with GFPGAN or codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--no_upscale',
|
||||
dest='esrgan',
|
||||
action='store_false',
|
||||
help='Disable upscaling with ESRGAN',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
type=int,
|
||||
default=400,
|
||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_denoise_str',
|
||||
type=float,
|
||||
default=0.75,
|
||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--gfpgan_model_path',
|
||||
type=str,
|
||||
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||
help='Indicates the path to the GFPGAN model',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web',
|
||||
dest='web',
|
||||
action='store_true',
|
||||
help='Start in web server mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--web_develop',
|
||||
dest='web_develop',
|
||||
action='store_true',
|
||||
help='Start in web server development mode.',
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--host',
|
||||
type=str,
|
||||
default='127.0.0.1',
|
||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default='9090',
|
||||
help='Web server: Port to listen on'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--certfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--keyfile',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
'--gui',
|
||||
dest='gui',
|
||||
action='store_true',
|
||||
help='Start InvokeAI GUI',
|
||||
)
|
@ -19,13 +19,15 @@ from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..model_management import ModelManager
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = get_invokeai_config()
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
@ -47,12 +49,11 @@ Config_preamble = """
|
||||
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
return config.model_conf_path
|
||||
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
return config.legacy_conf_path
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
@ -121,8 +122,9 @@ def install_requested_models(
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
print('** The global initfile is no longer supported; rewrite to support new yaml format **')
|
||||
initfile = Path(config.root, 'invokeai.init')
|
||||
replacement = Path(config.root, f"invokeai.init.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
with open(initfile, "r") as input:
|
||||
with open(replacement, "w") as output:
|
||||
@ -150,7 +152,7 @@ def get_root(root: str = None) -> str:
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
return config.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
@ -183,7 +185,7 @@ def all_datasets() -> dict:
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
model_path = os.path.join(config.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
@ -228,7 +230,7 @@ def _download_repo_or_file(
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
cache_dir = os.path.join(config.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
@ -239,9 +241,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
model_class: object, model_name: str, **kwargs
|
||||
):
|
||||
path = global_cache_dir(cache_subdir)
|
||||
path = config.cache_dir
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
@ -417,7 +419,7 @@ def new_config_file_contents(
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
successfully_downloaded[model], start=config.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(
|
||||
os.path.join(sd_configs(), mod["config"])
|
||||
@ -456,7 +458,7 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
weights = Path(config.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,126 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
the attributes:
|
||||
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
- initfile - path to the initialization file
|
||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
||||
- always_use_cpu - force use of CPU even if GPU is available
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
Globals = Namespace()
|
||||
|
||||
# Where to look for the initialization file and other key components
|
||||
Globals.initfile = "invokeai.init"
|
||||
Globals.models_file = "models.yaml"
|
||||
Globals.models_dir = "models"
|
||||
Globals.config_dir = "configs"
|
||||
Globals.autoscan_dir = "weights"
|
||||
Globals.converted_ckpts_dir = "converted_ckpts"
|
||||
|
||||
# Set the default root directory. This can be overwritten by explicitly
|
||||
# passing the `--root <directory>` argument on the command line.
|
||||
# logic is:
|
||||
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
|
||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
||||
# 3) use ~/invokeai
|
||||
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
|
||||
elif (
|
||||
os.environ.get("VIRTUAL_ENV")
|
||||
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
|
||||
):
|
||||
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
|
||||
else:
|
||||
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
|
||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
||||
Globals.always_use_cpu = False
|
||||
|
||||
# Whether the internet is reachable for dynamic downloads
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
# Whether to disable xformers
|
||||
Globals.disable_xformers = False
|
||||
|
||||
# Low-memory tradeoff for guidance calculations.
|
||||
Globals.sequential_guidance = False
|
||||
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
# whether we should convert ckpt files into diffusers models on the fly
|
||||
Globals.ckpt_convert = True
|
||||
|
||||
# logging tokenization everywhere
|
||||
Globals.log_tokenization = False
|
||||
|
||||
|
||||
def global_config_file() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
||||
|
||||
|
||||
def global_config_dir() -> Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
||||
def global_models_dir() -> Path:
|
||||
return Path(Globals.root, Globals.models_dir)
|
||||
|
||||
|
||||
def global_autoscan_dir() -> Path:
|
||||
return Path(Globals.root, Globals.autoscan_dir)
|
||||
|
||||
|
||||
def global_converted_ckpts_dir() -> Path:
|
||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
|
||||
def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_resolve_path(path: Union[str,Path]):
|
||||
if path is None:
|
||||
return None
|
||||
return Path(Globals.root,path).resolve()
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
"""
|
||||
Returns Path to the model cache directory. If a subdirectory
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for Hugging Face-style conventions. Currently, Hugging Face has
|
||||
moved all models into the "hub" subfolder, so for any pretrained
|
||||
HF model, use:
|
||||
global_cache_dir('hub')
|
||||
|
||||
The legacy location for transformers used to be global_cache_dir('transformers')
|
||||
and global_cache_dir('diffusers') for diffusers.
|
||||
"""
|
||||
home: str = os.getenv("HF_HOME")
|
||||
|
||||
if home is None:
|
||||
home = os.getenv("XDG_CACHE_HOME")
|
||||
|
||||
if home is not None:
|
||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
|
||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
||||
home += os.sep + "huggingface"
|
||||
|
||||
if home is not None:
|
||||
return Path(home, subdir)
|
||||
else:
|
||||
return Path(Globals.root, "models", subdir)
|
@ -6,7 +6,7 @@ be suppressed or deferred
|
||||
"""
|
||||
import numpy as np
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
@ -21,9 +21,10 @@ class PatchMatch:
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(self):
|
||||
config = get_invokeai_config()
|
||||
if self.tried_load:
|
||||
return
|
||||
if Globals.try_patchmatch:
|
||||
if config.try_patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
|
@ -33,12 +33,11 @@ from PIL import Image, ImageOps
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
@ -84,14 +83,15 @@ class Txt2Mask(object):
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
logger.info("Initializing clipseg model for text to mask inference")
|
||||
config = get_invokeai_config()
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .model_cache import ModelCache
|
||||
@ -76,7 +76,6 @@ from transformers import (
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@ -858,7 +857,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
|
||||
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@ -913,7 +912,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
cache_dir = get_invokeai_config().cache_dir
|
||||
config = CLIPVisionConfig.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||
)
|
||||
@ -985,7 +984,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
cache_dir = global_cache_dir("hub")
|
||||
cache_dir = get_invokeai_config().cache_dir
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||
)
|
||||
@ -1121,7 +1120,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
|
||||
config = get_invokeai_config()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
@ -1134,7 +1133,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
cache_dir = global_cache_dir("hub")
|
||||
cache_dir = config.cache_dir
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if return_generator_pipeline
|
||||
@ -1158,25 +1157,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if model_type == SDLegacyType.V2_v:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
|
||||
config.legacy_conf_path / "v2-inference-v.yaml"
|
||||
)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
|
||||
config.legacy_conf_path / "v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = (
|
||||
global_config_dir()
|
||||
/ "stable-diffusion"
|
||||
/ "v1-inpainting-inference.yaml"
|
||||
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||
)
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = (
|
||||
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
|
||||
config.legacy_conf_path / "v1-inference.yaml"
|
||||
)
|
||||
|
||||
else:
|
||||
@ -1323,7 +1320,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker",
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
cache_dir=config.cache_dir,
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||
|
@ -25,27 +25,25 @@ import warnings
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Sequence, Union, Tuple, types, Optional, List, Type, Any
|
||||
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
||||
from diffusers import logging as diffusers_logging
|
||||
from huggingface_hub import HfApi, scan_cache_dir
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import global_cache_dir
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
def get_model_path(repo_id_or_path: str):
|
||||
globals = get_invokeai_config()
|
||||
|
||||
if os.path.exists(repo_id_or_path):
|
||||
return repo_id_or_path
|
||||
|
||||
cache = scan_cache_dir(global_cache_dir("hub"))
|
||||
cache = scan_cache_dir(globals.cache_dir)
|
||||
for repo in cache.repos:
|
||||
if repo.repo_id != repo_id_or_path:
|
||||
continue
|
||||
@ -234,7 +232,7 @@ class DiffusersModelInfo(ModelInfoBase):
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.repo_id_or_path,
|
||||
subfolder=child_type.value,
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
cache_dir=get_invokeai_config.cache_dir('hub'),
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
)
|
||||
@ -248,7 +246,7 @@ class DiffusersModelInfo(ModelInfoBase):
|
||||
return model
|
||||
|
||||
|
||||
def get_pipeline(self, **kwrags):
|
||||
def get_pipeline(self, **kwargs):
|
||||
return DiffusionPipeline.from_pretrained(
|
||||
self.repo_id_or_path,
|
||||
**kwargs,
|
||||
@ -349,7 +347,7 @@ class ClassifierModelInfo(ModelInfoBase):
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.repo_id_or_path,
|
||||
subfolder=child_type.value,
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
cache_dir=get_invokeai_config().cache_dir('hub'),
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
@ -394,7 +392,7 @@ class VaeModelInfo(ModelInfoBase):
|
||||
|
||||
model = self.vae_type.from_pretrained(
|
||||
self.repo_id_or_path,
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
cache_dir=get_invokeai_config().cache_dir('hub'),
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
|
@ -149,8 +149,7 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import (Globals, global_cache_dir,
|
||||
global_resolve_path)
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from invokeai.backend.util import download_with_resume
|
||||
|
||||
from ..util import CUDA_DEVICE
|
||||
@ -226,7 +225,8 @@ class ModelManager(object):
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self._update_config_file_version()
|
||||
|
||||
self.globals = get_invokeai_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
max_cache_size=max_cache_size,
|
||||
execution_device = device_type,
|
||||
@ -235,7 +235,6 @@ class ModelManager(object):
|
||||
logger = logger,
|
||||
)
|
||||
self.cache_keys = dict()
|
||||
self.logger = logger
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
@ -304,12 +303,6 @@ class ModelManager(object):
|
||||
# raises an InvalidModelError
|
||||
|
||||
"""
|
||||
|
||||
# Commented-out workaround for callers that use "type/name" as the model name
|
||||
# because they haven't adjusted to the new return format of `list_models()`
|
||||
# if "/" in model_name:
|
||||
# model_key = model_name
|
||||
# else:
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
if model_key not in self.config:
|
||||
raise InvalidModelError(
|
||||
@ -326,13 +319,15 @@ class ModelManager(object):
|
||||
if mconfig.format in ["ckpt", "safetensors"]:
|
||||
location = self.convert_ckpt_and_cache(mconfig)
|
||||
else:
|
||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
||||
location = self.globals.root_dir / mconfig.get('path') or mconfig.get('repo_id')
|
||||
elif p := mconfig.get('path'):
|
||||
location = self.globals.root_dir / p
|
||||
elif r := mconfig.get('repo_id'):
|
||||
location = r
|
||||
elif w := mconfig.get('weights'):
|
||||
location = self.globals.root_dir / w
|
||||
else:
|
||||
location = global_resolve_path(
|
||||
mconfig.get('path')) \
|
||||
or mconfig.get('repo_id') \
|
||||
or global_resolve_path(mconfig.get('weights')
|
||||
)
|
||||
location = None
|
||||
|
||||
revision = mconfig.get('revision')
|
||||
hash = self.cache.model_hash(location, revision)
|
||||
@ -423,7 +418,7 @@ class ModelManager(object):
|
||||
"""
|
||||
# if we are converting legacy files automatically, then
|
||||
# there are no legacy ckpts!
|
||||
if Globals.ckpt_convert:
|
||||
if self.globals.ckpt_convert:
|
||||
return False
|
||||
info = self.model_info(model_name, model_type)
|
||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||
@ -862,25 +857,16 @@ class ModelManager(object):
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
self.logger.debug("SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
self.logger.debug("SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml",
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
self.logger.debug("SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
self.logger.debug("SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||
elif model_type == SDLegacyType.V2:
|
||||
self.logger.warning(
|
||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
@ -907,9 +893,7 @@ class ModelManager(object):
|
||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
diffuser_path = self.globals.converted_ckpts_dir / model_path.stem
|
||||
with SilenceWarnings():
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
@ -930,9 +914,9 @@ class ModelManager(object):
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
weights = global_resolve_path(mconfig.weights)
|
||||
config_file = global_resolve_path(mconfig.config)
|
||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
|
||||
weights = self.globals.root_dir / mconfig.weights
|
||||
config_file = self.globals.root_dir / mconfig.config
|
||||
diffusers_path = self.globals.converted_ckpts_dir / weights.stem
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
@ -949,7 +933,7 @@ class ModelManager(object):
|
||||
extract_ema=True,
|
||||
original_config_file=config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
|
||||
vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||
scan_needed=True,
|
||||
)
|
||||
return diffusers_path
|
||||
@ -960,9 +944,10 @@ class ModelManager(object):
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
weights_file = global_resolve_path(mconfig.weights)
|
||||
config_file = global_resolve_path(mconfig.config)
|
||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
|
||||
root = self.globals.root_dir
|
||||
weights_file = root / mconfig.weights
|
||||
config_file = root / mconfig.config
|
||||
diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem
|
||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||
|
||||
# return cached version if it exists
|
||||
@ -1018,7 +1003,9 @@ class ModelManager(object):
|
||||
|
||||
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
||||
if vae_config and isinstance(vae_config,DictConfig):
|
||||
vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
|
||||
vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \
|
||||
if vae_config.get('path') \
|
||||
else vae_config.get('repo_id')
|
||||
|
||||
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
||||
else:
|
||||
@ -1072,7 +1059,9 @@ class ModelManager(object):
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
||||
vae_location = self.globals.root_dir / vae.get('path') \
|
||||
if vae.get('path') \
|
||||
else vae.get('repo_id')
|
||||
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
@ -1140,6 +1129,7 @@ class ModelManager(object):
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
config_file_path = conf_file or self.config_path
|
||||
assert config_file_path is not None,'no config file path to write to'
|
||||
config_file_path = self.globals.root_dir / config_file_path
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
@ -1160,7 +1150,7 @@ class ModelManager(object):
|
||||
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
@ -1177,9 +1167,10 @@ class ModelManager(object):
|
||||
|
||||
@staticmethod
|
||||
def _abs_path(path: str | Path) -> Path:
|
||||
globals = get_invokeai_config()
|
||||
if path is None or Path(path).is_absolute():
|
||||
return path
|
||||
return Path(Globals.root, path).resolve()
|
||||
return Path(globals.root_dir, path).resolve()
|
||||
|
||||
# This is not the same as global_resolve_path(), which prepends
|
||||
# Globals.root.
|
||||
@ -1188,15 +1179,11 @@ class ModelManager(object):
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory = self.globals.root_dir / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
resolved_path = Path(source)
|
||||
resolved_path = self.globals.root_dir / source
|
||||
return resolved_path
|
||||
|
||||
def _update_config_file_version(self):
|
||||
|
@ -17,67 +17,59 @@ from compel.prompt_parser import (
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(
|
||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
||||
):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
prompt_string
|
||||
)
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False
|
||||
)
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
config = get_invokeai_config()
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or config.log_tokenization:
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
@ -88,18 +80,17 @@ def get_prompt_structure(
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
||||
negative_prompt_string
|
||||
)
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
@ -246,22 +237,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(
|
||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
||||
)
|
||||
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
pretrained_model_url = (
|
||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||
@ -17,11 +17,11 @@ class CodeFormerRestoration:
|
||||
def __init__(
|
||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||
) -> None:
|
||||
if not os.path.isabs(codeformer_dir):
|
||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
self.globals = get_invokeai_config()
|
||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||
self.model_path = codeformer_dir / codeformer_model_path
|
||||
self.codeformer_model_exists = self.model_path.exists()
|
||||
|
||||
if not self.codeformer_model_exists:
|
||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||
@ -71,9 +71,7 @@ class CodeFormerRestoration:
|
||||
upscale_factor=1,
|
||||
use_parse=True,
|
||||
device=device,
|
||||
model_rootpath=os.path.join(
|
||||
Globals.root, "models", "gfpgan", "weights"
|
||||
),
|
||||
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
|
||||
)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(bgr_image_array)
|
||||
|
@ -7,14 +7,13 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
class GFPGAN:
|
||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||
self.globals = get_invokeai_config()
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
gfpgan_model_path = os.path.abspath(
|
||||
os.path.join(Globals.root, gfpgan_model_path)
|
||||
)
|
||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||
self.model_path = gfpgan_model_path
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
@ -33,7 +32,7 @@ class GFPGAN:
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
cwd = os.getcwd()
|
||||
os.chdir(os.path.join(Globals.root, "models"))
|
||||
os.chdir(self.globals.root_dir / 'models')
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -7,7 +6,8 @@ from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
config = get_invokeai_config()
|
||||
|
||||
class ESRGAN:
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
@ -30,12 +30,8 @@ class ESRGAN:
|
||||
upscale=4,
|
||||
act_type="prelu",
|
||||
)
|
||||
model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||
)
|
||||
wdn_model_path = os.path.join(
|
||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
)
|
||||
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
|
||||
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
|
@ -15,7 +15,7 @@ from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .globals import global_cache_dir
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
class SafetyChecker(object):
|
||||
@ -26,10 +26,11 @@ class SafetyChecker(object):
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
self.device = device
|
||||
|
||||
config = get_invokeai_config()
|
||||
|
||||
try:
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
safety_model_path = config.cache_dir
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
|
@ -18,15 +18,15 @@ from huggingface_hub import (
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.root = root or Globals.root
|
||||
self.config = get_invokeai_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif Globals.internet_available is True:
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
|
@ -33,8 +33,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ..util import CPU_DEVICE, normalize_device
|
||||
from .diffusion import (
|
||||
AttentionMapSaver,
|
||||
@ -44,7 +43,6 @@ from .diffusion import (
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
run_id: str
|
||||
@ -348,10 +346,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = get_invokeai_config()
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and is_xformers_available()
|
||||
and not Globals.disable_xformers
|
||||
and not config.disable_xformers
|
||||
):
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
@ -548,8 +547,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
yield PipelineIntermediateState(
|
||||
run_id=run_id,
|
||||
|
@ -10,6 +10,7 @@ import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from torch import nn
|
||||
|
||||
@ -352,8 +353,7 @@ def restore_default_cross_attention(
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(
|
||||
p.slice_size
|
||||
for p in old_attn_processors.values()
|
||||
if type(p) is SlicedAttnProcessor
|
||||
),
|
||||
default_slice_size,
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
return None
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
def get_cross_attention_modules(
|
||||
model, which: CrossAttentionType
|
||||
|
@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
@ -17,8 +18,8 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
override_cross_attention,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
@ -31,7 +32,6 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
@ -72,31 +72,43 @@ class InvokeAIDiffuserComponent:
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
config = get_invokeai_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
cls,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
do_swap = (
|
||||
extra_conditioning_info is not None
|
||||
and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
old_attn_processor = None
|
||||
if do_swap:
|
||||
old_attn_processor = self.override_cross_attention(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
|
@ -9,7 +9,8 @@ SCHEDULER_MAP = dict(
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
|
@ -7,7 +7,6 @@
|
||||
This is the backend to "textual_inversion.py"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@ -47,8 +46,7 @@ from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from ..args import ArgFormatter, PagingArgumentParser
|
||||
from ..globals import Globals, global_cache_dir
|
||||
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
@ -90,8 +88,9 @@ def save_progress(
|
||||
|
||||
|
||||
def parse_args():
|
||||
config = InvokeAIAppConfig(argv=[])
|
||||
parser = PagingArgumentParser(
|
||||
description="Textual inversion training", formatter_class=ArgFormatter
|
||||
description="Textual inversion training"
|
||||
)
|
||||
general_group = parser.add_argument_group("General")
|
||||
model_group = parser.add_argument_group("Models and Paths")
|
||||
@ -112,7 +111,7 @@ def parse_args():
|
||||
"--root_dir",
|
||||
"--root",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
default=config.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
general_group.add_argument(
|
||||
@ -127,7 +126,7 @@ def parse_args():
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f"{Globals.root}/text-inversion-model",
|
||||
default=f"{config.root}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@ -528,6 +527,7 @@ def get_full_repo_name(
|
||||
|
||||
|
||||
def do_textual_inversion_training(
|
||||
config: InvokeAIAppConfig,
|
||||
model: str,
|
||||
train_data_dir: Path,
|
||||
output_dir: Path,
|
||||
@ -580,7 +580,7 @@ def do_textual_inversion_training(
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(Globals.root, output_dir)
|
||||
output_dir = os.path.join(config.root, output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
@ -628,7 +628,7 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
models_conf = OmegaConf.load(config.model_conf_path)
|
||||
model_conf = models_conf.get(model, None)
|
||||
assert model_conf is not None, f"Unknown model: {model}"
|
||||
assert (
|
||||
@ -640,7 +640,7 @@ def do_textual_inversion_training(
|
||||
assert (
|
||||
pretrained_model_name_or_path
|
||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir("hub"))
|
||||
pipeline_args = dict(cache_dir=config.cache_dir)
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_name:
|
||||
|
@ -4,17 +4,16 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
|
||||
from invokeai.backend.globals import Globals
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
if Globals.always_use_cpu:
|
||||
config = get_invokeai_config()
|
||||
if config.always_use_cpu:
|
||||
return CPU_DEVICE
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
@ -33,7 +32,8 @@ def choose_precision(device: torch.device) -> str:
|
||||
|
||||
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
if Globals.full_precision:
|
||||
config = get_invokeai_config()
|
||||
if config.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == "float16":
|
||||
return torch.float16
|
||||
|
@ -2,34 +2,37 @@
|
||||
|
||||
"""invokeai.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages that follow
|
||||
the conventions established in InvokeAI 1.X through 2.X.
|
||||
Logging class for InvokeAI that produces console messages
|
||||
|
||||
|
||||
One way to use it:
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.getLogger(__name__)
|
||||
logger.critical('this is critical')
|
||||
logger.error('this is an error')
|
||||
logger.warning('this is a warning')
|
||||
logger.info('this is info')
|
||||
logger.debug('this is debugging')
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||
(or)
|
||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||
|
||||
logger.critical('this is critical') // Critical Message
|
||||
logger.error('this is an error') // Error Message
|
||||
logger.warning('this is a warning') // Warning Message
|
||||
logger.info('this is info') // Info Message
|
||||
logger.debug('this is debugging') // Debug Message
|
||||
|
||||
Console messages:
|
||||
### this is critical
|
||||
*** this is an error ***
|
||||
** this is a warning
|
||||
>> this is info
|
||||
| this is debugging
|
||||
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
|
||||
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
|
||||
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
|
||||
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
|
||||
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
|
||||
|
||||
Another way:
|
||||
import invokeai.backend.util.logging as ialog
|
||||
ialogger.debug('this is a debugging message')
|
||||
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||
import invokeai.backend.util.logging as IAILogger
|
||||
IAILogger.debug('this is a debugging message')
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||
@ -42,7 +45,7 @@ def warning(msg, *args, **kwargs):
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
||||
|
||||
@ -55,49 +58,47 @@ def disable(level=logging.CRITICAL):
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||
|
||||
def getLogger(name: str=None)->logging.Logger:
|
||||
def getLogger(name: str = None) -> logging.Logger:
|
||||
return InvokeAILogger.getLogger(name)
|
||||
|
||||
|
||||
class InvokeAILogFormatter(logging.Formatter):
|
||||
'''
|
||||
Repurposed from:
|
||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
'''
|
||||
crit_fmt = "### %(msg)s"
|
||||
err_fmt = "*** %(msg)s"
|
||||
warn_fmt = "** %(msg)s"
|
||||
info_fmt = ">> %(msg)s"
|
||||
dbg_fmt = " | %(msg)s"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
cyan = "\x1b[36;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
reset = "\x1b[0m"
|
||||
|
||||
# Log Format
|
||||
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
||||
|
||||
# Format Map
|
||||
FORMATS = {
|
||||
logging.DEBUG: cyan + format + reset,
|
||||
logging.INFO: grey + format + reset,
|
||||
logging.WARNING: yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
# Remember the format used when the logging module
|
||||
# was installed (in the event that this formatter is
|
||||
# used with the vanilla logging module.
|
||||
format_orig = self._style._fmt
|
||||
if record.levelno == logging.DEBUG:
|
||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
||||
if record.levelno == logging.INFO:
|
||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
||||
if record.levelno == logging.WARNING:
|
||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
||||
if record.levelno == logging.ERROR:
|
||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
||||
if record.levelno == logging.CRITICAL:
|
||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
||||
return formatter.format(record)
|
||||
|
||||
# parent class does the work
|
||||
result = super().format(record)
|
||||
self._style._fmt = format_orig
|
||||
return result
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
|
||||
@classmethod
|
||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
||||
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
|
||||
if name not in self.loggers:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
'heun_k',
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,498 +0,0 @@
|
||||
"""
|
||||
Readline helper functions for invoke.py.
|
||||
You may import the global singleton `completer` to get access to the
|
||||
completer object itself. This is useful when you want to autocomplete
|
||||
seeds:
|
||||
|
||||
from invokeai.frontend.CLI.readline import completer
|
||||
completer.add_seed(18247566)
|
||||
completer.add_seed(9281839)
|
||||
"""
|
||||
import atexit
|
||||
import os
|
||||
import re
|
||||
|
||||
from ...backend.args import Args
|
||||
from ...backend.globals import Globals
|
||||
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
||||
|
||||
# ---------------readline utilities---------------------
|
||||
try:
|
||||
import readline
|
||||
|
||||
readline_available = True
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
print(f"** An error occurred when loading the readline module: {str(e)}")
|
||||
readline_available = False
|
||||
|
||||
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
|
||||
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
|
||||
TEXT_EXTENSIONS = (".txt", ".TXT")
|
||||
CONFIG_EXTENSIONS = (".yaml", ".yml")
|
||||
COMMANDS = (
|
||||
"--steps",
|
||||
"-s",
|
||||
"--seed",
|
||||
"-S",
|
||||
"--iterations",
|
||||
"-n",
|
||||
"--width",
|
||||
"-W",
|
||||
"--height",
|
||||
"-H",
|
||||
"--cfg_scale",
|
||||
"-C",
|
||||
"--threshold",
|
||||
"--perlin",
|
||||
"--grid",
|
||||
"-g",
|
||||
"--individual",
|
||||
"-i",
|
||||
"--save_intermediates",
|
||||
"--init_img",
|
||||
"-I",
|
||||
"--init_mask",
|
||||
"-M",
|
||||
"--init_color",
|
||||
"--strength",
|
||||
"-f",
|
||||
"--variants",
|
||||
"-v",
|
||||
"--outdir",
|
||||
"-o",
|
||||
"--sampler",
|
||||
"-A",
|
||||
"-m",
|
||||
"--embedding_path",
|
||||
"--device",
|
||||
"--grid",
|
||||
"-g",
|
||||
"--facetool",
|
||||
"-ft",
|
||||
"--facetool_strength",
|
||||
"-G",
|
||||
"--codeformer_fidelity",
|
||||
"-cf",
|
||||
"--upscale",
|
||||
"-U",
|
||||
"-save_orig",
|
||||
"--save_original",
|
||||
"--log_tokenization",
|
||||
"-t",
|
||||
"--hires_fix",
|
||||
"--inpaint_replace",
|
||||
"-r",
|
||||
"--png_compression",
|
||||
"-z",
|
||||
"--text_mask",
|
||||
"-tm",
|
||||
"--h_symmetry_time_pct",
|
||||
"--v_symmetry_time_pct",
|
||||
"!fix",
|
||||
"!fetch",
|
||||
"!replay",
|
||||
"!history",
|
||||
"!search",
|
||||
"!clear",
|
||||
"!models",
|
||||
"!switch",
|
||||
"!import_model",
|
||||
"!optimize_model",
|
||||
"!convert_model",
|
||||
"!edit_model",
|
||||
"!del_model",
|
||||
"!mask",
|
||||
"!triggers",
|
||||
)
|
||||
MODEL_COMMANDS = (
|
||||
"!switch",
|
||||
"!edit_model",
|
||||
"!del_model",
|
||||
)
|
||||
CKPT_MODEL_COMMANDS = ("!optimize_model",)
|
||||
WEIGHT_COMMANDS = (
|
||||
"!import_model",
|
||||
"!convert_model",
|
||||
)
|
||||
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
|
||||
TEXT_PATH_COMMANDS = ("!replay",)
|
||||
IMG_FILE_COMMANDS = (
|
||||
"!fix",
|
||||
"!fetch",
|
||||
"!mask",
|
||||
"--init_img[=\s]",
|
||||
"-I",
|
||||
"--init_mask[=\s]",
|
||||
"-M",
|
||||
"--init_color[=\s]",
|
||||
"--embedding_path[=\s]",
|
||||
)
|
||||
|
||||
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
|
||||
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
|
||||
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
|
||||
|
||||
|
||||
class Completer(object):
|
||||
def __init__(self, options, models={}):
|
||||
self.options = sorted(options)
|
||||
self.models = models
|
||||
self.seeds = set()
|
||||
self.matches = list()
|
||||
self.default_dir = None
|
||||
self.linebuffer = None
|
||||
self.auto_history_active = True
|
||||
self.extensions = None
|
||||
self.concepts = None
|
||||
self.embedding_terms = set()
|
||||
return
|
||||
|
||||
def complete(self, text, state):
|
||||
"""
|
||||
Completes invoke command line.
|
||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
||||
"""
|
||||
buffer = readline.get_line_buffer()
|
||||
|
||||
if state == 0:
|
||||
# extensions defined, so go directly into path completion mode
|
||||
if self.extensions is not None:
|
||||
self.matches = self._path_completions(text, state, self.extensions)
|
||||
|
||||
# looking for an image file
|
||||
elif re.search(path_regexp, buffer):
|
||||
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
|
||||
self.matches = self._path_completions(
|
||||
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
|
||||
)
|
||||
|
||||
# looking for a seed
|
||||
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
|
||||
self.matches = self._seed_completions(text, state)
|
||||
|
||||
# looking for an embedding concept
|
||||
elif re.search("<[\w-]*$", buffer):
|
||||
self.matches = self._concept_completions(text, state)
|
||||
|
||||
# looking for a model
|
||||
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
|
||||
self.matches = self._model_completions(text, state)
|
||||
|
||||
# looking for a ckpt model
|
||||
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
|
||||
self.matches = self._model_completions(text, state, ckpt_only=True)
|
||||
|
||||
elif re.search(weight_regexp, buffer):
|
||||
self.matches = self._path_completions(
|
||||
text,
|
||||
state,
|
||||
WEIGHT_EXTENSIONS,
|
||||
default_dir=Globals.root,
|
||||
)
|
||||
|
||||
elif re.search(text_regexp, buffer):
|
||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
||||
|
||||
# This is the first time for this text, so build a match list.
|
||||
elif text:
|
||||
self.matches = [s for s in self.options if s and s.startswith(text)]
|
||||
else:
|
||||
self.matches = self.options[:]
|
||||
|
||||
# Return the state'th item from the match list,
|
||||
# if we have that many.
|
||||
try:
|
||||
response = self.matches[state]
|
||||
except IndexError:
|
||||
response = None
|
||||
return response
|
||||
|
||||
def complete_extensions(self, extensions: list):
|
||||
"""
|
||||
If called with a list of extensions, will force completer
|
||||
to do file path completions.
|
||||
"""
|
||||
self.extensions = extensions
|
||||
|
||||
def add_history(self, line):
|
||||
"""
|
||||
Pass thru to readline
|
||||
"""
|
||||
if not self.auto_history_active:
|
||||
readline.add_history(line)
|
||||
|
||||
def clear_history(self):
|
||||
"""
|
||||
Pass clear_history() thru to readline
|
||||
"""
|
||||
readline.clear_history()
|
||||
|
||||
def search_history(self, match: str):
|
||||
"""
|
||||
Like show_history() but only shows items that
|
||||
contain the match string.
|
||||
"""
|
||||
self.show_history(match)
|
||||
|
||||
def remove_history_item(self, pos):
|
||||
readline.remove_history_item(pos)
|
||||
|
||||
def add_seed(self, seed):
|
||||
"""
|
||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
||||
"""
|
||||
if seed is not None:
|
||||
self.seeds.add(str(seed))
|
||||
|
||||
def set_default_dir(self, path):
|
||||
self.default_dir = path
|
||||
|
||||
def set_options(self, options):
|
||||
self.options = options
|
||||
|
||||
def get_line(self, index):
|
||||
try:
|
||||
line = self.get_history_item(index)
|
||||
except IndexError:
|
||||
return None
|
||||
return line
|
||||
|
||||
def get_current_history_length(self):
|
||||
return readline.get_current_history_length()
|
||||
|
||||
def get_history_item(self, index):
|
||||
return readline.get_history_item(index)
|
||||
|
||||
def show_history(self, match=None):
|
||||
"""
|
||||
Print the session history using the pydoc pager
|
||||
"""
|
||||
import pydoc
|
||||
|
||||
lines = list()
|
||||
h_len = self.get_current_history_length()
|
||||
if h_len < 1:
|
||||
print("<empty history>")
|
||||
return
|
||||
|
||||
for i in range(0, h_len):
|
||||
line = self.get_history_item(i + 1)
|
||||
if match and match not in line:
|
||||
continue
|
||||
lines.append(f"[{i+1}] {line}")
|
||||
pydoc.pager("\n".join(lines))
|
||||
|
||||
def set_line(self, line) -> None:
|
||||
"""
|
||||
Set the default string displayed in the next line of input.
|
||||
"""
|
||||
self.linebuffer = line
|
||||
readline.redisplay()
|
||||
|
||||
def update_models(self, models: dict) -> None:
|
||||
"""
|
||||
update our list of models
|
||||
"""
|
||||
self.models = models
|
||||
|
||||
def _seed_completions(self, text, state):
|
||||
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ""
|
||||
partial = text
|
||||
|
||||
matches = list()
|
||||
for s in self.seeds:
|
||||
if s.startswith(partial):
|
||||
matches.append(switch + s)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def add_embedding_terms(self, terms: list[str]):
|
||||
self.embedding_terms = set(terms)
|
||||
if self.concepts:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
partial = text[1:] # this removes the leading '<'
|
||||
if len(partial) == 0:
|
||||
return list(self.embedding_terms) # whole dump - think if user wants this!
|
||||
|
||||
matches = list()
|
||||
for concept in self.embedding_terms:
|
||||
if concept.startswith(partial):
|
||||
matches.append(f"<{concept}>")
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _model_completions(self, text, state, ckpt_only=False):
|
||||
m = re.search("(!switch\s+)(\w*)", text)
|
||||
if m:
|
||||
switch = m.groups()[0]
|
||||
partial = m.groups()[1]
|
||||
else:
|
||||
switch = ""
|
||||
partial = text
|
||||
matches = list()
|
||||
for s in self.models:
|
||||
name = self.models[s]["model_name"]
|
||||
format = self.models[s]["format"]
|
||||
if format == "vae":
|
||||
continue
|
||||
if ckpt_only and format != "ckpt":
|
||||
continue
|
||||
if name.startswith(partial):
|
||||
matches.append(switch + name)
|
||||
matches.sort()
|
||||
return matches
|
||||
|
||||
def _pre_input_hook(self):
|
||||
if self.linebuffer:
|
||||
readline.insert_text(self.linebuffer)
|
||||
readline.redisplay()
|
||||
self.linebuffer = None
|
||||
|
||||
def _path_completions(
|
||||
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
|
||||
):
|
||||
# separate the switch from the partial path
|
||||
match = re.search("^(-\w|--\w+=?)(.*)", text)
|
||||
if match is None:
|
||||
switch = None
|
||||
partial_path = text
|
||||
else:
|
||||
switch, partial_path = match.groups()
|
||||
|
||||
partial_path = partial_path.lstrip()
|
||||
|
||||
matches = list()
|
||||
path = os.path.expanduser(partial_path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
dir = path
|
||||
elif os.path.dirname(path) != "":
|
||||
dir = os.path.dirname(path)
|
||||
else:
|
||||
dir = default_dir if os.path.exists(default_dir) else ""
|
||||
path = os.path.join(dir, path)
|
||||
|
||||
dir_list = os.listdir(dir or ".")
|
||||
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
|
||||
dir_list += os.listdir(self.default_dir)
|
||||
|
||||
for node in dir_list:
|
||||
if node.startswith(".") and len(node) > 1:
|
||||
continue
|
||||
full_path = os.path.join(dir, node)
|
||||
|
||||
if not (node.endswith(extensions) or os.path.isdir(full_path)):
|
||||
continue
|
||||
|
||||
if path and not full_path.startswith(path):
|
||||
continue
|
||||
|
||||
if switch is None:
|
||||
match_path = os.path.join(dir, node)
|
||||
matches.append(
|
||||
match_path + "/" if os.path.isdir(full_path) else match_path
|
||||
)
|
||||
elif os.path.isdir(full_path):
|
||||
matches.append(
|
||||
switch + os.path.join(os.path.dirname(full_path), node) + "/"
|
||||
)
|
||||
elif node.endswith(extensions):
|
||||
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
class DummyCompleter(Completer):
|
||||
def __init__(self, options):
|
||||
super().__init__(options)
|
||||
self.history = list()
|
||||
|
||||
def add_history(self, line):
|
||||
self.history.append(line)
|
||||
|
||||
def clear_history(self):
|
||||
self.history = list()
|
||||
|
||||
def get_current_history_length(self):
|
||||
return len(self.history)
|
||||
|
||||
def get_history_item(self, index):
|
||||
return self.history[index - 1]
|
||||
|
||||
def remove_history_item(self, index):
|
||||
return self.history.pop(index - 1)
|
||||
|
||||
def set_line(self, line):
|
||||
print(f"# {line}")
|
||||
|
||||
|
||||
def generic_completer(commands: list) -> Completer:
|
||||
if readline_available:
|
||||
completer = Completer(commands, [])
|
||||
readline.set_completer(completer.complete)
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
else:
|
||||
completer = DummyCompleter(commands)
|
||||
return completer
|
||||
|
||||
|
||||
def get_completer(opt: Args, models=[]) -> Completer:
|
||||
if readline_available:
|
||||
completer = Completer(COMMANDS, models)
|
||||
|
||||
readline.set_completer(completer.complete)
|
||||
# pyreadline3 does not have a set_auto_history() method
|
||||
try:
|
||||
readline.set_auto_history(False)
|
||||
completer.auto_history_active = False
|
||||
except:
|
||||
completer.auto_history_active = True
|
||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
||||
readline.set_completer_delims(" ")
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set print-completions-horizontally off")
|
||||
readline.parse_and_bind("set page-completions on")
|
||||
readline.parse_and_bind("set skip-completed-text on")
|
||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||
|
||||
outdir = os.path.expanduser(opt.outdir)
|
||||
if os.path.isabs(outdir):
|
||||
histfile = os.path.join(outdir, ".invoke_history")
|
||||
else:
|
||||
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
|
||||
try:
|
||||
readline.read_history_file(histfile)
|
||||
readline.set_history_length(1000)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError: # file likely corrupted
|
||||
newname = f"{histfile}.old"
|
||||
print(
|
||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||
)
|
||||
os.replace(histfile, newname)
|
||||
atexit.register(readline.write_history_file, histfile)
|
||||
|
||||
else:
|
||||
completer = DummyCompleter(COMMANDS)
|
||||
return completer
|
@ -1,30 +0,0 @@
|
||||
'''
|
||||
This is a modularized version of the sd-metadata.py script,
|
||||
which retrieves and prints the metadata from a series of generated png files.
|
||||
'''
|
||||
import sys
|
||||
import json
|
||||
from invokeai.backend.image_util import retrieve_metadata
|
||||
|
||||
|
||||
def print_metadata():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
|
||||
exit(-1)
|
||||
|
||||
filenames = sys.argv[1:]
|
||||
for f in filenames:
|
||||
try:
|
||||
metadata = retrieve_metadata(f)
|
||||
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
|
||||
except FileNotFoundError:
|
||||
sys.stderr.write(f'{f} not found\n')
|
||||
continue
|
||||
except PermissionError:
|
||||
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
|
||||
continue
|
||||
|
||||
if __name__== '__main__':
|
||||
print_metadata()
|
||||
|
@ -23,7 +23,6 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_config_dir
|
||||
|
||||
from ...backend.config.model_install_backend import (
|
||||
Dataset_path,
|
||||
@ -41,11 +40,13 @@ from .widgets import (
|
||||
TextBox,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
|
||||
# minimum size for the UI
|
||||
MIN_COLS = 120
|
||||
MIN_LINES = 45
|
||||
|
||||
config = get_invokeai_config()
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
@ -453,9 +454,9 @@ def main():
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
config.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
if not (config.conf_path / '..' ).exists():
|
||||
logger.info(
|
||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
|
@ -8,7 +8,6 @@ import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
@ -20,20 +19,13 @@ from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ...backend.globals import (
|
||||
Globals,
|
||||
global_cache_dir,
|
||||
global_config_file,
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.services.config import get_invokeai_config
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
config = get_invokeai_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
@ -60,7 +52,7 @@ def merge_diffusion_models(
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
@ -94,7 +86,7 @@ def merge_diffusion_models_and_commit(
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = global_config_file()
|
||||
config_file = config.model_conf_path
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
@ -106,7 +98,7 @@ def merge_diffusion_models_and_commit(
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
@ -126,7 +118,7 @@ def _parse_args() -> Namespace:
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
default=config.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -398,7 +390,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
conf = OmegaConf.load(config.model_conf_path)
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
@ -429,7 +421,7 @@ def run_cli(args: Namespace):
|
||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
@ -440,9 +432,9 @@ def run_cli(args: Namespace):
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
config.root = args.root_dir
|
||||
|
||||
cache_dir = str(global_cache_dir("hub"))
|
||||
cache_dir = config.cache_dir
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
|
@ -21,14 +21,17 @@ from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import Globals, global_set_root
|
||||
|
||||
from ...backend.training import do_textual_inversion_training, parse_args
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from ...backend.training import (
|
||||
do_textual_inversion_training,
|
||||
parse_args
|
||||
)
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
|
||||
config = None
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
@ -122,7 +125,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
||||
config.root_dir / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
@ -135,7 +138,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
||||
config.root_dir / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
@ -241,9 +244,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(
|
||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
||||
config.root_dir / TRAINING_DATA / placeholder
|
||||
)
|
||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
||||
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
@ -284,7 +287,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||
model_names = [
|
||||
idx
|
||||
for idx in sorted(list(conf.keys()))
|
||||
@ -367,7 +370,7 @@ def copy_to_embeddings_folder(args: dict):
|
||||
"""
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
@ -383,7 +386,7 @@ def save_args(args: dict):
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
||||
dest_dir = config.root_dir / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
@ -394,7 +397,7 @@ def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
@ -420,7 +423,7 @@ def do_front_end(args: Namespace):
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
do_textual_inversion_training(**args)
|
||||
do_textual_inversion_training(get_invokeai_config(),**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
@ -430,13 +433,20 @@ def do_front_end(args: Namespace):
|
||||
|
||||
|
||||
def main():
|
||||
global config
|
||||
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
config = get_invokeai_config(argv=[])
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
config.root = args.root_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end(args)
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
do_textual_inversion_training(config,**vars(args))
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
||||
|
@ -62,11 +62,13 @@
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@floating-ui/react-dom": "^2.0.0",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"downshift": "^7.6.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^10.12.4",
|
||||
"fuse.js": "^6.6.2",
|
||||
@ -87,18 +89,13 @@
|
||||
"react-i18next": "^12.2.2",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-konva": "^18.2.7",
|
||||
"react-konva-utils": "^1.0.4",
|
||||
"react-redux": "^8.0.5",
|
||||
"react-resizable-panels": "^0.0.42",
|
||||
"react-rnd": "^10.4.1",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-use": "^17.4.0",
|
||||
"react-virtuoso": "^4.3.5",
|
||||
"react-zoom-pan-pinch": "^3.0.7",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"redux-remember": "^3.3.1",
|
||||
"roarr": "^7.15.0",
|
||||
"serialize-error": "^11.0.0",
|
||||
|
@ -1,24 +0,0 @@
|
||||
diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
index b67b8c2..7fc0fa1 100644
|
||||
--- a/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
+++ b/node_modules/redux-deep-persist/lib/types.d.ts
|
||||
@@ -35,6 +35,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
getStoredState?: (config: PersistConfig<S, RS, HSS, ESS>) => Promise<PersistedState>;
|
||||
diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts
|
||||
index 398ac19..cbc5663 100644
|
||||
--- a/node_modules/redux-deep-persist/src/types.ts
|
||||
+++ b/node_modules/redux-deep-persist/src/types.ts
|
||||
@@ -91,6 +91,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
/**
|
@ -1,116 +0,0 @@
|
||||
diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js
|
||||
index 8b43b9a..184faab 100644
|
||||
--- a/node_modules/redux-persist/es/createPersistoid.js
|
||||
+++ b/node_modules/redux-persist/es/createPersistoid.js
|
||||
@@ -6,6 +6,7 @@ export default function createPersistoid(config) {
|
||||
var whitelist = config.whitelist || null;
|
||||
var transforms = config.transforms || [];
|
||||
var throttle = config.throttle || 0;
|
||||
+ var debounce = config.debounce || 0;
|
||||
var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key);
|
||||
var storage = config.storage;
|
||||
var serialize;
|
||||
@@ -28,30 +29,37 @@ export default function createPersistoid(config) {
|
||||
var timeIterator = null;
|
||||
var writePromise = null;
|
||||
|
||||
- var update = function update(state) {
|
||||
- // add any changed keys to the queue
|
||||
- Object.keys(state).forEach(function (key) {
|
||||
- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
||||
+ // Timer for debounced `update()`
|
||||
+ let timer = 0;
|
||||
|
||||
- if (lastState[key] === state[key]) return; // value unchanged? noop
|
||||
+ function update(state) {
|
||||
+ // Debounce the update
|
||||
+ clearTimeout(timer);
|
||||
+ timer = setTimeout(() => {
|
||||
+ // add any changed keys to the queue
|
||||
+ Object.keys(state).forEach(function (key) {
|
||||
+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
||||
|
||||
- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
||||
+ if (lastState[key] === state[key]) return; // value unchanged? noop
|
||||
|
||||
- keysToProcess.push(key); // add key to queue
|
||||
- }); //if any key is missing in the new state which was present in the lastState,
|
||||
- //add it for processing too
|
||||
+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
||||
|
||||
- Object.keys(lastState).forEach(function (key) {
|
||||
- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
||||
- keysToProcess.push(key);
|
||||
- }
|
||||
- }); // start the time iterator if not running (read: throttle)
|
||||
+ keysToProcess.push(key); // add key to queue
|
||||
+ }); //if any key is missing in the new state which was present in the lastState,
|
||||
+ //add it for processing too
|
||||
|
||||
- if (timeIterator === null) {
|
||||
- timeIterator = setInterval(processNextKey, throttle);
|
||||
- }
|
||||
+ Object.keys(lastState).forEach(function (key) {
|
||||
+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
||||
+ keysToProcess.push(key);
|
||||
+ }
|
||||
+ }); // start the time iterator if not running (read: throttle)
|
||||
+
|
||||
+ if (timeIterator === null) {
|
||||
+ timeIterator = setInterval(processNextKey, throttle);
|
||||
+ }
|
||||
|
||||
- lastState = state;
|
||||
+ lastState = state;
|
||||
+ }, debounce)
|
||||
};
|
||||
|
||||
function processNextKey() {
|
||||
diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/es/types.js.flow
|
||||
+++ b/node_modules/redux-persist/es/types.js.flow
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/lib/types.js.flow
|
||||
+++ b/node_modules/redux-persist/lib/types.js.flow
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js
|
||||
index c50d3cd..39d8be2 100644
|
||||
--- a/node_modules/redux-persist/src/types.js
|
||||
+++ b/node_modules/redux-persist/src/types.js
|
||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
||||
whitelist?: Array<string>,
|
||||
transforms?: Array<Transform>,
|
||||
throttle?: number,
|
||||
+ debounce?: number,
|
||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
||||
stateReconciler?: false | Function,
|
||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
||||
diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts
|
||||
index b3733bc..2a1696c 100644
|
||||
--- a/node_modules/redux-persist/types/types.d.ts
|
||||
+++ b/node_modules/redux-persist/types/types.d.ts
|
||||
@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" {
|
||||
whitelist?: Array<string>;
|
||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
||||
throttle?: number;
|
||||
+ debounce?: number;
|
||||
migrate?: PersistMigrate;
|
||||
stateReconciler?: false | StateReconciler<S>;
|
||||
/**
|
@ -450,7 +450,7 @@
|
||||
"cfgScale": "CFG Scale",
|
||||
"width": "Width",
|
||||
"height": "Height",
|
||||
"sampler": "Sampler",
|
||||
"scheduler": "Scheduler",
|
||||
"seed": "Seed",
|
||||
"imageToImage": "Image to Image",
|
||||
"randomizeSeed": "Randomize Seed",
|
||||
@ -540,7 +540,10 @@
|
||||
"consoleLogLevel": "Log Level",
|
||||
"shouldLogToConsole": "Console Logging",
|
||||
"developer": "Developer",
|
||||
"general": "General"
|
||||
"general": "General",
|
||||
"generation": "Generation",
|
||||
"ui": "User Interface",
|
||||
"availableSchedulers": "Available Schedulers"
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
@ -549,8 +552,8 @@
|
||||
"canceled": "Processing Canceled",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
"uploadFailedUnableToLoadDesc": "Unable to load file",
|
||||
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||
"downloadImageStarted": "Image Download Started",
|
||||
"imageCopied": "Image Copied",
|
||||
"imageLinkCopied": "Image Link Copied",
|
||||
|
@ -2,14 +2,11 @@ import ImageUploader from 'common/components/ImageUploader';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
|
||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
||||
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
|
||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
||||
@ -17,25 +14,28 @@ import { motion, AnimatePresence } from 'framer-motion';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||
import { PartialAppConfig } from 'app/types/invokeai';
|
||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import i18n from 'i18n';
|
||||
import Toaster from './Toaster';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
useToastWatcher();
|
||||
useGlobalHotkeys();
|
||||
|
||||
const App = ({
|
||||
config = DEFAULT_CONFIG,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
}: Props) => {
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const log = useLogger();
|
||||
@ -61,66 +61,80 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
setLoadingOverridden(true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (isApplicationReady && setIsReady) {
|
||||
setIsReady(true);
|
||||
}
|
||||
|
||||
return () => {
|
||||
setIsReady && setIsReady(false);
|
||||
};
|
||||
}, [isApplicationReady, setIsReady]);
|
||||
|
||||
return (
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
{isLightboxEnabled && <Lightbox />}
|
||||
<ImageUploader>
|
||||
<ProgressBar />
|
||||
<Grid
|
||||
gap={4}
|
||||
p={4}
|
||||
gridAutoRows="min-content auto"
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
{headerComponent || <SiteHeader />}
|
||||
<Flex
|
||||
<>
|
||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||
{isLightboxEnabled && <Lightbox />}
|
||||
<ImageUploader>
|
||||
<ProgressBar />
|
||||
<Grid
|
||||
gap={4}
|
||||
w={{ base: '100vw', xl: 'full' }}
|
||||
h="full"
|
||||
flexDir={{ base: 'column', xl: 'row' }}
|
||||
p={4}
|
||||
gridAutoRows="min-content auto"
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
<InvokeTabs />
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
{headerComponent || <SiteHeader />}
|
||||
<Flex
|
||||
gap={4}
|
||||
w={{ base: '100vw', xl: 'full' }}
|
||||
h="full"
|
||||
flexDir={{ base: 'column', xl: 'row' }}
|
||||
>
|
||||
<InvokeTabs />
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
key="loading"
|
||||
initial={{ opacity: 1 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ zIndex: 3 }}
|
||||
>
|
||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
||||
<Loading />
|
||||
</Box>
|
||||
<Box
|
||||
onClick={handleOverrideClicked}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
cursor="pointer"
|
||||
w="2rem"
|
||||
h="2rem"
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
key="loading"
|
||||
initial={{ opacity: 1 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ zIndex: 3 }}
|
||||
>
|
||||
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
|
||||
<Loading />
|
||||
</Box>
|
||||
<Box
|
||||
onClick={handleOverrideClicked}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
cursor="pointer"
|
||||
w="2rem"
|
||||
h="2rem"
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
</Grid>
|
||||
<Portal>
|
||||
<FloatingParametersPanelButtons />
|
||||
</Portal>
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
</Grid>
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,44 @@
|
||||
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selector = createSelector(systemSelector, (system) => {
|
||||
const { isUploading } = system;
|
||||
|
||||
let tooltip = '';
|
||||
|
||||
if (isUploading) {
|
||||
tooltip = 'Uploading...';
|
||||
}
|
||||
|
||||
return {
|
||||
tooltip,
|
||||
shouldShow: isUploading,
|
||||
};
|
||||
});
|
||||
|
||||
export const AuxiliaryProgressIndicator = () => {
|
||||
const { shouldShow, tooltip } = useAppSelector(selector);
|
||||
|
||||
if (!shouldShow) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
color: 'base.600',
|
||||
}}
|
||||
>
|
||||
<Tooltip label={tooltip} placement="right" hasArrow>
|
||||
<Spinner />
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AuxiliaryProgressIndicator);
|
@ -10,6 +10,7 @@ import {
|
||||
togglePinParametersPanel,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import React, { memo } from 'react';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
const globalHotkeysSelector = createSelector(
|
||||
@ -27,7 +28,11 @@ const globalHotkeysSelector = createSelector(
|
||||
|
||||
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
|
||||
|
||||
export const useGlobalHotkeys = () => {
|
||||
/**
|
||||
* Logical component. Handles app-level global hotkeys.
|
||||
* @returns null
|
||||
*/
|
||||
const GlobalHotkeys: React.FC = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
||||
|
||||
@ -75,4 +80,8 @@ export const useGlobalHotkeys = () => {
|
||||
useHotkeys('4', () => {
|
||||
dispatch(setActiveTab('nodes'));
|
||||
});
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(GlobalHotkeys);
|
@ -24,9 +24,16 @@ interface Props extends PropsWithChildren {
|
||||
token?: string;
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
setIsReady?: (isReady: boolean) => void;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
||||
const InvokeAIUI = ({
|
||||
apiUrl,
|
||||
token,
|
||||
config,
|
||||
headerComponent,
|
||||
setIsReady,
|
||||
}: Props) => {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
@ -55,7 +62,11 @@ const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
|
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { useToast, UseToastOptions } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
|
||||
export type MakeToastArg = string | UseToastOptions;
|
||||
|
||||
/**
|
||||
* Makes a toast from a string or a UseToastOptions object.
|
||||
* If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms.
|
||||
*/
|
||||
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
|
||||
if (typeof arg === 'string') {
|
||||
return {
|
||||
title: arg,
|
||||
status: 'info',
|
||||
isClosable: true,
|
||||
duration: 2500,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'info', isClosable: true, duration: 2500, ...arg };
|
||||
};
|
||||
|
||||
/**
|
||||
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
||||
* @returns null
|
||||
*/
|
||||
const Toaster = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toastQueue = useAppSelector(toastQueueSelector);
|
||||
const toast = useToast();
|
||||
useEffect(() => {
|
||||
toastQueue.forEach((t) => {
|
||||
toast(t);
|
||||
});
|
||||
toastQueue.length > 0 && dispatch(clearToastQueue());
|
||||
}, [dispatch, toast, toastQueue]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a function that can be used to make a toast.
|
||||
* @example
|
||||
* const toaster = useAppToaster();
|
||||
* toaster('Hello world!');
|
||||
* toaster({ title: 'Hello world!', status: 'success' });
|
||||
* @returns A function that can be used to make a toast.
|
||||
* @see makeToast
|
||||
* @see MakeToastArg
|
||||
* @see UseToastOptions
|
||||
*/
|
||||
export const useAppToaster = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toaster = useCallback(
|
||||
(arg: MakeToastArg) => dispatch(addToast(makeToast(arg))),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return toaster;
|
||||
};
|
||||
|
||||
export default Toaster;
|
@ -1,28 +1,28 @@
|
||||
// TODO: use Enums?
|
||||
|
||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
||||
export const SCHEDULERS = [
|
||||
'ddim',
|
||||
'ddpm',
|
||||
'deis',
|
||||
'lms',
|
||||
'pndm',
|
||||
'heun',
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'deis',
|
||||
'ddpm',
|
||||
'pndm',
|
||||
'heun',
|
||||
'heun_k',
|
||||
'unipc',
|
||||
];
|
||||
] as const;
|
||||
|
||||
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
|
||||
(scheduler) => {
|
||||
return scheduler !== 'dpmpp_2s';
|
||||
}
|
||||
);
|
||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||
|
||||
export const isScheduler = (x: string): x is Scheduler =>
|
||||
SCHEDULERS.includes(x as Scheduler);
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||
|
@ -15,6 +15,10 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -43,3 +47,8 @@ addUserInvokedCanvasListener();
|
||||
addUserInvokedNodesListener();
|
||||
addUserInvokedTextToImageListener();
|
||||
addUserInvokedImageToImageListener();
|
||||
|
||||
addCanvasSavedToGalleryListener();
|
||||
addCanvasDownloadedAsImageListener();
|
||||
addCanvasCopiedToClipboardListener();
|
||||
addCanvasMergedListener();
|
||||
|
@ -0,0 +1,33 @@
|
||||
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
|
||||
export const addCanvasCopiedToClipboardListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasCopiedToClipboard,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Copying Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
copyBlobToClipboard(blob);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,33 @@
|
||||
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasDownloadedAsImageListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasDownloadedAsImage,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Downloading Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
downloadBlob(blob, 'mergedCanvas.png');
|
||||
},
|
||||
});
|
||||
};
|
@ -1,31 +0,0 @@
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { sessionInvoked } from 'services/thunks/session';
|
||||
|
||||
export const addCanvasGraphBuiltListener = () =>
|
||||
startAppListening({
|
||||
actionCreator: canvasGraphBuilt,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||
const { sessionId } = meta.arg;
|
||||
const state = getState();
|
||||
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
sessionId,
|
||||
boundingBox: {
|
||||
...state.canvas.boundingBoxCoordinates,
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(canvasSessionIdChanged(sessionId));
|
||||
},
|
||||
});
|
@ -0,0 +1,88 @@
|
||||
import { canvasMerged } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||
|
||||
export const addCanvasMergedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasMerged,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state, true);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Merging Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
|
||||
if (!canvasBaseLayer) {
|
||||
moduleLog.error('Problem getting canvas base layer');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Merging Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const baseLayerRect = canvasBaseLayer.getClientRect({
|
||||
relativeTo: canvasBaseLayer.getParent(),
|
||||
});
|
||||
|
||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([blob], filename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === filename
|
||||
);
|
||||
|
||||
const mergedCanvasImage = deserializeImageResponse(payload.response);
|
||||
|
||||
dispatch(
|
||||
setMergedCanvas({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
image: mergedCanvasImage,
|
||||
...baseLayerRect,
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Canvas Merged',
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,40 @@
|
||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||
|
||||
export const addCanvasSavedToGalleryListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: canvasSavedToGallery,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
|
||||
const blob = await getBaseLayerBlob(state);
|
||||
|
||||
if (!blob) {
|
||||
moduleLog.error('Problem getting base layer blob');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: 'Problem Saving Canvas',
|
||||
description: 'Unable to export base layer',
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'results',
|
||||
formData: {
|
||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -3,6 +3,10 @@ import { startAppListening } from '..';
|
||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
|
||||
export const addImageUploadedListener = () => {
|
||||
startAppListening({
|
||||
@ -11,14 +15,31 @@ export const addImageUploadedListener = () => {
|
||||
action.payload.response.image_type !== 'intermediates',
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { response } = action.payload;
|
||||
const { imageType } = action.meta.arg;
|
||||
|
||||
const state = getState();
|
||||
const image = deserializeImageResponse(response);
|
||||
|
||||
dispatch(uploadAdded(image));
|
||||
if (imageType === 'uploads') {
|
||||
dispatch(uploadAdded(image));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'img2img') {
|
||||
dispatch(initialImageSelected(image));
|
||||
}
|
||||
|
||||
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
}
|
||||
|
||||
if (imageType === 'results') {
|
||||
dispatch(resultAdded(image));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -2,11 +2,11 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||
import { t } from 'i18next';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
startAppListening({
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
||||
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
@ -11,9 +11,17 @@ import {
|
||||
stagingAreaInitialized,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
/**
|
||||
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
|
||||
* It is also responsible for uploading the base and mask layers to the server.
|
||||
*/
|
||||
export const addUserInvokedCanvasListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
@ -21,25 +29,49 @@ export const addUserInvokedCanvasListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const data = await buildCanvasGraphAndBlobs(state);
|
||||
// Build canvas blobs
|
||||
const canvasBlobsAndImageData = await getCanvasData(state);
|
||||
|
||||
if (!data) {
|
||||
if (!canvasBlobsAndImageData) {
|
||||
moduleLog.error('Unable to create canvas data');
|
||||
return;
|
||||
}
|
||||
|
||||
const { baseBlob, baseImageData, maskBlob, maskImageData } =
|
||||
canvasBlobsAndImageData;
|
||||
|
||||
// Determine the generation mode
|
||||
const generationMode = getCanvasGenerationMode(
|
||||
baseImageData,
|
||||
maskImageData
|
||||
);
|
||||
|
||||
if (state.system.enableImageDebugging) {
|
||||
const baseDataURL = await blobToDataURL(baseBlob);
|
||||
const maskDataURL = await blobToDataURL(maskBlob);
|
||||
openBase64ImageInTab([
|
||||
{ base64: maskDataURL, caption: 'mask b64' },
|
||||
{ base64: baseDataURL, caption: 'image b64' },
|
||||
]);
|
||||
}
|
||||
|
||||
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||
|
||||
// Build the canvas graph
|
||||
const graphComponents = await buildCanvasGraphComponents(
|
||||
state,
|
||||
generationMode
|
||||
);
|
||||
|
||||
if (!graphComponents) {
|
||||
moduleLog.error('Problem building graph');
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
rangeNode,
|
||||
iterateNode,
|
||||
baseNode,
|
||||
edges,
|
||||
baseBlob,
|
||||
maskBlob,
|
||||
generationMode,
|
||||
} = data;
|
||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||
|
||||
// Upload the base layer, to be used as init image
|
||||
const baseFilename = `${uuidv4()}.png`;
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
@ -66,6 +98,9 @@ export const addUserInvokedCanvasListener = () => {
|
||||
};
|
||||
}
|
||||
|
||||
// Upload the mask layer image
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
|
||||
if (baseNode.type === 'inpaint') {
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
@ -103,9 +138,12 @@ export const addUserInvokedCanvasListener = () => {
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Canvas graph built');
|
||||
|
||||
// Actually create the session
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
// Wait for the session to be invoked (this is just the HTTP request to start processing)
|
||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||
|
||||
const { sessionId } = meta.arg;
|
||||
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
|
@ -52,6 +52,7 @@ export type CommonGeneratedImageMetadata = {
|
||||
| 'lms'
|
||||
| 'pndm'
|
||||
| 'heun'
|
||||
| 'heun_k'
|
||||
| 'euler'
|
||||
| 'euler_k'
|
||||
| 'euler_a'
|
||||
|
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
@ -0,0 +1,172 @@
|
||||
import { CheckIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FlexProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
List,
|
||||
ListItem,
|
||||
Select,
|
||||
Text,
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||
import { useSelect } from 'downshift';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
import { memo } from 'react';
|
||||
|
||||
type IAICustomSelectProps = {
|
||||
label?: string;
|
||||
items: string[];
|
||||
selectedItem: string;
|
||||
setSelectedItem: (v: string | null | undefined) => void;
|
||||
withCheckIcon?: boolean;
|
||||
formControlProps?: FormControlProps;
|
||||
buttonProps?: FlexProps;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
};
|
||||
|
||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
const {
|
||||
label,
|
||||
items,
|
||||
setSelectedItem,
|
||||
selectedItem,
|
||||
withCheckIcon,
|
||||
formControlProps,
|
||||
tooltip,
|
||||
buttonProps,
|
||||
tooltipProps,
|
||||
} = props;
|
||||
|
||||
const {
|
||||
isOpen,
|
||||
getToggleButtonProps,
|
||||
getLabelProps,
|
||||
getMenuProps,
|
||||
highlightedIndex,
|
||||
getItemProps,
|
||||
} = useSelect({
|
||||
items,
|
||||
selectedItem,
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
|
||||
setSelectedItem(newSelectedItem),
|
||||
});
|
||||
|
||||
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
||||
whileElementsMounted: autoUpdate,
|
||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
||||
{label && (
|
||||
<FormLabel
|
||||
{...getLabelProps()}
|
||||
onClick={() => {
|
||||
refs.floating.current && refs.floating.current.focus();
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
<Tooltip label={tooltip} {...tooltipProps}>
|
||||
<Select
|
||||
{...getToggleButtonProps({ ref: refs.setReference })}
|
||||
{...buttonProps}
|
||||
as={Flex}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
userSelect: 'none',
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}>
|
||||
{selectedItem}
|
||||
</Text>
|
||||
</Select>
|
||||
</Tooltip>
|
||||
<Box {...getMenuProps()}>
|
||||
{isOpen && (
|
||||
<List
|
||||
as={Flex}
|
||||
ref={refs.setFloating}
|
||||
sx={{
|
||||
...floatingStyles,
|
||||
width: 'max-content',
|
||||
top: 0,
|
||||
left: 0,
|
||||
flexDirection: 'column',
|
||||
zIndex: 1,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
border: '1px',
|
||||
borderColor: 'base.700',
|
||||
shadow: 'dark-lg',
|
||||
py: 2,
|
||||
px: 0,
|
||||
h: 'fit-content',
|
||||
maxH: 64,
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent>
|
||||
{items.map((item, index) => (
|
||||
<ListItem
|
||||
sx={{
|
||||
bg: highlightedIndex === index ? 'base.700' : undefined,
|
||||
py: 1,
|
||||
paddingInlineStart: 3,
|
||||
paddingInlineEnd: 6,
|
||||
cursor: 'pointer',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.15s',
|
||||
}}
|
||||
key={`${item}${index}`}
|
||||
{...getItemProps({ item, index })}
|
||||
>
|
||||
{withCheckIcon ? (
|
||||
<Grid gridTemplateColumns="1.25rem auto">
|
||||
<GridItem>
|
||||
{selectedItem === item && <CheckIcon boxSize={2} />}
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.100',
|
||||
fontWeight: 500,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
</Text>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
) : (
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
color: 'base.100',
|
||||
fontWeight: 500,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
</Text>
|
||||
)}
|
||||
</ListItem>
|
||||
))}
|
||||
</OverlayScrollbarsComponent>
|
||||
</List>
|
||||
)}
|
||||
</Box>
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICustomSelect);
|
@ -5,6 +5,7 @@ import {
|
||||
Input,
|
||||
InputProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
|
||||
interface IAIInputProps extends InputProps {
|
||||
@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => {
|
||||
{...formControlProps}
|
||||
>
|
||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||
<Input {...rest} />
|
||||
<Input {...rest} onPaste={stopPastePropagation} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
@ -14,6 +14,7 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { clamp } from 'lodash-es';
|
||||
|
||||
import { FocusEvent, memo, useEffect, useState } from 'react';
|
||||
@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => {
|
||||
onChange={handleOnChange}
|
||||
onBlur={handleBlur}
|
||||
{...rest}
|
||||
onPaste={stopPastePropagation}
|
||||
>
|
||||
<NumberInputField {...numberInputFieldProps} />
|
||||
{showStepper && (
|
||||
|
@ -0,0 +1,9 @@
|
||||
import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react';
|
||||
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||
import { memo } from 'react';
|
||||
|
||||
const IAITextarea = forwardRef((props: TextareaProps, ref) => {
|
||||
return <Textarea ref={ref} onPaste={stopPastePropagation} {...props} />;
|
||||
});
|
||||
|
||||
export default memo(IAITextarea);
|
@ -1,4 +1,4 @@
|
||||
import { Box, useToast } from '@chakra-ui/react';
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
@ -10,12 +10,33 @@ import {
|
||||
ReactNode,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { filter, map, some } from 'lodash-es';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { ErrorCode } from 'react-dropzone';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector, activeTabNameSelector],
|
||||
(system, activeTabName) => {
|
||||
const { isConnected, isUploading } = system;
|
||||
|
||||
const isUploaderDisabled = !isConnected || isUploading;
|
||||
|
||||
return {
|
||||
isUploaderDisabled,
|
||||
activeTabName,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
type ImageUploaderProps = {
|
||||
children: ReactNode;
|
||||
@ -24,38 +45,49 @@ type ImageUploaderProps = {
|
||||
const ImageUploader = (props: ImageUploaderProps) => {
|
||||
const { children } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const toast = useToast({});
|
||||
const { isUploaderDisabled, activeTabName } = useAppSelector(selector);
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
||||
const { setOpenUploader } = useImageUploader();
|
||||
const { setOpenUploaderFunction } = useImageUploader();
|
||||
|
||||
const fileRejectionCallback = useCallback(
|
||||
(rejection: FileRejection) => {
|
||||
setIsHandlingUpload(true);
|
||||
const msg = rejection.errors.reduce(
|
||||
(acc: string, cur: { message: string }) => `${acc}\n${cur.message}`,
|
||||
''
|
||||
);
|
||||
toast({
|
||||
|
||||
toaster({
|
||||
title: t('toast.uploadFailed'),
|
||||
description: msg,
|
||||
description: rejection.errors.map((error) => error.message).join('\n'),
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[t, toast]
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'uploads',
|
||||
formData: { file },
|
||||
activeTabName,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, activeTabName]
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||
if (fileRejections.length > 1) {
|
||||
toaster({
|
||||
title: t('toast.uploadFailed'),
|
||||
description: t('toast.uploadFailedInvalidUploadDesc'),
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
fileRejections.forEach((rejection: FileRejection) => {
|
||||
fileRejectionCallback(rejection);
|
||||
});
|
||||
@ -64,7 +96,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
fileAcceptedCallback(file);
|
||||
});
|
||||
},
|
||||
[fileAcceptedCallback, fileRejectionCallback]
|
||||
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
|
||||
);
|
||||
|
||||
const {
|
||||
@ -73,92 +105,73 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
isDragAccept,
|
||||
isDragReject,
|
||||
isDragActive,
|
||||
inputRef,
|
||||
open,
|
||||
} = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
noClick: true,
|
||||
onDrop,
|
||||
onDragOver: () => setIsHandlingUpload(true),
|
||||
maxFiles: 1,
|
||||
disabled: isUploaderDisabled,
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
setOpenUploader(open);
|
||||
|
||||
useEffect(() => {
|
||||
const pasteImageListener = (e: ClipboardEvent) => {
|
||||
const dataTransferItemList = e.clipboardData?.items;
|
||||
if (!dataTransferItemList) return;
|
||||
|
||||
const imageItems: Array<DataTransferItem> = [];
|
||||
|
||||
for (const item of dataTransferItemList) {
|
||||
if (
|
||||
item.kind === 'file' &&
|
||||
['image/png', 'image/jpg'].includes(item.type)
|
||||
) {
|
||||
imageItems.push(item);
|
||||
}
|
||||
}
|
||||
|
||||
if (!imageItems.length) return;
|
||||
|
||||
e.stopImmediatePropagation();
|
||||
|
||||
if (imageItems.length > 1) {
|
||||
toast({
|
||||
description: t('toast.uploadFailedMultipleImagesDesc'),
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
// This is a hack to allow pasting images into the uploader
|
||||
const handlePaste = async (e: ClipboardEvent) => {
|
||||
if (!inputRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
const file = imageItems[0].getAsFile();
|
||||
|
||||
if (!file) {
|
||||
toast({
|
||||
description: t('toast.uploadFailedUnableToLoadDesc'),
|
||||
status: 'error',
|
||||
isClosable: true,
|
||||
});
|
||||
return;
|
||||
if (e.clipboardData?.files) {
|
||||
// Set the files on the inputRef
|
||||
inputRef.current.files = e.clipboardData.files;
|
||||
// Dispatch the change event, dropzone catches this and we get to use its own validation
|
||||
inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
|
||||
}
|
||||
|
||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
|
||||
// Set the open function so we can open the uploader from anywhere
|
||||
setOpenUploaderFunction(open);
|
||||
|
||||
// Add the paste event listener
|
||||
document.addEventListener('paste', handlePaste);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('paste', pasteImageListener);
|
||||
document.removeEventListener('paste', handlePaste);
|
||||
setOpenUploaderFunction(() => {
|
||||
return;
|
||||
});
|
||||
};
|
||||
}, [t, dispatch, toast, activeTabName]);
|
||||
}, [inputRef, open, setOpenUploaderFunction]);
|
||||
|
||||
const overlaySecondaryText = ['img2img', 'unifiedCanvas'].includes(
|
||||
activeTabName
|
||||
)
|
||||
? ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`
|
||||
: ``;
|
||||
const overlaySecondaryText = useMemo(() => {
|
||||
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
|
||||
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
|
||||
}
|
||||
|
||||
return '';
|
||||
}, [t, activeTabName]);
|
||||
|
||||
return (
|
||||
<ImageUploaderTriggerContext.Provider value={open}>
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
onKeyDown={(e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') return;
|
||||
}}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
{children}
|
||||
{isDragActive && isHandlingUpload && (
|
||||
<ImageUploadOverlay
|
||||
isDragAccept={isDragAccept}
|
||||
isDragReject={isDragReject}
|
||||
overlaySecondaryText={overlaySecondaryText}
|
||||
setIsHandlingUpload={setIsHandlingUpload}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</ImageUploaderTriggerContext.Provider>
|
||||
<Box
|
||||
{...getRootProps({ style: {} })}
|
||||
onKeyDown={(e: KeyboardEvent) => {
|
||||
// Bail out if user hits spacebar - do not open the uploader
|
||||
if (e.key === ' ') return;
|
||||
}}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
{children}
|
||||
{isDragActive && isHandlingUpload && (
|
||||
<ImageUploadOverlay
|
||||
isDragAccept={isDragAccept}
|
||||
isDragReject={isDragReject}
|
||||
overlaySecondaryText={overlaySecondaryText}
|
||||
setIsHandlingUpload={setIsHandlingUpload}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { Flex, Heading, Icon } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useContext } from 'react';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
|
||||
type ImageUploaderButtonProps = {
|
||||
@ -9,11 +8,7 @@ type ImageUploaderButtonProps = {
|
||||
|
||||
const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
|
||||
const { styleClass } = props;
|
||||
const open = useContext(ImageUploaderTriggerContext);
|
||||
|
||||
const handleClickUpload = () => {
|
||||
open && open();
|
||||
};
|
||||
const { openUploader } = useImageUploader();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -26,7 +21,7 @@ const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
|
||||
className={styleClass}
|
||||
>
|
||||
<Flex
|
||||
onClick={handleClickUpload}
|
||||
onClick={openUploader}
|
||||
sx={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
|
@ -1,19 +1,18 @@
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useContext } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import IAIIconButton from './IAIIconButton';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
|
||||
const ImageUploaderIconButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const openImageUploader = useContext(ImageUploaderTriggerContext);
|
||||
const { openUploader } = useImageUploader();
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
aria-label={t('accessibility.uploadImage')}
|
||||
tooltip="Upload Image"
|
||||
icon={<FaUpload />}
|
||||
onClick={openImageUploader || undefined}
|
||||
onClick={openUploader}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -6,10 +6,12 @@ import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
|
||||
const InitialImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { openUploader } = useImageUploader();
|
||||
|
||||
const handleResetInitialImage = useCallback(() => {
|
||||
dispatch(clearInitialImage());
|
||||
@ -27,7 +29,11 @@ const InitialImageButtons = () => {
|
||||
aria-label={t('accessibility.reset')}
|
||||
onClick={handleResetInitialImage}
|
||||
/>
|
||||
<IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
onClick={openUploader}
|
||||
aria-label={t('common.upload')}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
@ -24,7 +24,6 @@ const Loading = () => {
|
||||
height="24px !important"
|
||||
right="1.5rem"
|
||||
bottom="1.5rem"
|
||||
speed="1.2s"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,29 +0,0 @@
|
||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
|
||||
export const PostProcessingWIP = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: '100%',
|
||||
h: '100%',
|
||||
gap: 4,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<Heading>{t('common.postProcessing')}</Heading>
|
||||
<VStack maxW="50rem" gap={4}>
|
||||
<Text>{t('common.postProcessDesc1')}</Text>
|
||||
<Text>{t('common.postProcessDesc2')}</Text>
|
||||
<Text>{t('common.postProcessDesc3')}</Text>
|
||||
</VStack>
|
||||
</Flex>
|
||||
</WorkInProgress>
|
||||
);
|
||||
};
|
@ -1,28 +0,0 @@
|
||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
|
||||
export default function TrainingWIP() {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: '100%',
|
||||
h: '100%',
|
||||
gap: 4,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<Heading>{t('common.training')}</Heading>
|
||||
<VStack maxW="50rem" gap={4}>
|
||||
<Text>{t('common.trainingDesc1')}</Text>
|
||||
<Text>{t('common.trainingDesc2')}</Text>
|
||||
</VStack>
|
||||
</Flex>
|
||||
</WorkInProgress>
|
||||
);
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
type WorkInProgressProps = {
|
||||
children: ReactNode;
|
||||
};
|
||||
|
||||
const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
const { children } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default WorkInProgress;
|
@ -1,35 +0,0 @@
|
||||
import { RefObject, useEffect } from 'react';
|
||||
|
||||
const watchers: {
|
||||
ref: RefObject<HTMLElement>;
|
||||
enable: boolean;
|
||||
callback: () => void;
|
||||
}[] = [];
|
||||
|
||||
const useClickOutsideWatcher = () => {
|
||||
useEffect(() => {
|
||||
function handleClickOutside(e: MouseEvent) {
|
||||
watchers.forEach(({ ref, enable, callback }) => {
|
||||
if (enable && ref.current && !ref.current.contains(e.target as Node)) {
|
||||
callback();
|
||||
}
|
||||
});
|
||||
}
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
addWatcher: (watcher: {
|
||||
ref: RefObject<HTMLElement>;
|
||||
callback: () => void;
|
||||
enable: boolean;
|
||||
}) => {
|
||||
watchers.push(watcher);
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default useClickOutsideWatcher;
|
@ -1,13 +1,22 @@
|
||||
let openFunction: () => void;
|
||||
import { useCallback } from 'react';
|
||||
|
||||
let openUploader = () => {
|
||||
return;
|
||||
};
|
||||
|
||||
const useImageUploader = () => {
|
||||
return {
|
||||
setOpenUploader: (open?: () => void) => {
|
||||
if (open) {
|
||||
openFunction = open;
|
||||
const setOpenUploaderFunction = useCallback(
|
||||
(openUploaderFunction?: () => void) => {
|
||||
if (openUploaderFunction) {
|
||||
openUploader = openUploaderFunction;
|
||||
}
|
||||
},
|
||||
openUploader: openFunction,
|
||||
[]
|
||||
);
|
||||
|
||||
return {
|
||||
setOpenUploaderFunction,
|
||||
openUploader,
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -1,17 +0,0 @@
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function useUpdateTranslations(fn: () => void) {
|
||||
const { i18n } = useTranslation();
|
||||
const currentLang = localStorage.getItem('i18nextLng');
|
||||
|
||||
React.useEffect(() => {
|
||||
fn();
|
||||
}, [fn]);
|
||||
|
||||
React.useEffect(() => {
|
||||
i18n.on('languageChanged', () => {
|
||||
fn();
|
||||
});
|
||||
}, [fn, i18n, currentLang]);
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
import { createIcon } from '@chakra-ui/react';
|
||||
|
||||
const ImageToImageIcon = createIcon({
|
||||
displayName: 'ImageToImageIcon',
|
||||
viewBox: '0 0 3543 3543',
|
||||
path: (
|
||||
<g transform="matrix(1.10943,0,0,1.10943,-206.981,-213.533)">
|
||||
<path
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M688.533,2405.95L542.987,2405.95C349.532,2405.95 192.47,2248.89 192.47,2055.44L192.47,542.987C192.47,349.532 349.532,192.47 542.987,192.47L2527.88,192.47C2721.33,192.47 2878.4,349.532 2878.4,542.987L2878.4,1172.79L3023.94,1172.79C3217.4,1172.79 3374.46,1329.85 3374.46,1523.3C3374.46,1523.3 3374.46,3035.75 3374.46,3035.75C3374.46,3229.21 3217.4,3386.27 3023.94,3386.27L1039.05,3386.27C845.595,3386.27 688.533,3229.21 688.533,3035.75L688.533,2405.95ZM3286.96,2634.37L3286.96,1523.3C3286.96,1378.14 3169.11,1260.29 3023.94,1260.29C3023.94,1260.29 1039.05,1260.29 1039.05,1260.29C893.887,1260.29 776.033,1378.14 776.033,1523.3L776.033,2489.79L1440.94,1736.22L2385.83,2775.59L2880.71,2200.41L3286.96,2634.37ZM2622.05,1405.51C2778.5,1405.51 2905.51,1532.53 2905.51,1688.98C2905.51,1845.42 2778.5,1972.44 2622.05,1972.44C2465.6,1972.44 2338.58,1845.42 2338.58,1688.98C2338.58,1532.53 2465.6,1405.51 2622.05,1405.51ZM2790.9,1172.79L1323.86,1172.79L944.882,755.906L279.97,1509.47L279.97,542.987C279.97,397.824 397.824,279.97 542.987,279.97C542.987,279.97 2527.88,279.97 2527.88,279.97C2673.04,279.97 2790.9,397.824 2790.9,542.987L2790.9,1172.79ZM2125.98,425.197C2282.43,425.197 2409.45,552.213 2409.45,708.661C2409.45,865.11 2282.43,992.126 2125.98,992.126C1969.54,992.126 1842.52,865.11 1842.52,708.661C1842.52,552.213 1969.54,425.197 2125.98,425.197Z"
|
||||
/>
|
||||
</g>
|
||||
),
|
||||
defaultProps: {
|
||||
boxSize: '24px',
|
||||
},
|
||||
});
|
||||
export default ImageToImageIcon;
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
||||
import { createIcon } from '@chakra-ui/react';
|
||||
|
||||
const NodesIcon = createIcon({
|
||||
displayName: 'NodesIcon',
|
||||
viewBox: '0 0 3543 3543',
|
||||
path: (
|
||||
<path
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M3543.31,770.787C3543.31,515.578 3336.11,308.38 3080.9,308.38L462.407,308.38C207.197,308.38 0,515.578 0,770.787L0,2766.03C0,3021.24 207.197,3228.44 462.407,3228.44L3080.9,3228.44C3336.11,3228.44 3543.31,3021.24 3543.31,2766.03C3543.31,2766.03 3543.31,770.787 3543.31,770.787ZM3427.88,770.787L3427.88,2766.03C3427.88,2957.53 3272.4,3113.01 3080.9,3113.01C3080.9,3113.01 462.407,3113.01 462.407,3113.01C270.906,3113.01 115.431,2957.53 115.431,2766.03L115.431,770.787C115.431,579.286 270.906,423.812 462.407,423.812L3080.9,423.812C3272.4,423.812 3427.88,579.286 3427.88,770.787ZM1214.23,1130.69L1321.47,1130.69C1324.01,1130.69 1326.54,1130.53 1329.05,1130.2C1329.05,1130.2 1367.3,1125.33 1397.94,1149.8C1421.63,1168.72 1437.33,1204.3 1437.33,1265.48L1437.33,2078.74L1220.99,2078.74C1146.83,2078.74 1086.61,2138.95 1086.61,2213.12L1086.61,2762.46C1086.61,2836.63 1146.83,2896.84 1220.99,2896.84L1770.34,2896.84C1844.5,2896.84 1904.71,2836.63 1904.71,2762.46L1904.71,2213.12C1904.71,2138.95 1844.5,2078.74 1770.34,2078.74L1554,2078.74L1554,1604.84C1625.84,1658.19 1703.39,1658.1 1703.39,1658.1C1703.54,1658.1 1703.69,1658.11 1703.84,1658.11L2362.2,1658.11L2362.2,1874.44C2362.2,1948.61 2422.42,2008.82 2496.58,2008.82L3045.93,2008.82C3120.09,2008.82 3180.3,1948.61 3180.3,1874.44L3180.3,1325.1C3180.3,1250.93 3120.09,1190.72 3045.93,1190.72L2496.58,1190.72C2422.42,1190.72 2362.2,1250.93 2362.2,1325.1L2362.2,1558.97L2362.2,1541.44L1704.23,1541.44C1702.2,1541.37 1650.96,1539.37 1609.51,1499.26C1577.72,1468.49 1554,1416.47 1554,1331.69L1554,1265.48C1554,1153.86 1513.98,1093.17 1470.76,1058.64C1411.24,1011.1 1338.98,1012.58 1319.15,1014.03L1214.23,1014.03L1214.23,796.992C1214.23,722.828 1154.02,662.617 1079.85,662.617L530.507,662.617C456.343,662.617 396.131,722.828 396.131,796.992L396.131,1346.34C396.131,1420.5 456.343,1480.71 530.507,1480.71L1079.85,1480.71C1154.02,1480.71 1214.23,1420.5 1214.23,1346.34L1214.23,1130.69Z"
|
||||
/>
|
||||
),
|
||||
defaultProps: {
|
||||
boxSize: '24px',
|
||||
},
|
||||
});
|
||||
|
||||
export default NodesIcon;
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
||||
import { createIcon } from '@chakra-ui/react';
|
||||
|
||||
const PostprocessingIcon = createIcon({
|
||||
displayName: 'PostprocessingIcon',
|
||||
viewBox: '0 0 3543 3543',
|
||||
path: (
|
||||
<path
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M709.477,1596.53L992.591,1275.66L2239.09,2646.81L2891.95,1888.03L3427.88,2460.51L3427.88,994.78C3427.88,954.66 3421.05,916.122 3408.5,880.254L3521.9,855.419C3535.8,899.386 3543.31,946.214 3543.31,994.78L3543.31,2990.02C3543.31,3245.23 3336.11,3452.43 3080.9,3452.43C3080.9,3452.43 462.407,3452.43 462.407,3452.43C207.197,3452.43 -0,3245.23 -0,2990.02L-0,994.78C-0,739.571 207.197,532.373 462.407,532.373L505.419,532.373L504.644,532.546L807.104,600.085C820.223,601.729 832.422,607.722 841.77,617.116C850.131,625.517 855.784,636.21 858.055,647.804L462.407,647.804C270.906,647.804 115.431,803.279 115.431,994.78L115.431,2075.73L-0,2101.5L115.431,2127.28L115.431,2269.78L220.47,2150.73L482.345,2209.21C503.267,2211.83 522.722,2221.39 537.63,2236.37C552.538,2251.35 562.049,2270.9 564.657,2291.93L671.84,2776.17L779.022,2291.93C781.631,2270.9 791.141,2251.35 806.05,2236.37C820.958,2221.39 840.413,2211.83 861.334,2209.21L1353.15,2101.5L861.334,1993.8C840.413,1991.18 820.958,1981.62 806.05,1966.64C791.141,1951.66 781.631,1932.11 779.022,1911.08L709.477,1596.53ZM671.84,1573.09L725.556,2006.07C726.863,2016.61 731.63,2026.4 739.101,2033.91C746.573,2041.42 756.323,2046.21 766.808,2047.53L1197.68,2101.5L766.808,2155.48C756.323,2156.8 746.573,2161.59 739.101,2169.09C731.63,2176.6 726.863,2186.4 725.556,2196.94L671.84,2629.92L618.124,2196.94C616.817,2186.4 612.05,2176.6 604.579,2169.09C597.107,2161.59 587.357,2156.8 576.872,2155.48L146.001,2101.5L576.872,2047.53C587.357,2046.21 597.107,2041.42 604.579,2033.91C612.05,2026.4 616.817,2016.61 618.124,2006.07L671.84,1573.09ZM609.035,1710.36L564.657,1911.08C562.049,1932.11 552.538,1951.66 537.63,1966.64C522.722,1981.62 503.267,1991.18 482.345,1993.8L328.665,2028.11L609.035,1710.36ZM2297.12,938.615L2451.12,973.003C2480.59,976.695 2507.99,990.158 2528.99,1011.26C2549.99,1032.37 2563.39,1059.9 2567.07,1089.52L2672.73,1566.9C2634.5,1580.11 2593.44,1587.29 2550.72,1587.29C2344.33,1587.29 2176.77,1419.73 2176.77,1213.34C2176.77,1104.78 2223.13,1006.96 2297.12,938.615ZM2718.05,76.925L2793.72,686.847C2795.56,701.69 2802.27,715.491 2812.8,726.068C2823.32,736.644 2837.06,743.391 2851.83,745.242L3458.78,821.28L2851.83,897.318C2837.06,899.168 2823.32,905.916 2812.8,916.492C2802.27,927.068 2795.56,940.87 2793.72,955.712L2718.05,1565.63L2642.38,955.712C2640.54,940.87 2633.83,927.068 2623.3,916.492C2612.78,905.916 2599.04,899.168 2584.27,897.318L1977.32,821.28L2584.27,745.242C2599.04,743.391 2612.78,736.644 2623.3,726.068C2633.83,715.491 2640.54,701.69 2642.38,686.847L2718.05,76.925ZM2883.68,1043.06C2909.88,1094.13 2924.67,1152.02 2924.67,1213.34C2924.67,1335.4 2866.06,1443.88 2775.49,1512.14L2869.03,1089.52C2871.07,1073.15 2876.07,1057.42 2883.68,1043.06ZM925.928,201.2L959.611,472.704C960.431,479.311 963.42,485.455 968.105,490.163C972.79,494.871 978.904,497.875 985.479,498.698L1255.66,532.546L985.479,566.395C978.904,567.218 972.79,570.222 968.105,574.93C963.42,579.638 960.431,585.781 959.611,592.388L925.928,863.893L892.245,592.388C891.425,585.781 888.436,579.638 883.751,574.93C879.066,570.222 872.952,567.218 866.378,566.395L596.195,532.546L866.378,498.698C872.952,497.875 879.066,494.871 883.751,490.163C888.436,485.455 891.425,479.311 892.245,472.704L925.928,201.2ZM2864.47,532.373L3080.9,532.373C3258.7,532.373 3413.2,632.945 3490.58,780.281L3319.31,742.773C3257.14,683.925 3173.2,647.804 3080.9,647.804L2927.07,647.804C2919.95,642.994 2913.25,637.473 2907.11,631.298C2886.11,610.194 2872.71,582.655 2869.03,553.04L2864.47,532.373ZM1352.36,532.373L2571.64,532.373L2567.07,553.04C2563.39,582.655 2549.99,610.194 2528.99,631.298C2522.85,637.473 2516.16,642.994 2509.03,647.804L993.801,647.804C996.072,636.21 1001.73,625.517 1010.09,617.116C1019.43,607.722 1031.63,601.729 1044.75,600.085L1353.15,532.546L1352.36,532.373Z"
|
||||
/>
|
||||
),
|
||||
defaultProps: {
|
||||
boxSize: '24px',
|
||||
},
|
||||
});
|
||||
|
||||
export default PostprocessingIcon;
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
||||
import { createIcon } from '@chakra-ui/react';
|
||||
|
||||
const TrainingIcon = createIcon({
|
||||
displayName: 'TrainingIcon',
|
||||
viewBox: '0 0 3544 3544',
|
||||
path: (
|
||||
<path
|
||||
fill="currentColor"
|
||||
fillRule="evenodd"
|
||||
clipRule="evenodd"
|
||||
d="M0,768.593L0,2774.71C0,2930.6 78.519,3068.3 198.135,3150.37C273.059,3202.68 364.177,3233.38 462.407,3233.38C462.407,3233.38 3080.9,3233.38 3080.9,3233.38C3179.13,3233.38 3270.25,3202.68 3345.17,3150.37C3464.79,3068.3 3543.31,2930.6 3543.31,2774.71L3543.31,768.593C3543.31,517.323 3339.31,313.324 3088.04,313.324L455.269,313.324C203.999,313.324 0,517.323 0,768.593ZM3427.88,775.73L3427.88,2770.97C3427.88,2962.47 3272.4,3117.95 3080.9,3117.95L462.407,3117.95C270.906,3117.95 115.431,2962.47 115.431,2770.97C115.431,2770.97 115.431,775.73 115.431,775.73C115.431,584.229 270.906,428.755 462.407,428.755C462.407,428.755 3080.9,428.755 3080.9,428.755C3272.4,428.755 3427.88,584.229 3427.88,775.73ZM796.24,1322.76L796.24,1250.45C796.24,1199.03 836.16,1157.27 885.331,1157.27C885.331,1157.27 946.847,1157.27 946.847,1157.27C996.017,1157.27 1035.94,1199.03 1035.94,1250.45L1035.94,1644.81L2507.37,1644.81L2507.37,1250.45C2507.37,1199.03 2547.29,1157.27 2596.46,1157.27C2596.46,1157.27 2657.98,1157.27 2657.98,1157.27C2707.15,1157.27 2747.07,1199.03 2747.07,1250.45L2747.07,1322.76C2756.66,1319.22 2767.02,1317.29 2777.83,1317.29C2777.83,1317.29 2839.34,1317.29 2839.34,1317.29C2888.51,1317.29 2928.43,1357.21 2928.43,1406.38L2928.43,1527.32C2933.51,1526.26 2938.77,1525.71 2944.16,1525.71L2995.3,1525.71C3036.18,1525.71 3069.37,1557.59 3069.37,1596.86C3069.37,1596.86 3069.37,1946.44 3069.37,1946.44C3069.37,1985.72 3036.18,2017.6 2995.3,2017.6C2995.3,2017.6 2944.16,2017.6 2944.16,2017.6C2938.77,2017.6 2933.51,2017.04 2928.43,2015.99L2928.43,2136.92C2928.43,2186.09 2888.51,2226.01 2839.34,2226.01L2777.83,2226.01C2767.02,2226.01 2756.66,2224.08 2747.07,2220.55L2747.07,2292.85C2747.07,2344.28 2707.15,2386.03 2657.98,2386.03C2657.98,2386.03 2596.46,2386.03 2596.46,2386.03C2547.29,2386.03 2507.37,2344.28 2507.37,2292.85L2507.37,1898.5L1035.94,1898.5L1035.94,2292.85C1035.94,2344.28 996.017,2386.03 946.847,2386.03C946.847,2386.03 885.331,2386.03 885.331,2386.03C836.16,2386.03 796.24,2344.28 796.24,2292.85L796.24,2220.55C786.651,2224.08 776.29,2226.01 765.482,2226.01L703.967,2226.01C654.796,2226.01 614.876,2186.09 614.876,2136.92L614.876,2015.99C609.801,2017.04 604.539,2017.6 599.144,2017.6C599.144,2017.6 548.003,2017.6 548.003,2017.6C507.125,2017.6 473.937,1985.72 473.937,1946.44C473.937,1946.44 473.937,1596.86 473.937,1596.86C473.937,1557.59 507.125,1525.71 548.003,1525.71L599.144,1525.71C604.539,1525.71 609.801,1526.26 614.876,1527.32L614.876,1406.38C614.876,1357.21 654.796,1317.29 703.967,1317.29C703.967,1317.29 765.482,1317.29 765.482,1317.29C776.29,1317.29 786.651,1319.22 796.24,1322.76ZM977.604,1250.45C977.604,1232.7 963.822,1218.29 946.847,1218.29L885.331,1218.29C868.355,1218.29 854.573,1232.7 854.573,1250.45L854.573,2292.85C854.573,2310.61 868.355,2325.02 885.331,2325.02L946.847,2325.02C963.822,2325.02 977.604,2310.61 977.604,2292.85L977.604,1250.45ZM2565.7,1250.45C2565.7,1232.7 2579.49,1218.29 2596.46,1218.29L2657.98,1218.29C2674.95,1218.29 2688.73,1232.7 2688.73,1250.45L2688.73,2292.85C2688.73,2310.61 2674.95,2325.02 2657.98,2325.02L2596.46,2325.02C2579.49,2325.02 2565.7,2310.61 2565.7,2292.85L2565.7,1250.45ZM673.209,1406.38L673.209,2136.92C673.209,2153.9 686.991,2167.68 703.967,2167.68L765.482,2167.68C782.458,2167.68 796.24,2153.9 796.24,2136.92L796.24,1406.38C796.24,1389.41 782.458,1375.63 765.482,1375.63L703.967,1375.63C686.991,1375.63 673.209,1389.41 673.209,1406.38ZM2870.1,1406.38L2870.1,2136.92C2870.1,2153.9 2856.32,2167.68 2839.34,2167.68L2777.83,2167.68C2760.85,2167.68 2747.07,2153.9 2747.07,2136.92L2747.07,1406.38C2747.07,1389.41 2760.85,1375.63 2777.83,1375.63L2839.34,1375.63C2856.32,1375.63 2870.1,1389.41 2870.1,1406.38ZM614.876,1577.5C610.535,1574.24 605.074,1572.3 599.144,1572.3L548.003,1572.3C533.89,1572.3 522.433,1583.3 522.433,1596.86L522.433,1946.44C522.433,1960 533.89,1971.01 548.003,1971.01L599.144,1971.01C605.074,1971.01 610.535,1969.07 614.876,1965.81L614.876,1577.5ZM2928.43,1965.81L2928.43,1577.5C2932.77,1574.24 2938.23,1572.3 2944.16,1572.3L2995.3,1572.3C3009.42,1572.3 3020.87,1583.3 3020.87,1596.86L3020.87,1946.44C3020.87,1960 3009.42,1971.01 2995.3,1971.01L2944.16,1971.01C2938.23,1971.01 2932.77,1969.07 2928.43,1965.81ZM2507.37,1703.14L1035.94,1703.14L1035.94,1840.16L2507.37,1840.16L2507.37,1898.38L2507.37,1659.46L2507.37,1703.14Z"
|
||||
/>
|
||||
),
|
||||
defaultProps: {
|
||||
boxSize: '24px',
|
||||
},
|
||||
});
|
||||
|
||||
export default TrainingIcon;
|
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
||||
<g transform="matrix(1.10943,0,0,1.10943,-206.981,-213.533)">
|
||||
<path d="M688.533,2405.95L542.987,2405.95C349.532,2405.95 192.47,2248.89 192.47,2055.44L192.47,542.987C192.47,349.532 349.532,192.47 542.987,192.47L2527.88,192.47C2721.33,192.47 2878.4,349.532 2878.4,542.987L2878.4,1172.79L3023.94,1172.79C3217.4,1172.79 3374.46,1329.85 3374.46,1523.3C3374.46,1523.3 3374.46,3035.75 3374.46,3035.75C3374.46,3229.21 3217.4,3386.27 3023.94,3386.27L1039.05,3386.27C845.595,3386.27 688.533,3229.21 688.533,3035.75L688.533,2405.95ZM3286.96,2634.37L3286.96,1523.3C3286.96,1378.14 3169.11,1260.29 3023.94,1260.29C3023.94,1260.29 1039.05,1260.29 1039.05,1260.29C893.887,1260.29 776.033,1378.14 776.033,1523.3L776.033,2489.79L1440.94,1736.22L2385.83,2775.59L2880.71,2200.41L3286.96,2634.37ZM2622.05,1405.51C2778.5,1405.51 2905.51,1532.53 2905.51,1688.98C2905.51,1845.42 2778.5,1972.44 2622.05,1972.44C2465.6,1972.44 2338.58,1845.42 2338.58,1688.98C2338.58,1532.53 2465.6,1405.51 2622.05,1405.51ZM2790.9,1172.79L1323.86,1172.79L944.882,755.906L279.97,1509.47L279.97,542.987C279.97,397.824 397.824,279.97 542.987,279.97C542.987,279.97 2527.88,279.97 2527.88,279.97C2673.04,279.97 2790.9,397.824 2790.9,542.987L2790.9,1172.79ZM2125.98,425.197C2282.43,425.197 2409.45,552.213 2409.45,708.661C2409.45,865.11 2282.43,992.126 2125.98,992.126C1969.54,992.126 1842.52,865.11 1842.52,708.661C1842.52,552.213 1969.54,425.197 2125.98,425.197Z"/>
|
||||
</g>
|
||||
</svg>
|
Before Width: | Height: | Size: 1.9 KiB |
Binary file not shown.
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 8.9 KiB |
Binary file not shown.
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
||||
<path d="M3543.31,770.787C3543.31,515.578 3336.11,308.38 3080.9,308.38L462.407,308.38C207.197,308.38 0,515.578 0,770.787L0,2766.03C0,3021.24 207.197,3228.44 462.407,3228.44L3080.9,3228.44C3336.11,3228.44 3543.31,3021.24 3543.31,2766.03C3543.31,2766.03 3543.31,770.787 3543.31,770.787ZM3427.88,770.787L3427.88,2766.03C3427.88,2957.53 3272.4,3113.01 3080.9,3113.01C3080.9,3113.01 462.407,3113.01 462.407,3113.01C270.906,3113.01 115.431,2957.53 115.431,2766.03L115.431,770.787C115.431,579.286 270.906,423.812 462.407,423.812L3080.9,423.812C3272.4,423.812 3427.88,579.286 3427.88,770.787ZM1214.23,1130.69L1321.47,1130.69C1324.01,1130.69 1326.54,1130.53 1329.05,1130.2C1329.05,1130.2 1367.3,1125.33 1397.94,1149.8C1421.63,1168.72 1437.33,1204.3 1437.33,1265.48L1437.33,2078.74L1220.99,2078.74C1146.83,2078.74 1086.61,2138.95 1086.61,2213.12L1086.61,2762.46C1086.61,2836.63 1146.83,2896.84 1220.99,2896.84L1770.34,2896.84C1844.5,2896.84 1904.71,2836.63 1904.71,2762.46L1904.71,2213.12C1904.71,2138.95 1844.5,2078.74 1770.34,2078.74L1554,2078.74L1554,1604.84C1625.84,1658.19 1703.39,1658.1 1703.39,1658.1C1703.54,1658.1 1703.69,1658.11 1703.84,1658.11L2362.2,1658.11L2362.2,1874.44C2362.2,1948.61 2422.42,2008.82 2496.58,2008.82L3045.93,2008.82C3120.09,2008.82 3180.3,1948.61 3180.3,1874.44L3180.3,1325.1C3180.3,1250.93 3120.09,1190.72 3045.93,1190.72L2496.58,1190.72C2422.42,1190.72 2362.2,1250.93 2362.2,1325.1L2362.2,1558.97L2362.2,1541.44L1704.23,1541.44C1702.2,1541.37 1650.96,1539.37 1609.51,1499.26C1577.72,1468.49 1554,1416.47 1554,1331.69L1554,1265.48C1554,1153.86 1513.98,1093.17 1470.76,1058.64C1411.24,1011.1 1338.98,1012.58 1319.15,1014.03L1214.23,1014.03L1214.23,796.992C1214.23,722.828 1154.02,662.617 1079.85,662.617L530.507,662.617C456.343,662.617 396.131,722.828 396.131,796.992L396.131,1346.34C396.131,1420.5 456.343,1480.71 530.507,1480.71L1079.85,1480.71C1154.02,1480.71 1214.23,1420.5 1214.23,1346.34L1214.23,1130.69Z"/>
|
||||
</svg>
|
Before Width: | Height: | Size: 2.3 KiB |
Binary file not shown.
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 6.3 KiB |
Binary file not shown.
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:1.5;">
|
||||
<path d="M709.477,1596.53L992.591,1275.66L2239.09,2646.81L2891.95,1888.03L3427.88,2460.51L3427.88,994.78C3427.88,954.66 3421.05,916.122 3408.5,880.254L3521.9,855.419C3535.8,899.386 3543.31,946.214 3543.31,994.78L3543.31,2990.02C3543.31,3245.23 3336.11,3452.43 3080.9,3452.43C3080.9,3452.43 462.407,3452.43 462.407,3452.43C207.197,3452.43 -0,3245.23 -0,2990.02L-0,994.78C-0,739.571 207.197,532.373 462.407,532.373L505.419,532.373L504.644,532.546L807.104,600.085C820.223,601.729 832.422,607.722 841.77,617.116C850.131,625.517 855.784,636.21 858.055,647.804L462.407,647.804C270.906,647.804 115.431,803.279 115.431,994.78L115.431,2075.73L-0,2101.5L115.431,2127.28L115.431,2269.78L220.47,2150.73L482.345,2209.21C503.267,2211.83 522.722,2221.39 537.63,2236.37C552.538,2251.35 562.049,2270.9 564.657,2291.93L671.84,2776.17L779.022,2291.93C781.631,2270.9 791.141,2251.35 806.05,2236.37C820.958,2221.39 840.413,2211.83 861.334,2209.21L1353.15,2101.5L861.334,1993.8C840.413,1991.18 820.958,1981.62 806.05,1966.64C791.141,1951.66 781.631,1932.11 779.022,1911.08L709.477,1596.53ZM671.84,1573.09L725.556,2006.07C726.863,2016.61 731.63,2026.4 739.101,2033.91C746.573,2041.42 756.323,2046.21 766.808,2047.53L1197.68,2101.5L766.808,2155.48C756.323,2156.8 746.573,2161.59 739.101,2169.09C731.63,2176.6 726.863,2186.4 725.556,2196.94L671.84,2629.92L618.124,2196.94C616.817,2186.4 612.05,2176.6 604.579,2169.09C597.107,2161.59 587.357,2156.8 576.872,2155.48L146.001,2101.5L576.872,2047.53C587.357,2046.21 597.107,2041.42 604.579,2033.91C612.05,2026.4 616.817,2016.61 618.124,2006.07L671.84,1573.09ZM609.035,1710.36L564.657,1911.08C562.049,1932.11 552.538,1951.66 537.63,1966.64C522.722,1981.62 503.267,1991.18 482.345,1993.8L328.665,2028.11L609.035,1710.36ZM2297.12,938.615L2451.12,973.003C2480.59,976.695 2507.99,990.158 2528.99,1011.26C2549.99,1032.37 2563.39,1059.9 2567.07,1089.52L2672.73,1566.9C2634.5,1580.11 2593.44,1587.29 2550.72,1587.29C2344.33,1587.29 2176.77,1419.73 2176.77,1213.34C2176.77,1104.78 2223.13,1006.96 2297.12,938.615ZM2718.05,76.925L2793.72,686.847C2795.56,701.69 2802.27,715.491 2812.8,726.068C2823.32,736.644 2837.06,743.391 2851.83,745.242L3458.78,821.28L2851.83,897.318C2837.06,899.168 2823.32,905.916 2812.8,916.492C2802.27,927.068 2795.56,940.87 2793.72,955.712L2718.05,1565.63L2642.38,955.712C2640.54,940.87 2633.83,927.068 2623.3,916.492C2612.78,905.916 2599.04,899.168 2584.27,897.318L1977.32,821.28L2584.27,745.242C2599.04,743.391 2612.78,736.644 2623.3,726.068C2633.83,715.491 2640.54,701.69 2642.38,686.847L2718.05,76.925ZM2883.68,1043.06C2909.88,1094.13 2924.67,1152.02 2924.67,1213.34C2924.67,1335.4 2866.06,1443.88 2775.49,1512.14L2869.03,1089.52C2871.07,1073.15 2876.07,1057.42 2883.68,1043.06ZM925.928,201.2L959.611,472.704C960.431,479.311 963.42,485.455 968.105,490.163C972.79,494.871 978.904,497.875 985.479,498.698L1255.66,532.546L985.479,566.395C978.904,567.218 972.79,570.222 968.105,574.93C963.42,579.638 960.431,585.781 959.611,592.388L925.928,863.893L892.245,592.388C891.425,585.781 888.436,579.638 883.751,574.93C879.066,570.222 872.952,567.218 866.378,566.395L596.195,532.546L866.378,498.698C872.952,497.875 879.066,494.871 883.751,490.163C888.436,485.455 891.425,479.311 892.245,472.704L925.928,201.2ZM2864.47,532.373L3080.9,532.373C3258.7,532.373 3413.2,632.945 3490.58,780.281L3319.31,742.773C3257.14,683.925 3173.2,647.804 3080.9,647.804L2927.07,647.804C2919.95,642.994 2913.25,637.473 2907.11,631.298C2886.11,610.194 2872.71,582.655 2869.03,553.04L2864.47,532.373ZM1352.36,532.373L2571.64,532.373L2567.07,553.04C2563.39,582.655 2549.99,610.194 2528.99,631.298C2522.85,637.473 2516.16,642.994 2509.03,647.804L993.801,647.804C996.072,636.21 1001.73,625.517 1010.09,617.116C1019.43,607.722 1031.63,601.729 1044.75,600.085L1353.15,532.546L1352.36,532.373Z" style="stroke:white;stroke-opacity:0;stroke-width:1px;"/>
|
||||
</svg>
|
Before Width: | Height: | Size: 4.2 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user