mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve some undefined symbols in model_cache
This commit is contained in:
commit
d96175d127
20
.github/workflows/test-invoke-pip.yml
vendored
20
.github/workflows/test-invoke-pip.yml
vendored
@ -80,12 +80,7 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
- name: set test prompt to main branch validation
|
||||||
if: ${{ github.ref == 'refs/heads/main' }}
|
run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: set test prompt to Pull Request validation
|
|
||||||
if: ${{ github.ref != 'refs/heads/main' }}
|
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
@ -105,12 +100,6 @@ jobs:
|
|||||||
id: run-pytest
|
id: run-pytest
|
||||||
run: pytest
|
run: pytest
|
||||||
|
|
||||||
- name: set INVOKEAI_OUTDIR
|
|
||||||
run: >
|
|
||||||
python -c
|
|
||||||
"import os;from invokeai.backend.globals import Globals;OUTDIR=os.path.join(Globals.root,str('outputs'));print(f'INVOKEAI_OUTDIR={OUTDIR}')"
|
|
||||||
>> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: run invokeai-configure
|
- name: run invokeai-configure
|
||||||
id: run-preload-models
|
id: run-preload-models
|
||||||
env:
|
env:
|
||||||
@ -129,15 +118,20 @@ jobs:
|
|||||||
HF_HUB_OFFLINE: 1
|
HF_HUB_OFFLINE: 1
|
||||||
HF_DATASETS_OFFLINE: 1
|
HF_DATASETS_OFFLINE: 1
|
||||||
TRANSFORMERS_OFFLINE: 1
|
TRANSFORMERS_OFFLINE: 1
|
||||||
|
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||||
run: >
|
run: >
|
||||||
invokeai
|
invokeai
|
||||||
--no-patchmatch
|
--no-patchmatch
|
||||||
--no-nsfw_checker
|
--no-nsfw_checker
|
||||||
--from_file ${{ env.TEST_PROMPTS }}
|
--precision=float32
|
||||||
|
--always_use_cpu
|
||||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||||
|
--from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
|
||||||
- name: Archive results
|
- name: Archive results
|
||||||
id: archive-results
|
id: archive-results
|
||||||
|
env:
|
||||||
|
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: results
|
name: results
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -201,6 +201,8 @@ checkpoints
|
|||||||
# If it's a Mac
|
# If it's a Mac
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
invokeai/frontend/web/dist/*
|
||||||
|
|
||||||
# Let the frontend manage its own gitignore
|
# Let the frontend manage its own gitignore
|
||||||
!invokeai/frontend/web/*
|
!invokeai/frontend/web/*
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ from typing import types
|
|||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ...backend import Globals
|
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
@ -42,17 +41,8 @@ class ApiDependencies:
|
|||||||
|
|
||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
||||||
Globals.try_patchmatch = config.patchmatch
|
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||||
Globals.always_use_cpu = config.always_use_cpu
|
|
||||||
Globals.internet_available = config.internet_available and check_internet()
|
|
||||||
Globals.disable_xformers = not config.xformers
|
|
||||||
Globals.ckpt_convert = config.ckpt_convert
|
|
||||||
|
|
||||||
# TO DO: Use the config to select the logger rather than use the default
|
|
||||||
# invokeai logging module
|
|
||||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
@ -72,7 +62,6 @@ class ApiDependencies:
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config,logger),
|
model_manager=ModelManagerService(config,logger),
|
||||||
events=events,
|
events=events,
|
||||||
logger=logger,
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@ -85,6 +74,8 @@ class ApiDependencies:
|
|||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger),
|
restoration=RestorationServices(config,logger),
|
||||||
|
configuration=config,
|
||||||
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
@ -13,11 +13,11 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
from ..backend import Args
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions, models
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
@ -33,30 +33,25 @@ app.add_middleware(
|
|||||||
middleware_id=event_handler_id,
|
middleware_id=event_handler_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS
|
|
||||||
# TODO: use configuration for this
|
|
||||||
origins = []
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=origins,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
config = {}
|
# initialize config
|
||||||
|
# this is a module global
|
||||||
|
app_config = InvokeAIAppConfig()
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
config = Args()
|
app.add_middleware(
|
||||||
config.parse_args()
|
CORSMiddleware,
|
||||||
|
allow_origins=app_config.allow_origins,
|
||||||
|
allow_credentials=app_config.allow_credentials,
|
||||||
|
allow_methods=app_config.allow_methods,
|
||||||
|
allow_headers=app_config.allow_headers,
|
||||||
|
)
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(
|
||||||
config=config, event_handler_id=event_handler_id, logger=logger
|
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -148,14 +143,11 @@ app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), n
|
|||||||
|
|
||||||
def invoke_api():
|
def invoke_api():
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
# TODO: determine if there's a better way to do this
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(app=app, host="0.0.0.0", port=9090, loop=loop)
|
config = uvicorn.Config(app=app, host=app_config.host, port=app_config.port, loop=loop)
|
||||||
# Use access_log to turn off logging
|
# Use access_log to turn off logging
|
||||||
|
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
invoke_api()
|
invoke_api()
|
||||||
|
@ -285,3 +285,19 @@ class DrawExecutionGraphCommand(BaseCommand):
|
|||||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||||
|
def _iter_indented_subactions(self, action):
|
||||||
|
try:
|
||||||
|
get_subactions = action._get_subactions
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self._indent()
|
||||||
|
if isinstance(action, argparse._SubParsersAction):
|
||||||
|
for subaction in sorted(get_subactions(), key=lambda x: x.dest):
|
||||||
|
yield subaction
|
||||||
|
else:
|
||||||
|
for subaction in get_subactions():
|
||||||
|
yield subaction
|
||||||
|
self._dedent()
|
||||||
|
@ -11,9 +11,10 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from ...backend import ModelManager, Globals
|
from ...backend import ModelManager
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .commands import BaseCommand
|
from .commands import BaseCommand
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
# singleton object, class variable
|
# singleton object, class variable
|
||||||
completer = None
|
completer = None
|
||||||
@ -131,13 +132,13 @@ class Completer(object):
|
|||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
self.linebuffer = None
|
self.linebuffer = None
|
||||||
|
|
||||||
def set_autocompleter(model_manager: ModelManager) -> Completer:
|
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||||
global completer
|
global completer
|
||||||
|
|
||||||
if completer:
|
if completer:
|
||||||
return completer
|
return completer
|
||||||
|
|
||||||
completer = Completer(model_manager)
|
completer = Completer(services.model_manager)
|
||||||
|
|
||||||
readline.set_completer(completer.complete)
|
readline.set_completer(completer.complete)
|
||||||
# pyreadline3 does not have a set_auto_history() method
|
# pyreadline3 does not have a set_auto_history() method
|
||||||
@ -153,7 +154,7 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
|||||||
readline.parse_and_bind("set skip-completed-text on")
|
readline.parse_and_bind("set skip-completed-text on")
|
||||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
readline.parse_and_bind("set show-all-if-ambiguous on")
|
||||||
|
|
||||||
histfile = Path(Globals.root, ".invoke_history")
|
histfile = Path(services.configuration.root_dir / ".invoke_history")
|
||||||
try:
|
try:
|
||||||
readline.read_history_file(histfile)
|
readline.read_history_file(histfile)
|
||||||
readline.set_history_length(1000)
|
readline.set_history_length(1000)
|
||||||
|
@ -4,13 +4,14 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -20,10 +21,7 @@ from invokeai.app.services.metadata import PngMetadataService
|
|||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, SortedHelpFormatter
|
||||||
from ..backend import Globals # this should go when pr 3340 merged
|
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
@ -37,6 +35,7 @@ from .services.invoker import Invoker
|
|||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
from .services.model_manager_service import ModelManagerService
|
from .services.model_manager_service import ModelManagerService
|
||||||
|
from .services.config import get_invokeai_config
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
@ -66,7 +65,7 @@ def add_invocation_args(command_parser):
|
|||||||
|
|
||||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||||
# Create invocation parser
|
# Create invocation parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(formatter_class=SortedHelpFormatter)
|
||||||
|
|
||||||
def exit(*args, **kwargs):
|
def exit(*args, **kwargs):
|
||||||
raise InvalidArgs
|
raise InvalidArgs
|
||||||
@ -191,28 +190,26 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
config = Args()
|
# this gets the basic configuration
|
||||||
config.parse_args()
|
config = get_invokeai_config()
|
||||||
|
|
||||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
# get the optional list of invocations to execute on the command line
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
parser = config.get_parser()
|
||||||
|
parser.add_argument('commands',nargs='*')
|
||||||
|
invocation_commands = parser.parse_args().commands
|
||||||
|
|
||||||
|
# get the optional file to read commands from.
|
||||||
|
# Simplest is to use it for STDIN
|
||||||
|
if infile := config.from_file:
|
||||||
|
sys.stdin = open(infile,"r")
|
||||||
|
|
||||||
model_manager = ModelManagerService(config,logger)
|
model_manager = ModelManagerService(config,logger)
|
||||||
|
|
||||||
# This initializes the autocompleter and returns it.
|
|
||||||
# Currently nothing is done with the returned Completer
|
|
||||||
# object, but the object can be used to change autocompletion
|
|
||||||
# behavior on the fly, if desired.
|
|
||||||
set_autocompleter(model_manager)
|
set_autocompleter(model_manager)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
output_folder = config.output_path
|
||||||
metadata = PngMetadataService()
|
metadata = PngMetadataService()
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
@ -232,6 +229,7 @@ def invoke_cli():
|
|||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
restoration=RestorationServices(config,logger=logger),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
@ -247,9 +245,17 @@ def invoke_cli():
|
|||||||
# print(services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
|
|
||||||
context = CliContext(invoker, session, parser)
|
context = CliContext(invoker, session, parser)
|
||||||
|
set_autocompleter(services)
|
||||||
|
|
||||||
while True:
|
command_line_args_exist = len(invocation_commands) > 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
while not done:
|
||||||
try:
|
try:
|
||||||
|
if command_line_args_exist:
|
||||||
|
cmd_input = invocation_commands.pop(0)
|
||||||
|
done = len(invocation_commands) == 0
|
||||||
|
else:
|
||||||
cmd_input = input("invoke> ")
|
cmd_input = input("invoke> ")
|
||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
# Ctrl-c exits
|
# Ctrl-c exits
|
||||||
@ -374,6 +380,9 @@ def invoke_cli():
|
|||||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
except ValidationError:
|
||||||
|
invoker.services.logger.warning('Invalid command arguments, run "<command> --help" for summary')
|
||||||
|
|
||||||
except SessionError:
|
except SessionError:
|
||||||
# Start a new session
|
# Start a new session
|
||||||
invoker.services.logger.warning("Session error: creating a new session")
|
invoker.services.logger.warning("Session error: creating a new session")
|
||||||
|
@ -5,10 +5,8 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
|
|||||||
|
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
|
||||||
from ...backend.model_management import SDModelType
|
|
||||||
|
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
from compel.prompt_parser import (
|
from compel.prompt_parser import (
|
||||||
@ -18,8 +16,6 @@ from compel.prompt_parser import (
|
|||||||
Fragment,
|
Fragment,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||||
@ -91,7 +87,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(self.prompt)
|
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(self.prompt)
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
if context.services.configuration.log_tokenization:
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
|
@ -5,7 +5,12 @@ from typing import Literal
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext,
|
||||||
|
InvocationConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MathInvocationConfig(BaseModel):
|
class MathInvocationConfig(BaseModel):
|
||||||
@ -22,19 +27,21 @@ class MathInvocationConfig(BaseModel):
|
|||||||
|
|
||||||
class IntOutput(BaseInvocationOutput):
|
class IntOutput(BaseInvocationOutput):
|
||||||
"""An integer output"""
|
"""An integer output"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["int_output"] = "int_output"
|
type: Literal["int_output"] = "int_output"
|
||||||
a: int = Field(default=None, description="The output integer")
|
a: int = Field(default=None, description="The output integer")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["add"] = "add"
|
type: Literal["add"] = "add"
|
||||||
a: int = Field(default=0, description="The first number")
|
a: int = Field(default=0, description="The first number")
|
||||||
b: int = Field(default=0, description="The second number")
|
b: int = Field(default=0, description="The second number")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a + self.b)
|
return IntOutput(a=self.a + self.b)
|
||||||
@ -42,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["sub"] = "sub"
|
type: Literal["sub"] = "sub"
|
||||||
a: int = Field(default=0, description="The first number")
|
a: int = Field(default=0, description="The first number")
|
||||||
b: int = Field(default=0, description="The second number")
|
b: int = Field(default=0, description="The second number")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a - self.b)
|
return IntOutput(a=self.a - self.b)
|
||||||
@ -54,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["mul"] = "mul"
|
type: Literal["mul"] = "mul"
|
||||||
a: int = Field(default=0, description="The first number")
|
a: int = Field(default=0, description="The first number")
|
||||||
b: int = Field(default=0, description="The second number")
|
b: int = Field(default=0, description="The second number")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a * self.b)
|
return IntOutput(a=self.a * self.b)
|
||||||
@ -66,11 +75,12 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["div"] = "div"
|
type: Literal["div"] = "div"
|
||||||
a: int = Field(default=0, description="The first number")
|
a: int = Field(default=0, description="The first number")
|
||||||
b: int = Field(default=0, description="The second number")
|
b: int = Field(default=0, description="The second number")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=int(self.a / self.b))
|
return IntOutput(a=int(self.a / self.b))
|
||||||
@ -78,8 +88,13 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class RandomIntInvocation(BaseInvocation):
|
class RandomIntInvocation(BaseInvocation):
|
||||||
"""Outputs a single random integer."""
|
"""Outputs a single random integer."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["rand_int"] = "rand_int"
|
type: Literal["rand_int"] = "rand_int"
|
||||||
#fmt: on
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
|
high: int = Field(
|
||||||
|
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
return IntOutput(a=np.random.randint(self.low, self.high))
|
||||||
|
528
invokeai/app/services/config.py
Normal file
528
invokeai/app/services/config.py
Normal file
@ -0,0 +1,528 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
|
'''Invokeai configuration system.
|
||||||
|
|
||||||
|
Arguments and fields are taken from the pydantic definition of the
|
||||||
|
model. Defaults can be set by creating a yaml configuration file that
|
||||||
|
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||||
|
categories returned by `invokeai --help`. The file looks like this:
|
||||||
|
|
||||||
|
[file: invokeai.yaml]
|
||||||
|
|
||||||
|
InvokeAI:
|
||||||
|
Paths:
|
||||||
|
root: /home/lstein/invokeai-main
|
||||||
|
conf_path: configs/models.yaml
|
||||||
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
|
outdir: outputs
|
||||||
|
embedding_dir: embeddings
|
||||||
|
lora_dir: loras
|
||||||
|
autoconvert_dir: null
|
||||||
|
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
||||||
|
Models:
|
||||||
|
model: stable-diffusion-1.5
|
||||||
|
embeddings: true
|
||||||
|
Memory/Performance:
|
||||||
|
xformers_enabled: false
|
||||||
|
sequential_guidance: false
|
||||||
|
precision: float16
|
||||||
|
max_loaded_models: 4
|
||||||
|
always_use_cpu: false
|
||||||
|
free_gpu_mem: false
|
||||||
|
Features:
|
||||||
|
nsfw_checker: true
|
||||||
|
restore: true
|
||||||
|
esrgan: true
|
||||||
|
patchmatch: true
|
||||||
|
internet_available: true
|
||||||
|
log_tokenization: false
|
||||||
|
Web Server:
|
||||||
|
host: 127.0.0.1
|
||||||
|
port: 8081
|
||||||
|
allow_origins: []
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods:
|
||||||
|
- '*'
|
||||||
|
allow_headers:
|
||||||
|
- '*'
|
||||||
|
|
||||||
|
The default name of the configuration file is `invokeai.yaml`, located
|
||||||
|
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||||
|
OmegaConf dictionary object initialization time:
|
||||||
|
|
||||||
|
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||||
|
conf = InvokeAIAppConfig(conf=omegaconf)
|
||||||
|
|
||||||
|
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
|
||||||
|
initialization time. You may pass a list of strings in the optional
|
||||||
|
`argv` argument to use instead of the system argv:
|
||||||
|
|
||||||
|
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
||||||
|
|
||||||
|
It is also possible to set a value at initialization time. This value
|
||||||
|
has highest priority.
|
||||||
|
|
||||||
|
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||||
|
|
||||||
|
Any setting can be overwritten by setting an environment variable of
|
||||||
|
form: "INVOKEAI_<setting>", as in:
|
||||||
|
|
||||||
|
export INVOKEAI_port=8080
|
||||||
|
|
||||||
|
Order of precedence (from highest):
|
||||||
|
1) initialization options
|
||||||
|
2) command line options
|
||||||
|
3) environment variable options
|
||||||
|
4) config file options
|
||||||
|
5) pydantic defaults
|
||||||
|
|
||||||
|
Typical usage:
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.invocations.generate import TextToImageInvocation
|
||||||
|
|
||||||
|
# get global configuration and print its nsfw_checker value
|
||||||
|
conf = InvokeAIAppConfig()
|
||||||
|
print(conf.nsfw_checker)
|
||||||
|
|
||||||
|
# get the text2image invocation and print its step value
|
||||||
|
text2image = TextToImageInvocation()
|
||||||
|
print(text2image.steps)
|
||||||
|
|
||||||
|
Computed properties:
|
||||||
|
|
||||||
|
The InvokeAIAppConfig object has a series of properties that
|
||||||
|
resolve paths relative to the runtime root directory. They each return
|
||||||
|
a Path object:
|
||||||
|
|
||||||
|
root_path - path to InvokeAI root
|
||||||
|
output_path - path to default outputs directory
|
||||||
|
model_conf_path - path to models.yaml
|
||||||
|
conf - alias for the above
|
||||||
|
embedding_path - path to the embeddings directory
|
||||||
|
lora_path - path to the LoRA directory
|
||||||
|
|
||||||
|
In most cases, you will want to create a single InvokeAIAppConfig
|
||||||
|
object for the entire application. The get_invokeai_config() function
|
||||||
|
does this:
|
||||||
|
|
||||||
|
config = get_invokeai_config()
|
||||||
|
print(config.root)
|
||||||
|
|
||||||
|
# Subclassing
|
||||||
|
|
||||||
|
If you wish to create a similar class, please subclass the
|
||||||
|
`InvokeAISettings` class and define a Literal field named "type",
|
||||||
|
which is set to the desired top-level name. For example, to create a
|
||||||
|
"InvokeBatch" configuration, define like this:
|
||||||
|
|
||||||
|
class InvokeBatch(InvokeAISettings):
|
||||||
|
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||||
|
node_count : int = Field(default=1, description="Number of nodes to run on", category='Resources')
|
||||||
|
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", category='Resources')
|
||||||
|
|
||||||
|
This will now read and write from the "InvokeBatch" section of the
|
||||||
|
config file, look for environment variables named INVOKEBATCH_*, and
|
||||||
|
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||||
|
two configs are kept in separate sections of the config file:
|
||||||
|
|
||||||
|
# invokeai.yaml
|
||||||
|
|
||||||
|
InvokeBatch:
|
||||||
|
Resources:
|
||||||
|
node_count: 1
|
||||||
|
cpu_count: 8
|
||||||
|
|
||||||
|
InvokeAI:
|
||||||
|
Paths:
|
||||||
|
root: /home/lstein/invokeai-main
|
||||||
|
conf_path: configs/models.yaml
|
||||||
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
|
outdir: outputs
|
||||||
|
...
|
||||||
|
'''
|
||||||
|
import argparse
|
||||||
|
import pydoc
|
||||||
|
import typing
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
|
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
|
INIT_FILE = Path('invokeai.yaml')
|
||||||
|
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||||
|
|
||||||
|
# This global stores a singleton InvokeAIAppConfig configuration object
|
||||||
|
global_config = None
|
||||||
|
|
||||||
|
class InvokeAISettings(BaseSettings):
|
||||||
|
'''
|
||||||
|
Runtime configuration settings in which default values are
|
||||||
|
read from an omegaconf .yaml file.
|
||||||
|
'''
|
||||||
|
initconf : ClassVar[DictConfig] = None
|
||||||
|
argparse_groups : ClassVar[Dict] = {}
|
||||||
|
|
||||||
|
def parse_args(self, argv: list=sys.argv[1:]):
|
||||||
|
parser = self.get_parser()
|
||||||
|
opt, _ = parser.parse_known_args(argv)
|
||||||
|
for name in self.__fields__:
|
||||||
|
if name not in self._excluded():
|
||||||
|
setattr(self, name, getattr(opt,name))
|
||||||
|
|
||||||
|
def to_yaml(self)->str:
|
||||||
|
"""
|
||||||
|
Return a YAML string representing our settings. This can be used
|
||||||
|
as the contents of `invokeai.yaml` to restore settings later.
|
||||||
|
"""
|
||||||
|
cls = self.__class__
|
||||||
|
type = get_args(get_type_hints(cls)['type'])[0]
|
||||||
|
field_dict = dict({type:dict()})
|
||||||
|
for name,field in self.__fields__.items():
|
||||||
|
if name in cls._excluded():
|
||||||
|
continue
|
||||||
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
|
value = getattr(self,name)
|
||||||
|
if category not in field_dict[type]:
|
||||||
|
field_dict[type][category] = dict()
|
||||||
|
# keep paths as strings to make it easier to read
|
||||||
|
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||||
|
conf = OmegaConf.create(field_dict)
|
||||||
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_parser_arguments(cls, parser):
|
||||||
|
if 'type' in get_type_hints(cls):
|
||||||
|
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||||
|
else:
|
||||||
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
|
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||||
|
|
||||||
|
initconf = cls.initconf.get(settings_stanza) \
|
||||||
|
if cls.initconf and settings_stanza in cls.initconf \
|
||||||
|
else OmegaConf.create()
|
||||||
|
|
||||||
|
# create an upcase version of the environment in
|
||||||
|
# order to achieve case-insensitive environment
|
||||||
|
# variables (the way Windows does)
|
||||||
|
upcase_environ = dict()
|
||||||
|
for key,value in os.environ.items():
|
||||||
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
|
fields = cls.__fields__
|
||||||
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
|
for name, field in fields.items():
|
||||||
|
if name not in cls._excluded():
|
||||||
|
current_default = field.default
|
||||||
|
|
||||||
|
category = field.field_info.extra.get("category","Uncategorized")
|
||||||
|
env_name = env_prefix + '_' + name
|
||||||
|
if category in initconf and name in initconf.get(category):
|
||||||
|
field.default = initconf.get(category).get(name)
|
||||||
|
if env_name.upper() in upcase_environ:
|
||||||
|
field.default = upcase_environ[env_name.upper()]
|
||||||
|
cls.add_field_argument(parser, name, field)
|
||||||
|
|
||||||
|
field.default = current_default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cmd_name(self, command_field: str='type')->str:
|
||||||
|
hints = get_type_hints(self)
|
||||||
|
if command_field in hints:
|
||||||
|
return get_args(hints[command_field])[0]
|
||||||
|
else:
|
||||||
|
return 'Uncategorized'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_parser(cls)->ArgumentParser:
|
||||||
|
parser = PagingArgumentParser(
|
||||||
|
prog=cls.cmd_name(),
|
||||||
|
description=cls.__doc__,
|
||||||
|
)
|
||||||
|
cls.add_parser_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_subparser(cls, parser: argparse.ArgumentParser):
|
||||||
|
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _excluded(self)->List[str]:
|
||||||
|
return ['type','initconf']
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file_encoding = 'utf-8'
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
case_sensitive = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||||
|
field_type = get_type_hints(cls).get(name)
|
||||||
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||||
|
if category := field.field_info.extra.get("category"):
|
||||||
|
if category not in cls.argparse_groups:
|
||||||
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
|
argparse_group = cls.argparse_groups[category]
|
||||||
|
else:
|
||||||
|
argparse_group = command_parser
|
||||||
|
|
||||||
|
if get_origin(field_type) == Literal:
|
||||||
|
allowed_values = get_args(field.type_)
|
||||||
|
allowed_types = set()
|
||||||
|
for val in allowed_values:
|
||||||
|
allowed_types.add(type(val))
|
||||||
|
allowed_types_list = list(allowed_types)
|
||||||
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||||
|
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field_type,
|
||||||
|
default=default,
|
||||||
|
choices=allowed_values,
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif get_origin(field_type) == list:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
nargs='*',
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
type=field.type_,
|
||||||
|
default=default,
|
||||||
|
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||||
|
help=field.field_info.description,
|
||||||
|
)
|
||||||
|
def _find_root()->Path:
|
||||||
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
|
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||||
|
elif (
|
||||||
|
os.environ.get("VIRTUAL_ENV")
|
||||||
|
and (Path(os.environ.get("VIRTUAL_ENV"), "..", INIT_FILE).exists()
|
||||||
|
or
|
||||||
|
Path(os.environ.get("VIRTUAL_ENV"), "..", LEGACY_INIT_FILE).exists()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
root = Path(os.environ.get("VIRTUAL_ENV"), "..").resolve()
|
||||||
|
else:
|
||||||
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
|
return root
|
||||||
|
|
||||||
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
|
'''
|
||||||
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
|
the command-line client (recommended for experts only), or
|
||||||
|
"invokeai-web" to launch the web server. Global options
|
||||||
|
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||||
|
setting environment variables INVOKEAI_<setting>.
|
||||||
|
'''
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", category='Web Server')
|
||||||
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", category='Web Server')
|
||||||
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", category='Web Server')
|
||||||
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", category='Web Server')
|
||||||
|
|
||||||
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
|
||||||
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
|
||||||
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
|
||||||
|
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
|
||||||
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
|
||||||
|
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
|
||||||
|
|
||||||
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
|
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||||
|
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||||
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
|
|
||||||
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
|
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||||
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
|
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
||||||
|
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
||||||
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
|
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||||
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
|
|
||||||
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
|
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
||||||
|
'''
|
||||||
|
Initialize InvokeAIAppconfig.
|
||||||
|
:param conf: alternate Omegaconf dictionary object
|
||||||
|
:param argv: aternate sys.argv list
|
||||||
|
:param **kwargs: attributes to initialize with
|
||||||
|
'''
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
|
# in order to pick up the --root_dir option.
|
||||||
|
self.parse_args(argv)
|
||||||
|
if conf is None:
|
||||||
|
try:
|
||||||
|
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
|
# parse args again in order to pick up settings in configuration file
|
||||||
|
self.parse_args(argv)
|
||||||
|
|
||||||
|
# restore initialization values
|
||||||
|
hints = get_type_hints(self)
|
||||||
|
for k in kwargs:
|
||||||
|
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def root_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the runtime root directory
|
||||||
|
'''
|
||||||
|
if self.root:
|
||||||
|
return Path(self.root).expanduser()
|
||||||
|
else:
|
||||||
|
return self.find_root()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def root_dir(self)->Path:
|
||||||
|
'''
|
||||||
|
Alias for above.
|
||||||
|
'''
|
||||||
|
return self.root_path
|
||||||
|
|
||||||
|
def _resolve(self,partial_path:Path)->Path:
|
||||||
|
return (self.root_path / partial_path).resolve()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to defaults outputs directory.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.outdir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_conf_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to models configuration file.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.conf_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def legacy_conf_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||||
|
'''
|
||||||
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_dir(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the global cache directory for HuggingFace hub-managed models
|
||||||
|
'''
|
||||||
|
return self.models_dir / "hub"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def models_dir(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the models directory
|
||||||
|
'''
|
||||||
|
return self._resolve("models")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def converted_ckpts_dir(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the converted models
|
||||||
|
'''
|
||||||
|
return self._resolve("models/converted_ckpts")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the textual inversion embeddings directory.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the LoRA models directory.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.lora_dir) if self.lora_dir else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def autoconvert_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the directory containing models to be imported automatically at startup.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gfpgan_model_path(self)->Path:
|
||||||
|
'''
|
||||||
|
Path to the GFPGAN model.
|
||||||
|
'''
|
||||||
|
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||||
|
|
||||||
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
|
@property
|
||||||
|
def full_precision(self)->bool:
|
||||||
|
"""Return true if precision set to float32"""
|
||||||
|
return self.precision=='float32'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disable_xformers(self)->bool:
|
||||||
|
"""Return true if xformers_enabled is false"""
|
||||||
|
return not self.xformers_enabled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def try_patchmatch(self)->bool:
|
||||||
|
"""Return true if patchmatch true"""
|
||||||
|
return self.patchmatch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_root()->Path:
|
||||||
|
'''
|
||||||
|
Choose the runtime root directory when not specified on command line or
|
||||||
|
init file.
|
||||||
|
'''
|
||||||
|
return _find_root()
|
||||||
|
|
||||||
|
|
||||||
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
|
'''
|
||||||
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
|
It also supports reading defaults from an init file.
|
||||||
|
'''
|
||||||
|
def print_help(self, file=None):
|
||||||
|
text = self.format_help()
|
||||||
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
|
||||||
|
'''
|
||||||
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
|
'''
|
||||||
|
global global_config
|
||||||
|
if global_config is None or type(global_config)!=cls:
|
||||||
|
global_config = cls(**kwargs)
|
||||||
|
return global_config
|
@ -135,6 +135,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
|
"""Execute a graph"""
|
||||||
type: Literal["graph"] = "graph"
|
type: Literal["graph"] = "graph"
|
||||||
|
|
||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
@ -162,6 +163,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
|
"""Iterates over a list of items"""
|
||||||
type: Literal["iterate"] = "iterate"
|
type: Literal["iterate"] = "iterate"
|
||||||
|
|
||||||
collection: list[Any] = Field(
|
collection: list[Any] = Field(
|
||||||
|
@ -10,6 +10,7 @@ from .image_storage import ImageStorageBase
|
|||||||
from .restoration_services import RestorationServices
|
from .restoration_services import RestorationServices
|
||||||
from .invocation_queue import InvocationQueueABC
|
from .invocation_queue import InvocationQueueABC
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
from .config import InvokeAISettings
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
@ -21,6 +22,7 @@ class InvocationServices:
|
|||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
restoration: RestorationServices
|
restoration: RestorationServices
|
||||||
|
configuration: InvokeAISettings
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_library: ItemStorageABC["LibraryGraph"]
|
graph_library: ItemStorageABC["LibraryGraph"]
|
||||||
@ -40,6 +42,7 @@ class InvocationServices:
|
|||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
restoration: RestorationServices,
|
restoration: RestorationServices,
|
||||||
|
configuration: InvokeAISettings=None,
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
@ -52,3 +55,4 @@ class InvocationServices:
|
|||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.restoration = restoration
|
self.restoration = restoration
|
||||||
|
self.configuration = configuration
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generate import Generate
|
|
||||||
from .generator import (
|
from .generator import (
|
||||||
InvokeAIGeneratorBasicParams,
|
InvokeAIGeneratorBasicParams,
|
||||||
InvokeAIGenerator,
|
InvokeAIGenerator,
|
||||||
@ -12,5 +11,3 @@ from .generator import (
|
|||||||
)
|
)
|
||||||
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
|
||||||
from .globals import Globals
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -19,10 +19,10 @@ import warnings
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
|
from typing import get_type_hints
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
import npyscreen
|
import npyscreen
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
@ -38,34 +38,40 @@ from transformers import (
|
|||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from ...frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from ...frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
IntTitleSlider,
|
IntTitleSlider,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
)
|
)
|
||||||
from ..args import PRECISION_CHOICES, Args
|
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
|
||||||
from ..globals import Globals, global_cache_dir, global_config_dir, global_config_file
|
from invokeai.backend.config.model_install_backend import (
|
||||||
from .model_install_backend import (
|
|
||||||
default_dataset,
|
default_dataset,
|
||||||
download_from_hf,
|
download_from_hf,
|
||||||
hf_download_with_resume,
|
hf_download_with_resume,
|
||||||
recommended_datasets,
|
recommended_datasets,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.config import (
|
||||||
|
get_invokeai_config,
|
||||||
|
InvokeAIAppConfig,
|
||||||
|
)
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||||
|
|
||||||
Default_config_file = Path(global_config_dir()) / "models.yaml"
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = Path(global_config_dir()) / "stable-diffusion"
|
SD_Configs = config.legacy_conf_path
|
||||||
|
|
||||||
Datasets = OmegaConf.load(Dataset_path)
|
Datasets = OmegaConf.load(Dataset_path)
|
||||||
|
|
||||||
@ -73,17 +79,12 @@ Datasets = OmegaConf.load(Dataset_path)
|
|||||||
MIN_COLS = 135
|
MIN_COLS = 135
|
||||||
MIN_LINES = 45
|
MIN_LINES = 45
|
||||||
|
|
||||||
|
PRECISION_CHOICES = ['auto','float16','float32','autocast']
|
||||||
|
|
||||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
# Place frequently-used startup commands here, one or more per line.
|
|
||||||
# Examples:
|
|
||||||
# --outdir=D:\data\images
|
|
||||||
# --no-nsfw_checker
|
|
||||||
# --web --host=0.0.0.0
|
|
||||||
# --steps=20
|
|
||||||
# -Ak_euler_a -C10.0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -96,14 +97,13 @@ If you installed manually from source or with 'pip install': activate the virtua
|
|||||||
then run one of the following commands to start InvokeAI.
|
then run one of the following commands to start InvokeAI.
|
||||||
|
|
||||||
Web UI:
|
Web UI:
|
||||||
invokeai --web # (connect to http://localhost:9090)
|
invokeai-web
|
||||||
invokeai --web --host 0.0.0.0 # (connect to http://your-lan-ip:9090 from another computer on the local network)
|
|
||||||
|
|
||||||
Command-line interface:
|
Command-line client:
|
||||||
invokeai
|
invokeai
|
||||||
|
|
||||||
If you installed using an installation script, run:
|
If you installed using an installation script, run:
|
||||||
{Globals.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
{config.root}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||||
|
|
||||||
Add the '--help' argument to see all of the command-line switches available for use.
|
Add the '--help' argument to see all of the command-line switches available for use.
|
||||||
"""
|
"""
|
||||||
@ -216,11 +216,11 @@ def download_realesrgan():
|
|||||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||||
|
|
||||||
model_dest = os.path.join(
|
model_dest = os.path.join(
|
||||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
config.root, "models/realesrgan/realesr-general-x4v3.pth"
|
||||||
)
|
)
|
||||||
|
|
||||||
wdn_model_dest = os.path.join(
|
wdn_model_dest = os.path.join(
|
||||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
config.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||||
)
|
)
|
||||||
|
|
||||||
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
download_with_progress_bar(model_url, model_dest, "RealESRGAN")
|
||||||
@ -243,7 +243,7 @@ def download_gfpgan():
|
|||||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
"./models/gfpgan/weights/parsing_parsenet.pth",
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
model_url, model_dest = model[0], os.path.join(Globals.root, model[1])
|
model_url, model_dest = model[0], os.path.join(config.root, model[1])
|
||||||
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
download_with_progress_bar(model_url, model_dest, "GFPGAN weights")
|
||||||
|
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ def download_codeformer():
|
|||||||
model_url = (
|
model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
)
|
)
|
||||||
model_dest = os.path.join(Globals.root, "models/codeformer/codeformer.pth")
|
model_dest = os.path.join(config.root, "models/codeformer/codeformer.pth")
|
||||||
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
download_with_progress_bar(model_url, model_dest, "CodeFormer")
|
||||||
|
|
||||||
|
|
||||||
@ -295,7 +295,7 @@ def download_vaes():
|
|||||||
# first the diffusers version
|
# first the diffusers version
|
||||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
repo_id = "stabilityai/sd-vae-ft-mse"
|
||||||
args = dict(
|
args = dict(
|
||||||
cache_dir=global_cache_dir("hub"),
|
cache_dir=config.cache_dir,
|
||||||
)
|
)
|
||||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
||||||
raise Exception(f"download of {repo_id} failed")
|
raise Exception(f"download of {repo_id} failed")
|
||||||
@ -306,7 +306,7 @@ def download_vaes():
|
|||||||
if not hf_download_with_resume(
|
if not hf_download_with_resume(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_dir=str(Globals.root / Model_dir / Weights_dir),
|
model_dir=str(config.root / Model_dir / Weights_dir),
|
||||||
):
|
):
|
||||||
raise Exception(f"download of {model_name} failed")
|
raise Exception(f"download of {model_name} failed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -321,8 +321,7 @@ def get_root(root: str = None) -> str:
|
|||||||
elif os.environ.get("INVOKEAI_ROOT"):
|
elif os.environ.get("INVOKEAI_ROOT"):
|
||||||
return os.environ.get("INVOKEAI_ROOT")
|
return os.environ.get("INVOKEAI_ROOT")
|
||||||
else:
|
else:
|
||||||
return Globals.root
|
return config.root
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
class editOptsForm(npyscreen.FormMultiPage):
|
class editOptsForm(npyscreen.FormMultiPage):
|
||||||
@ -332,7 +331,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
def create(self):
|
def create(self):
|
||||||
program_opts = self.parentApp.program_opts
|
program_opts = self.parentApp.program_opts
|
||||||
old_opts = self.parentApp.invokeai_opts
|
old_opts = self.parentApp.invokeai_opts
|
||||||
first_time = not (Globals.root / Globals.initfile).exists()
|
first_time = not (config.root / 'invokeai.yaml').exists()
|
||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
for i in [
|
for i in [
|
||||||
@ -366,7 +365,7 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
self.outdir = self.add_widget_intelligent(
|
self.outdir = self.add_widget_intelligent(
|
||||||
npyscreen.TitleFilename,
|
npyscreen.TitleFilename,
|
||||||
name="(<tab> autocompletes, ctrl-N advances):",
|
name="(<tab> autocompletes, ctrl-N advances):",
|
||||||
value=old_opts.outdir or str(default_output_dir()),
|
value=str(old_opts.outdir) or str(default_output_dir()),
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=False,
|
must_exist=False,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
@ -381,17 +380,17 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
)
|
)
|
||||||
self.safety_checker = self.add_widget_intelligent(
|
self.nsfw_checker = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="NSFW checker",
|
name="NSFW checker",
|
||||||
value=old_opts.safety_checker,
|
value=old_opts.nsfw_checker,
|
||||||
relx=5,
|
relx=5,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
for i in [
|
for i in [
|
||||||
"If you have an account at HuggingFace you may paste your access token here",
|
"If you have an account at HuggingFace you may optionally paste your access token here",
|
||||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".',
|
||||||
"See https://huggingface.co/settings/tokens",
|
"See https://huggingface.co/settings/tokens",
|
||||||
]:
|
]:
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -435,17 +434,10 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
relx=5,
|
relx=5,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.xformers = self.add_widget_intelligent(
|
self.xformers_enabled = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Enable xformers support if available",
|
name="Enable xformers support if available",
|
||||||
value=old_opts.xformers,
|
value=old_opts.xformers_enabled,
|
||||||
relx=5,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.ckpt_convert = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Load legacy checkpoint models into memory as diffusers models",
|
|
||||||
value=old_opts.ckpt_convert,
|
|
||||||
relx=5,
|
relx=5,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
@ -480,19 +472,30 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value="Directory containing embedding/textual inversion files:",
|
value="Directories containing textual inversion and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
)
|
)
|
||||||
self.embedding_path = self.add_widget_intelligent(
|
self.embedding_dir = self.add_widget_intelligent(
|
||||||
npyscreen.TitleFilename,
|
npyscreen.TitleFilename,
|
||||||
name="(<tab> autocompletes, ctrl-N advances):",
|
name=" Textual Inversion Embeddings:",
|
||||||
value=str(default_embedding_dir()),
|
value=str(default_embedding_dir()),
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=False,
|
must_exist=False,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
labelColor="GOOD",
|
labelColor="GOOD",
|
||||||
begin_entry_at=40,
|
begin_entry_at=32,
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
|
self.lora_dir = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilename,
|
||||||
|
name=" LoRA and LyCORIS:",
|
||||||
|
value=str(default_lora_dir()),
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=False,
|
||||||
|
use_two_lines=False,
|
||||||
|
labelColor="GOOD",
|
||||||
|
begin_entry_at=32,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
@ -559,9 +562,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||||
)
|
)
|
||||||
if not Path(opt.embedding_path).parent.exists():
|
if not Path(opt.embedding_dir).parent.exists():
|
||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_path).parent)} is an existing directory."
|
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
||||||
)
|
)
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = "The following problems were detected and must be corrected:\n"
|
message = "The following problems were detected and must be corrected:\n"
|
||||||
@ -577,13 +580,13 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
for attr in [
|
for attr in [
|
||||||
"outdir",
|
"outdir",
|
||||||
"safety_checker",
|
"nsfw_checker",
|
||||||
"free_gpu_mem",
|
"free_gpu_mem",
|
||||||
"max_loaded_models",
|
"max_loaded_models",
|
||||||
"xformers",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
"embedding_path",
|
"embedding_dir",
|
||||||
"ckpt_convert",
|
"lora_dir",
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
@ -591,6 +594,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
|
|
||||||
|
# widget library workaround to make max_loaded_models an int rather than a float
|
||||||
|
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||||
|
|
||||||
return new_opts
|
return new_opts
|
||||||
|
|
||||||
|
|
||||||
@ -628,15 +634,14 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = Args().parse_args([])
|
opts = InvokeAIAppConfig(argv=[])
|
||||||
outdir = Path(opts.outdir)
|
outdir = Path(opts.outdir)
|
||||||
if not outdir.is_absolute():
|
if not outdir.is_absolute():
|
||||||
opts.outdir = str(Globals.root / opts.outdir)
|
opts.outdir = str(config.root / opts.outdir)
|
||||||
if not init_file.exists():
|
if not init_file.exists():
|
||||||
opts.safety_checker = True
|
opts.nsfw_checker = True
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
|
||||||
def default_user_selections(program_opts: Namespace) -> Namespace:
|
def default_user_selections(program_opts: Namespace) -> Namespace:
|
||||||
return Namespace(
|
return Namespace(
|
||||||
starter_models=default_dataset()
|
starter_models=default_dataset()
|
||||||
@ -690,70 +695,61 @@ def run_console_ui(
|
|||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_opts(opts: Namespace, init_file: Path):
|
def write_opts(opts: Namespace, init_file: Path):
|
||||||
"""
|
"""
|
||||||
Update the invokeai.init file with values from opts Namespace
|
Update the invokeai.yaml file with values from current settings.
|
||||||
"""
|
"""
|
||||||
# touch file if it doesn't exist
|
|
||||||
if not init_file.exists():
|
|
||||||
with open(init_file, "w") as f:
|
|
||||||
f.write(INIT_FILE_PREAMBLE)
|
|
||||||
|
|
||||||
# We want to write in the changed arguments without clobbering
|
# this will load current settings
|
||||||
# any other initialization values the user has entered. There is
|
config = InvokeAIAppConfig()
|
||||||
# no good way to do this because of the one-way nature of
|
for key,value in opts.__dict__.items():
|
||||||
# argparse: i.e. --outdir could be --outdir, --out, or -o
|
if hasattr(config,key):
|
||||||
# initfile needs to be replaced with a fully structured format
|
setattr(config,key,value)
|
||||||
# such as yaml; this is a hack that will work much of the time
|
|
||||||
args_to_skip = re.compile(
|
|
||||||
"^--?(o|out|no-xformer|xformer|no-ckpt|ckpt|free|no-nsfw|nsfw|prec|max_load|embed|always|ckpt|free_gpu)"
|
|
||||||
)
|
|
||||||
# fix windows paths
|
|
||||||
opts.outdir = opts.outdir.replace("\\", "/")
|
|
||||||
opts.embedding_path = opts.embedding_path.replace("\\", "/")
|
|
||||||
new_file = f"{init_file}.new"
|
|
||||||
try:
|
|
||||||
lines = [x.strip() for x in open(init_file, "r").readlines()]
|
|
||||||
with open(new_file, "w") as out_file:
|
|
||||||
for line in lines:
|
|
||||||
if len(line) > 0 and not args_to_skip.match(line):
|
|
||||||
out_file.write(line + "\n")
|
|
||||||
out_file.write(
|
|
||||||
f"""
|
|
||||||
--outdir={opts.outdir}
|
|
||||||
--embedding_path={opts.embedding_path}
|
|
||||||
--precision={opts.precision}
|
|
||||||
--max_loaded_models={int(opts.max_loaded_models)}
|
|
||||||
--{'no-' if not opts.safety_checker else ''}nsfw_checker
|
|
||||||
--{'no-' if not opts.xformers else ''}xformers
|
|
||||||
--{'no-' if not opts.ckpt_convert else ''}ckpt_convert
|
|
||||||
{'--free_gpu_mem' if opts.free_gpu_mem else ''}
|
|
||||||
{'--always_use_cpu' if opts.always_use_cpu else ''}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except OSError as e:
|
|
||||||
print(f"** An error occurred while writing the init file: {str(e)}")
|
|
||||||
|
|
||||||
os.replace(new_file, init_file)
|
|
||||||
|
|
||||||
if opts.hf_token:
|
|
||||||
HfLogin(opts.hf_token)
|
|
||||||
|
|
||||||
|
with open(init_file,'w', encoding='utf-8') as file:
|
||||||
|
file.write(config.to_yaml())
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return Globals.root / "outputs"
|
return config.root / "outputs"
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def default_embedding_dir() -> Path:
|
def default_embedding_dir() -> Path:
|
||||||
return Globals.root / "embeddings"
|
return config.root / "embeddings"
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
def default_lora_dir() -> Path:
|
||||||
|
return config.root / "loras"
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||||
opt = default_startup_options(initfile)
|
opt = default_startup_options(initfile)
|
||||||
opt.hf_token = HfFolder.get_token()
|
|
||||||
write_opts(opt, initfile)
|
write_opts(opt, initfile)
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
# Here we bring in
|
||||||
|
# the legacy Args object in order to parse
|
||||||
|
# the old init file and write out the new
|
||||||
|
# yaml format.
|
||||||
|
def migrate_init_file(legacy_format:Path):
|
||||||
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
|
new = InvokeAIAppConfig(conf={})
|
||||||
|
|
||||||
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
|
for attr in fields:
|
||||||
|
if hasattr(old,attr):
|
||||||
|
setattr(new,attr,getattr(old,attr))
|
||||||
|
|
||||||
|
# a few places where the field names have changed and we have to
|
||||||
|
# manually add in the new names/values
|
||||||
|
new.nsfw_checker = old.safety_checker
|
||||||
|
new.xformers_enabled = old.xformers
|
||||||
|
new.conf_path = old.conf
|
||||||
|
new.embedding_dir = old.embedding_path
|
||||||
|
|
||||||
|
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||||
|
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||||
|
outfile.write(new.to_yaml())
|
||||||
|
|
||||||
|
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
@ -810,7 +806,8 @@ def main():
|
|||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
# setting a global here
|
# setting a global here
|
||||||
Globals.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
global config
|
||||||
|
config.root = Path(os.path.expanduser(get_root(opt.root) or ""))
|
||||||
|
|
||||||
errors = set()
|
errors = set()
|
||||||
|
|
||||||
@ -818,19 +815,26 @@ def main():
|
|||||||
models_to_download = default_user_selections(opt)
|
models_to_download = default_user_selections(opt)
|
||||||
|
|
||||||
# We check for to see if the runtime directory is correctly initialized.
|
# We check for to see if the runtime directory is correctly initialized.
|
||||||
init_file = Path(Globals.root, Globals.initfile)
|
old_init_file = Path(config.root, 'invokeai.init')
|
||||||
if not init_file.exists() or not global_config_file().exists():
|
new_init_file = Path(config.root, 'invokeai.yaml')
|
||||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
if old_init_file.exists() and not new_init_file.exists():
|
||||||
|
print('** Migrating invokeai.init to invokeai.yaml')
|
||||||
|
migrate_init_file(old_init_file)
|
||||||
|
config = get_invokeai_config() # reread defaults
|
||||||
|
|
||||||
|
|
||||||
|
if not config.model_conf_path.exists():
|
||||||
|
initialize_rootdir(config.root, opt.yes_to_all)
|
||||||
|
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, init_file)
|
write_default_options(opt, new_init_file)
|
||||||
init_options = Namespace(
|
init_options = Namespace(
|
||||||
precision="float32" if opt.full_precision else "float16"
|
precision="float32" if opt.full_precision else "float16"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
init_options, models_to_download = run_console_ui(opt, init_file)
|
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||||
if init_options:
|
if init_options:
|
||||||
write_opts(init_options, init_file)
|
write_opts(init_options, new_init_file)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||||
|
390
invokeai/backend/config/legacy_arg_parsing.py
Normal file
390
invokeai/backend/config/legacy_arg_parsing.py
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shlex
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
SAMPLER_CHOICES = [
|
||||||
|
"ddim",
|
||||||
|
"ddpm",
|
||||||
|
"deis",
|
||||||
|
"lms",
|
||||||
|
"pndm",
|
||||||
|
"heun",
|
||||||
|
"heun_k",
|
||||||
|
"euler",
|
||||||
|
"euler_k",
|
||||||
|
"euler_a",
|
||||||
|
"kdpm_2",
|
||||||
|
"kdpm_2_a",
|
||||||
|
"dpmpp_2s",
|
||||||
|
"dpmpp_2m",
|
||||||
|
"dpmpp_2m_k",
|
||||||
|
"unipc",
|
||||||
|
]
|
||||||
|
|
||||||
|
PRECISION_CHOICES = [
|
||||||
|
"auto",
|
||||||
|
"float32",
|
||||||
|
"autocast",
|
||||||
|
"float16",
|
||||||
|
]
|
||||||
|
|
||||||
|
class FileArgumentParser(ArgumentParser):
|
||||||
|
"""
|
||||||
|
Supports reading defaults from an init file.
|
||||||
|
"""
|
||||||
|
def convert_arg_line_to_args(self, arg_line):
|
||||||
|
return shlex.split(arg_line, comments=True)
|
||||||
|
|
||||||
|
|
||||||
|
legacy_parser = FileArgumentParser(
|
||||||
|
description=
|
||||||
|
"""
|
||||||
|
Generate images using Stable Diffusion.
|
||||||
|
Use --web to launch the web interface.
|
||||||
|
Use --from_file to load prompts from a file path or standard input ("-").
|
||||||
|
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||||
|
Other command-line arguments are defaults that can usually be overridden
|
||||||
|
prompt the command prompt.
|
||||||
|
""",
|
||||||
|
fromfile_prefix_chars='@',
|
||||||
|
)
|
||||||
|
general_group = legacy_parser.add_argument_group('General')
|
||||||
|
model_group = legacy_parser.add_argument_group('Model selection')
|
||||||
|
file_group = legacy_parser.add_argument_group('Input/output')
|
||||||
|
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||||
|
render_group = legacy_parser.add_argument_group('Rendering')
|
||||||
|
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||||
|
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||||
|
|
||||||
|
deprecated_group.add_argument('--laion400m')
|
||||||
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
|
general_group.add_argument(
|
||||||
|
'--version','-V',
|
||||||
|
action='store_true',
|
||||||
|
help='Print InvokeAI version number'
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--root_dir',
|
||||||
|
default=None,
|
||||||
|
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--config',
|
||||||
|
'-c',
|
||||||
|
'-config',
|
||||||
|
dest='conf',
|
||||||
|
default='./configs/models.yaml',
|
||||||
|
help='Path to configuration file for alternate models.',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--model',
|
||||||
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--weight_dirs',
|
||||||
|
nargs='+',
|
||||||
|
type=str,
|
||||||
|
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--png_compression','-z',
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
choices=range(0,9),
|
||||||
|
dest='png_compression',
|
||||||
|
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'-F',
|
||||||
|
'--full_precision',
|
||||||
|
dest='full_precision',
|
||||||
|
action='store_true',
|
||||||
|
help='Deprecated way to set --precision=float32',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--max_loaded_models',
|
||||||
|
dest='max_loaded_models',
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--free_gpu_mem',
|
||||||
|
dest='free_gpu_mem',
|
||||||
|
action='store_true',
|
||||||
|
help='Force free gpu memory before final decoding',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--sequential_guidance',
|
||||||
|
dest='sequential_guidance',
|
||||||
|
action='store_true',
|
||||||
|
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||||
|
"at the expense of speed",
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--xformers',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help='Enable/disable xformers support (default enabled if installed)',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--always_use_cpu",
|
||||||
|
dest="always_use_cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Force use of CPU even if GPU is available"
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--precision',
|
||||||
|
dest='precision',
|
||||||
|
type=str,
|
||||||
|
choices=PRECISION_CHOICES,
|
||||||
|
metavar='PRECISION',
|
||||||
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
|
default='auto',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--ckpt_convert',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
dest='ckpt_convert',
|
||||||
|
default=True,
|
||||||
|
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--internet',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
dest='internet_available',
|
||||||
|
default=True,
|
||||||
|
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--nsfw_checker',
|
||||||
|
'--safety_checker',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
dest='safety_checker',
|
||||||
|
default=False,
|
||||||
|
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--autoimport',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--autoconvert',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||||
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--patchmatch',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--from_file',
|
||||||
|
dest='infile',
|
||||||
|
type=str,
|
||||||
|
help='If specified, load prompts from this file',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--outdir',
|
||||||
|
'-o',
|
||||||
|
type=str,
|
||||||
|
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||||
|
default='outputs',
|
||||||
|
)
|
||||||
|
file_group.add_argument(
|
||||||
|
'--prompt_as_dir',
|
||||||
|
'-p',
|
||||||
|
action='store_true',
|
||||||
|
help='Place images in subdirectories named after the prompt.',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--fnformat',
|
||||||
|
default='{prefix}.{seed}.png',
|
||||||
|
type=str,
|
||||||
|
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-s',
|
||||||
|
'--steps',
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help='Number of steps'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-W',
|
||||||
|
'--width',
|
||||||
|
type=int,
|
||||||
|
help='Image width, multiple of 64',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-H',
|
||||||
|
'--height',
|
||||||
|
type=int,
|
||||||
|
help='Image height, multiple of 64',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-C',
|
||||||
|
'--cfg_scale',
|
||||||
|
default=7.5,
|
||||||
|
type=float,
|
||||||
|
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--sampler',
|
||||||
|
'-A',
|
||||||
|
'-m',
|
||||||
|
dest='sampler_name',
|
||||||
|
type=str,
|
||||||
|
choices=SAMPLER_CHOICES,
|
||||||
|
metavar='SAMPLER_NAME',
|
||||||
|
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
|
default='k_lms',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--log_tokenization',
|
||||||
|
'-t',
|
||||||
|
action='store_true',
|
||||||
|
help='shows how the prompt is split into tokens'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-f',
|
||||||
|
'--strength',
|
||||||
|
type=float,
|
||||||
|
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-T',
|
||||||
|
'-fit',
|
||||||
|
'--fit',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||||
|
)
|
||||||
|
|
||||||
|
render_group.add_argument(
|
||||||
|
'--grid',
|
||||||
|
'-g',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help='generate a grid'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--embedding_directory',
|
||||||
|
'--embedding_path',
|
||||||
|
dest='embedding_path',
|
||||||
|
default='embeddings',
|
||||||
|
type=str,
|
||||||
|
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--lora_directory',
|
||||||
|
dest='lora_path',
|
||||||
|
default='loras',
|
||||||
|
type=str,
|
||||||
|
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--embeddings',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--enable_image_debugging',
|
||||||
|
action='store_true',
|
||||||
|
help='Generates debugging image to display'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--karras_max',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||||
|
)
|
||||||
|
# Restoration related args
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--no_restore',
|
||||||
|
dest='restore',
|
||||||
|
action='store_false',
|
||||||
|
help='Disable face restoration with GFPGAN or codeformer',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--no_upscale',
|
||||||
|
dest='esrgan',
|
||||||
|
action='store_false',
|
||||||
|
help='Disable upscaling with ESRGAN',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--esrgan_bg_tile',
|
||||||
|
type=int,
|
||||||
|
default=400,
|
||||||
|
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--esrgan_denoise_str',
|
||||||
|
type=float,
|
||||||
|
default=0.75,
|
||||||
|
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||||
|
)
|
||||||
|
postprocessing_group.add_argument(
|
||||||
|
'--gfpgan_model_path',
|
||||||
|
type=str,
|
||||||
|
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||||
|
help='Indicates the path to the GFPGAN model',
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--web',
|
||||||
|
dest='web',
|
||||||
|
action='store_true',
|
||||||
|
help='Start in web server mode.',
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--web_develop',
|
||||||
|
dest='web_develop',
|
||||||
|
action='store_true',
|
||||||
|
help='Start in web server development mode.',
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
"--web_verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enables verbose logging",
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
"--cors",
|
||||||
|
nargs="*",
|
||||||
|
type=str,
|
||||||
|
help="Additional allowed origins, comma-separated",
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--host',
|
||||||
|
type=str,
|
||||||
|
default='127.0.0.1',
|
||||||
|
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--port',
|
||||||
|
type=int,
|
||||||
|
default='9090',
|
||||||
|
help='Web server: Port to listen on'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--certfile',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--keyfile',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--gui',
|
||||||
|
dest='gui',
|
||||||
|
action='store_true',
|
||||||
|
help='Start InvokeAI GUI',
|
||||||
|
)
|
@ -19,13 +19,15 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from ..model_management import ModelManager
|
from ..model_management import ModelManager
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
config = get_invokeai_config()
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
|
|
||||||
@ -47,12 +49,11 @@ Config_preamble = """
|
|||||||
|
|
||||||
|
|
||||||
def default_config_file():
|
def default_config_file():
|
||||||
return Path(global_config_dir()) / "models.yaml"
|
return config.model_conf_path
|
||||||
|
|
||||||
|
|
||||||
def sd_configs():
|
def sd_configs():
|
||||||
return Path(global_config_dir()) / "stable-diffusion"
|
return config.legacy_conf_path
|
||||||
|
|
||||||
|
|
||||||
def initial_models():
|
def initial_models():
|
||||||
global Datasets
|
global Datasets
|
||||||
@ -121,8 +122,9 @@ def install_requested_models(
|
|||||||
|
|
||||||
if scan_at_startup and scan_directory.is_dir():
|
if scan_at_startup and scan_directory.is_dir():
|
||||||
argument = "--autoconvert"
|
argument = "--autoconvert"
|
||||||
initfile = Path(Globals.root, Globals.initfile)
|
print('** The global initfile is no longer supported; rewrite to support new yaml format **')
|
||||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
initfile = Path(config.root, 'invokeai.init')
|
||||||
|
replacement = Path(config.root, f"invokeai.init.new")
|
||||||
directory = str(scan_directory).replace("\\", "/")
|
directory = str(scan_directory).replace("\\", "/")
|
||||||
with open(initfile, "r") as input:
|
with open(initfile, "r") as input:
|
||||||
with open(replacement, "w") as output:
|
with open(replacement, "w") as output:
|
||||||
@ -150,7 +152,7 @@ def get_root(root: str = None) -> str:
|
|||||||
elif os.environ.get("INVOKEAI_ROOT"):
|
elif os.environ.get("INVOKEAI_ROOT"):
|
||||||
return os.environ.get("INVOKEAI_ROOT")
|
return os.environ.get("INVOKEAI_ROOT")
|
||||||
else:
|
else:
|
||||||
return Globals.root
|
return config.root
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
@ -183,7 +185,7 @@ def all_datasets() -> dict:
|
|||||||
# look for legacy model.ckpt in models directory and offer to
|
# look for legacy model.ckpt in models directory and offer to
|
||||||
# normalize its name
|
# normalize its name
|
||||||
def migrate_models_ckpt():
|
def migrate_models_ckpt():
|
||||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
model_path = os.path.join(config.root, Model_dir, Weights_dir)
|
||||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||||
return
|
return
|
||||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||||
@ -228,7 +230,7 @@ def _download_repo_or_file(
|
|||||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||||
repo_id = mconfig["repo_id"]
|
repo_id = mconfig["repo_id"]
|
||||||
filename = mconfig["file"]
|
filename = mconfig["file"]
|
||||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
cache_dir = os.path.join(config.root, Model_dir, Weights_dir)
|
||||||
return hf_download_with_resume(
|
return hf_download_with_resume(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
model_dir=cache_dir,
|
model_dir=cache_dir,
|
||||||
@ -239,9 +241,9 @@ def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_from_hf(
|
def download_from_hf(
|
||||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
model_class: object, model_name: str, **kwargs
|
||||||
):
|
):
|
||||||
path = global_cache_dir(cache_subdir)
|
path = config.cache_dir
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
cache_dir=path,
|
cache_dir=path,
|
||||||
@ -417,7 +419,7 @@ def new_config_file_contents(
|
|||||||
stanza["height"] = mod["height"]
|
stanza["height"] = mod["height"]
|
||||||
if "file" in mod:
|
if "file" in mod:
|
||||||
stanza["weights"] = os.path.relpath(
|
stanza["weights"] = os.path.relpath(
|
||||||
successfully_downloaded[model], start=Globals.root
|
successfully_downloaded[model], start=config.root
|
||||||
)
|
)
|
||||||
stanza["config"] = os.path.normpath(
|
stanza["config"] = os.path.normpath(
|
||||||
os.path.join(sd_configs(), mod["config"])
|
os.path.join(sd_configs(), mod["config"])
|
||||||
@ -456,7 +458,7 @@ def delete_weights(model_name: str, conf_stanza: dict):
|
|||||||
|
|
||||||
weights = Path(weights)
|
weights = Path(weights)
|
||||||
if not weights.is_absolute():
|
if not weights.is_absolute():
|
||||||
weights = Path(Globals.root) / weights
|
weights = Path(config.root) / weights
|
||||||
try:
|
try:
|
||||||
weights.unlink()
|
weights.unlink()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,126 +0,0 @@
|
|||||||
"""
|
|
||||||
invokeai.backend.globals defines a small number of global variables that would
|
|
||||||
otherwise have to be passed through long and complex call chains.
|
|
||||||
|
|
||||||
It defines a Namespace object named "Globals" that contains
|
|
||||||
the attributes:
|
|
||||||
|
|
||||||
- root - the root directory under which "models" and "outputs" can be found
|
|
||||||
- initfile - path to the initialization file
|
|
||||||
- try_patchmatch - option to globally disable loading of 'patchmatch' module
|
|
||||||
- always_use_cpu - force use of CPU even if GPU is available
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
from argparse import Namespace
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
Globals = Namespace()
|
|
||||||
|
|
||||||
# Where to look for the initialization file and other key components
|
|
||||||
Globals.initfile = "invokeai.init"
|
|
||||||
Globals.models_file = "models.yaml"
|
|
||||||
Globals.models_dir = "models"
|
|
||||||
Globals.config_dir = "configs"
|
|
||||||
Globals.autoscan_dir = "weights"
|
|
||||||
Globals.converted_ckpts_dir = "converted_ckpts"
|
|
||||||
|
|
||||||
# Set the default root directory. This can be overwritten by explicitly
|
|
||||||
# passing the `--root <directory>` argument on the command line.
|
|
||||||
# logic is:
|
|
||||||
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
|
|
||||||
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
|
|
||||||
# 3) use ~/invokeai
|
|
||||||
|
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
|
||||||
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
|
|
||||||
elif (
|
|
||||||
os.environ.get("VIRTUAL_ENV")
|
|
||||||
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
|
|
||||||
):
|
|
||||||
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
|
|
||||||
else:
|
|
||||||
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
|
|
||||||
|
|
||||||
# Try loading patchmatch
|
|
||||||
Globals.try_patchmatch = True
|
|
||||||
|
|
||||||
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
|
|
||||||
Globals.always_use_cpu = False
|
|
||||||
|
|
||||||
# Whether the internet is reachable for dynamic downloads
|
|
||||||
# The CLI will test connectivity at startup time.
|
|
||||||
Globals.internet_available = True
|
|
||||||
|
|
||||||
# Whether to disable xformers
|
|
||||||
Globals.disable_xformers = False
|
|
||||||
|
|
||||||
# Low-memory tradeoff for guidance calculations.
|
|
||||||
Globals.sequential_guidance = False
|
|
||||||
|
|
||||||
# whether we are forcing full precision
|
|
||||||
Globals.full_precision = False
|
|
||||||
|
|
||||||
# whether we should convert ckpt files into diffusers models on the fly
|
|
||||||
Globals.ckpt_convert = True
|
|
||||||
|
|
||||||
# logging tokenization everywhere
|
|
||||||
Globals.log_tokenization = False
|
|
||||||
|
|
||||||
|
|
||||||
def global_config_file() -> Path:
|
|
||||||
return Path(Globals.root, Globals.config_dir, Globals.models_file)
|
|
||||||
|
|
||||||
|
|
||||||
def global_config_dir() -> Path:
|
|
||||||
return Path(Globals.root, Globals.config_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def global_models_dir() -> Path:
|
|
||||||
return Path(Globals.root, Globals.models_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def global_autoscan_dir() -> Path:
|
|
||||||
return Path(Globals.root, Globals.autoscan_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def global_converted_ckpts_dir() -> Path:
|
|
||||||
return Path(global_models_dir(), Globals.converted_ckpts_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def global_set_root(root_dir: Union[str, Path]):
|
|
||||||
Globals.root = root_dir
|
|
||||||
|
|
||||||
def global_resolve_path(path: Union[str,Path]):
|
|
||||||
if path is None:
|
|
||||||
return None
|
|
||||||
return Path(Globals.root,path).resolve()
|
|
||||||
|
|
||||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
|
||||||
"""
|
|
||||||
Returns Path to the model cache directory. If a subdirectory
|
|
||||||
is provided, it will be appended to the end of the path, allowing
|
|
||||||
for Hugging Face-style conventions. Currently, Hugging Face has
|
|
||||||
moved all models into the "hub" subfolder, so for any pretrained
|
|
||||||
HF model, use:
|
|
||||||
global_cache_dir('hub')
|
|
||||||
|
|
||||||
The legacy location for transformers used to be global_cache_dir('transformers')
|
|
||||||
and global_cache_dir('diffusers') for diffusers.
|
|
||||||
"""
|
|
||||||
home: str = os.getenv("HF_HOME")
|
|
||||||
|
|
||||||
if home is None:
|
|
||||||
home = os.getenv("XDG_CACHE_HOME")
|
|
||||||
|
|
||||||
if home is not None:
|
|
||||||
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
|
|
||||||
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
|
|
||||||
home += os.sep + "huggingface"
|
|
||||||
|
|
||||||
if home is not None:
|
|
||||||
return Path(home, subdir)
|
|
||||||
else:
|
|
||||||
return Path(Globals.root, "models", subdir)
|
|
@ -6,7 +6,7 @@ be suppressed or deferred
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
@ -21,9 +21,10 @@ class PatchMatch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_patch_match(self):
|
def _load_patch_match(self):
|
||||||
|
config = get_invokeai_config()
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
if Globals.try_patchmatch:
|
if config.try_patchmatch:
|
||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
|
@ -33,12 +33,11 @@ from PIL import Image, ImageOps
|
|||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import global_cache_dir
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
self.heatmap = heatmap
|
self.heatmap = heatmap
|
||||||
@ -84,14 +83,15 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
logger.info("Initializing clipseg model for text to mask inference")
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||||
)
|
)
|
||||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||||
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
|
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -26,7 +26,7 @@ import torch
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
@ -76,7 +76,6 @@ from transformers import (
|
|||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
@ -858,7 +857,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
|
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -913,7 +912,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_paint_by_example_checkpoint(checkpoint):
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = get_invokeai_config().cache_dir
|
||||||
config = CLIPVisionConfig.from_pretrained(
|
config = CLIPVisionConfig.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -985,7 +984,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = get_invokeai_config().cache_dir
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -1121,7 +1120,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param vae: A diffusers VAE to load into the pipeline.
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
@ -1134,7 +1133,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
checkpoint = load_file(checkpoint_path)
|
checkpoint = load_file(checkpoint_path)
|
||||||
|
|
||||||
cache_dir = global_cache_dir("hub")
|
cache_dir = config.cache_dir
|
||||||
pipeline_class = (
|
pipeline_class = (
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
if return_generator_pipeline
|
if return_generator_pipeline
|
||||||
@ -1158,25 +1157,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
if model_type == SDLegacyType.V2_v:
|
if model_type == SDLegacyType.V2_v:
|
||||||
original_config_file = (
|
original_config_file = (
|
||||||
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
|
config.legacy_conf_path / "v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
if global_step == 110000:
|
if global_step == 110000:
|
||||||
# v2.1 needs to upcast attention
|
# v2.1 needs to upcast attention
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
original_config_file = (
|
original_config_file = (
|
||||||
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
|
config.legacy_conf_path / "v2-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
original_config_file = (
|
original_config_file = (
|
||||||
global_config_dir()
|
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||||
/ "stable-diffusion"
|
|
||||||
/ "v1-inpainting-inference.yaml"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == SDLegacyType.V1:
|
elif model_type == SDLegacyType.V1:
|
||||||
original_config_file = (
|
original_config_file = (
|
||||||
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
|
config.legacy_conf_path / "v1-inference.yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -1323,7 +1320,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
)
|
)
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
"CompVis/stable-diffusion-safety-checker",
|
"CompVis/stable-diffusion-safety-checker",
|
||||||
cache_dir=global_cache_dir("hub"),
|
cache_dir=config.cache_dir,
|
||||||
)
|
)
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir
|
||||||
|
@ -25,27 +25,25 @@ import warnings
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Sequence, Union, Tuple, types, Optional, List, Type, Any
|
from typing import Dict, Sequence, Union, types, Optional, List, Type, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from huggingface_hub import HfApi, scan_cache_dir
|
from huggingface_hub import HfApi, scan_cache_dir
|
||||||
from picklescan.scanner import scan_file_path
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from ..globals import global_cache_dir
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(repo_id_or_path: str):
|
def get_model_path(repo_id_or_path: str):
|
||||||
|
globals = get_invokeai_config()
|
||||||
|
|
||||||
if os.path.exists(repo_id_or_path):
|
if os.path.exists(repo_id_or_path):
|
||||||
return repo_id_or_path
|
return repo_id_or_path
|
||||||
|
|
||||||
cache = scan_cache_dir(global_cache_dir("hub"))
|
cache = scan_cache_dir(globals.cache_dir)
|
||||||
for repo in cache.repos:
|
for repo in cache.repos:
|
||||||
if repo.repo_id != repo_id_or_path:
|
if repo.repo_id != repo_id_or_path:
|
||||||
continue
|
continue
|
||||||
@ -234,7 +232,7 @@ class DiffusersModelInfo(ModelInfoBase):
|
|||||||
model = self.child_types[child_type].from_pretrained(
|
model = self.child_types[child_type].from_pretrained(
|
||||||
self.repo_id_or_path,
|
self.repo_id_or_path,
|
||||||
subfolder=child_type.value,
|
subfolder=child_type.value,
|
||||||
cache_dir=global_cache_dir('hub'),
|
cache_dir=get_invokeai_config.cache_dir('hub'),
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
@ -248,7 +246,7 @@ class DiffusersModelInfo(ModelInfoBase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline(self, **kwrags):
|
def get_pipeline(self, **kwargs):
|
||||||
return DiffusionPipeline.from_pretrained(
|
return DiffusionPipeline.from_pretrained(
|
||||||
self.repo_id_or_path,
|
self.repo_id_or_path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -349,7 +347,7 @@ class ClassifierModelInfo(ModelInfoBase):
|
|||||||
model = self.child_types[child_type].from_pretrained(
|
model = self.child_types[child_type].from_pretrained(
|
||||||
self.repo_id_or_path,
|
self.repo_id_or_path,
|
||||||
subfolder=child_type.value,
|
subfolder=child_type.value,
|
||||||
cache_dir=global_cache_dir('hub'),
|
cache_dir=get_invokeai_config().cache_dir('hub'),
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# calc more accurate size
|
# calc more accurate size
|
||||||
@ -394,7 +392,7 @@ class VaeModelInfo(ModelInfoBase):
|
|||||||
|
|
||||||
model = self.vae_type.from_pretrained(
|
model = self.vae_type.from_pretrained(
|
||||||
self.repo_id_or_path,
|
self.repo_id_or_path,
|
||||||
cache_dir=global_cache_dir('hub'),
|
cache_dir=get_invokeai_config().cache_dir('hub'),
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# calc more accurate size
|
# calc more accurate size
|
||||||
|
@ -149,8 +149,7 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import (Globals, global_cache_dir,
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
global_resolve_path)
|
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
@ -226,7 +225,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self._update_config_file_version()
|
self._update_config_file_version()
|
||||||
|
self.globals = get_invokeai_config()
|
||||||
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
execution_device = device_type,
|
execution_device = device_type,
|
||||||
@ -235,7 +235,6 @@ class ModelManager(object):
|
|||||||
logger = logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
@ -304,12 +303,6 @@ class ModelManager(object):
|
|||||||
# raises an InvalidModelError
|
# raises an InvalidModelError
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Commented-out workaround for callers that use "type/name" as the model name
|
|
||||||
# because they haven't adjusted to the new return format of `list_models()`
|
|
||||||
# if "/" in model_name:
|
|
||||||
# model_key = model_name
|
|
||||||
# else:
|
|
||||||
model_key = self.create_key(model_name, model_type)
|
model_key = self.create_key(model_name, model_type)
|
||||||
if model_key not in self.config:
|
if model_key not in self.config:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
@ -326,13 +319,15 @@ class ModelManager(object):
|
|||||||
if mconfig.format in ["ckpt", "safetensors"]:
|
if mconfig.format in ["ckpt", "safetensors"]:
|
||||||
location = self.convert_ckpt_and_cache(mconfig)
|
location = self.convert_ckpt_and_cache(mconfig)
|
||||||
else:
|
else:
|
||||||
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
|
location = self.globals.root_dir / mconfig.get('path') or mconfig.get('repo_id')
|
||||||
|
elif p := mconfig.get('path'):
|
||||||
|
location = self.globals.root_dir / p
|
||||||
|
elif r := mconfig.get('repo_id'):
|
||||||
|
location = r
|
||||||
|
elif w := mconfig.get('weights'):
|
||||||
|
location = self.globals.root_dir / w
|
||||||
else:
|
else:
|
||||||
location = global_resolve_path(
|
location = None
|
||||||
mconfig.get('path')) \
|
|
||||||
or mconfig.get('repo_id') \
|
|
||||||
or global_resolve_path(mconfig.get('weights')
|
|
||||||
)
|
|
||||||
|
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
hash = self.cache.model_hash(location, revision)
|
hash = self.cache.model_hash(location, revision)
|
||||||
@ -423,7 +418,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# if we are converting legacy files automatically, then
|
# if we are converting legacy files automatically, then
|
||||||
# there are no legacy ckpts!
|
# there are no legacy ckpts!
|
||||||
if Globals.ckpt_convert:
|
if self.globals.ckpt_convert:
|
||||||
return False
|
return False
|
||||||
info = self.model_info(model_name, model_type)
|
info = self.model_info(model_name, model_type)
|
||||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||||
@ -862,25 +857,16 @@ class ModelManager(object):
|
|||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
self.logger.debug("SD-v1 model detected")
|
self.logger.debug("SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
|
||||||
)
|
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
self.logger.debug("SD-v1 inpainting model detected")
|
self.logger.debug("SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml",
|
||||||
Globals.root,
|
|
||||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
|
||||||
)
|
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
self.logger.debug("SD-v2-v model detected")
|
self.logger.debug("SD-v2-v model detected")
|
||||||
model_config_file = Path(
|
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
|
||||||
)
|
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
self.logger.debug("SD-v2-e model detected")
|
self.logger.debug("SD-v2-e model detected")
|
||||||
model_config_file = Path(
|
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
|
||||||
)
|
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
@ -907,9 +893,7 @@ class ModelManager(object):
|
|||||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = self.globals.converted_ckpts_dir / model_path.stem
|
||||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
|
||||||
)
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model_name = self.convert_and_import(
|
model_name = self.convert_and_import(
|
||||||
model_path,
|
model_path,
|
||||||
@ -930,9 +914,9 @@ class ModelManager(object):
|
|||||||
diffusers, cache it to disk, and return Path to converted
|
diffusers, cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
weights = global_resolve_path(mconfig.weights)
|
weights = self.globals.root_dir / mconfig.weights
|
||||||
config_file = global_resolve_path(mconfig.config)
|
config_file = self.globals.root_dir / mconfig.config
|
||||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights.stem
|
diffusers_path = self.globals.converted_ckpts_dir / weights.stem
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if diffusers_path.exists():
|
if diffusers_path.exists():
|
||||||
@ -949,7 +933,7 @@ class ModelManager(object):
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
vae=vae_model,
|
vae=vae_model,
|
||||||
vae_path=str(global_resolve_path(vae_ckpt_path)) if vae_ckpt_path else None,
|
vae_path=str(self.globals.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
)
|
)
|
||||||
return diffusers_path
|
return diffusers_path
|
||||||
@ -960,9 +944,10 @@ class ModelManager(object):
|
|||||||
object, cache it to disk, and return Path to converted
|
object, cache it to disk, and return Path to converted
|
||||||
file. If already on disk then just returns Path.
|
file. If already on disk then just returns Path.
|
||||||
"""
|
"""
|
||||||
weights_file = global_resolve_path(mconfig.weights)
|
root = self.globals.root_dir
|
||||||
config_file = global_resolve_path(mconfig.config)
|
weights_file = root / mconfig.weights
|
||||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
|
config_file = root / mconfig.config
|
||||||
|
diffusers_path = self.globals.converted_ckpts_dir / weights_file.stem
|
||||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
@ -1018,7 +1003,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
# 3. If mconfig has a vae dict, then we use it as the diffusers-style vae
|
||||||
if vae_config and isinstance(vae_config,DictConfig):
|
if vae_config and isinstance(vae_config,DictConfig):
|
||||||
vae_diffusers_location = global_resolve_path(vae_config.get('path')) or vae_config.get('repo_id')
|
vae_diffusers_location = self.globals.root_dir / vae_config.get('path') \
|
||||||
|
if vae_config.get('path') \
|
||||||
|
else vae_config.get('repo_id')
|
||||||
|
|
||||||
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
# 4. Otherwise, we use stabilityai/sd-vae-ft-mse "because it works"
|
||||||
else:
|
else:
|
||||||
@ -1072,7 +1059,9 @@ class ModelManager(object):
|
|||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = None
|
vae_model = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_location = global_resolve_path(vae.get('path')) or vae.get('repo_id')
|
vae_location = self.globals.root_dir / vae.get('path') \
|
||||||
|
if vae.get('path') \
|
||||||
|
else vae.get('repo_id')
|
||||||
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
|
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
|
||||||
vae_path = None
|
vae_path = None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
@ -1140,6 +1129,7 @@ class ModelManager(object):
|
|||||||
yaml_str = OmegaConf.to_yaml(self.config)
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None,'no config file path to write to'
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
|
config_file_path = self.globals.root_dir / config_file_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(self.preamble())
|
outfile.write(self.preamble())
|
||||||
@ -1160,7 +1150,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(cls,repo_id):
|
||||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
# but the code quickly became incomprehensible!
|
# but the code quickly became incomprehensible!
|
||||||
@ -1177,9 +1167,10 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _abs_path(path: str | Path) -> Path:
|
def _abs_path(path: str | Path) -> Path:
|
||||||
|
globals = get_invokeai_config()
|
||||||
if path is None or Path(path).is_absolute():
|
if path is None or Path(path).is_absolute():
|
||||||
return path
|
return path
|
||||||
return Path(Globals.root, path).resolve()
|
return Path(globals.root_dir, path).resolve()
|
||||||
|
|
||||||
# This is not the same as global_resolve_path(), which prepends
|
# This is not the same as global_resolve_path(), which prepends
|
||||||
# Globals.root.
|
# Globals.root.
|
||||||
@ -1188,15 +1179,11 @@ class ModelManager(object):
|
|||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
resolved_path = None
|
resolved_path = None
|
||||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||||
dest_directory = Path(dest_directory)
|
dest_directory = self.globals.root_dir / dest_directory
|
||||||
if not dest_directory.is_absolute():
|
|
||||||
dest_directory = Globals.root / dest_directory
|
|
||||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||||
resolved_path = download_with_resume(str(source), dest_directory)
|
resolved_path = download_with_resume(str(source), dest_directory)
|
||||||
else:
|
else:
|
||||||
if not os.path.isabs(source):
|
resolved_path = self.globals.root_dir / source
|
||||||
source = os.path.join(Globals.root, source)
|
|
||||||
resolved_path = Path(source)
|
|
||||||
return resolved_path
|
return resolved_path
|
||||||
|
|
||||||
def _update_config_file_version(self):
|
def _update_config_file_version(self):
|
||||||
|
@ -17,67 +17,59 @@ from compel.prompt_parser import (
|
|||||||
FlattenedPrompt,
|
FlattenedPrompt,
|
||||||
Fragment,
|
Fragment,
|
||||||
PromptParser,
|
PromptParser,
|
||||||
|
Conjunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
def get_uc_and_c_and_ec(
|
model: InvokeAIDiffuserComponent,
|
||||||
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
):
|
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = model.tokenizer
|
compel = Compel(tokenizer=model.tokenizer,
|
||||||
compel = Compel(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=model.text_encoder,
|
text_encoder=model.text_encoder,
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
(
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
positive_prompt_string,
|
|
||||||
negative_prompt_string,
|
|
||||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
|
||||||
legacy_blend = try_parse_legacy_blend(
|
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
|
||||||
)
|
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
|
||||||
if legacy_blend is not None:
|
|
||||||
positive_prompt = legacy_blend
|
|
||||||
else:
|
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
|
||||||
negative_prompt_string
|
|
||||||
)
|
|
||||||
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
positive_conjunction: Conjunction
|
||||||
|
if legacy_blend is not None:
|
||||||
|
positive_conjunction = legacy_blend
|
||||||
|
else:
|
||||||
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
|
|
||||||
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
|
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||||
|
if log_tokens or config.log_tokenization:
|
||||||
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
|
cross_attention_control_args=options.get(
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
'cross_attention_control', None))
|
||||||
tokens_count_including_eos_bos=tokens_count,
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
return uc, c, ec
|
return uc, c, ec
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_structure(
|
def get_prompt_structure(
|
||||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
@ -88,18 +80,17 @@ def get_prompt_structure(
|
|||||||
legacy_blend = try_parse_legacy_blend(
|
legacy_blend = try_parse_legacy_blend(
|
||||||
positive_prompt_string, skip_normalize_legacy_blend
|
positive_prompt_string, skip_normalize_legacy_blend
|
||||||
)
|
)
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend]
|
positive_prompt: Conjunction
|
||||||
if legacy_blend is not None:
|
if legacy_blend is not None:
|
||||||
positive_prompt = legacy_blend
|
positive_conjunction = legacy_blend
|
||||||
else:
|
else:
|
||||||
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
|
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
|
positive_prompt = positive_conjunction.prompts[0]
|
||||||
negative_prompt_string
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
)
|
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
return positive_prompt, negative_prompt
|
return positive_prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -246,22 +237,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
logger.debug(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||||
if len(weighted_subprompts) <= 1:
|
if len(weighted_subprompts) <= 1:
|
||||||
return None
|
return None
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
strings = [x[0] for x in weighted_subprompts]
|
||||||
weights = [x[1] for x in weighted_subprompts]
|
|
||||||
|
|
||||||
pp = PromptParser()
|
pp = PromptParser()
|
||||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
flattened_prompts = []
|
||||||
|
weights = []
|
||||||
return Blend(
|
for i, x in enumerate(parsed_conjunctions):
|
||||||
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
|
if len(x.prompts)>0:
|
||||||
)
|
flattened_prompts.append(x.prompts[0])
|
||||||
|
weights.append(weighted_subprompts[i][1])
|
||||||
|
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from ..globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
@ -17,11 +17,11 @@ class CodeFormerRestoration:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||||
) -> None:
|
) -> None:
|
||||||
if not os.path.isabs(codeformer_dir):
|
|
||||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
|
||||||
|
|
||||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
self.globals = get_invokeai_config()
|
||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||||
|
self.model_path = codeformer_dir / codeformer_model_path
|
||||||
|
self.codeformer_model_exists = self.model_path.exists()
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
@ -71,9 +71,7 @@ class CodeFormerRestoration:
|
|||||||
upscale_factor=1,
|
upscale_factor=1,
|
||||||
use_parse=True,
|
use_parse=True,
|
||||||
device=device,
|
device=device,
|
||||||
model_rootpath=os.path.join(
|
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
|
||||||
Globals.root, "models", "gfpgan", "weights"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
face_helper.clean_all()
|
face_helper.clean_all()
|
||||||
face_helper.read_image(bgr_image_array)
|
face_helper.read_image(bgr_image_array)
|
||||||
|
@ -7,14 +7,13 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
|
self.globals = get_invokeai_config()
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
gfpgan_model_path = os.path.abspath(
|
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||||
os.path.join(Globals.root, gfpgan_model_path)
|
|
||||||
)
|
|
||||||
self.model_path = gfpgan_model_path
|
self.model_path = gfpgan_model_path
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
@ -33,7 +32,7 @@ class GFPGAN:
|
|||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
os.chdir(os.path.join(Globals.root, "models"))
|
os.chdir(self.globals.root_dir / 'models')
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,7 +6,8 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
@ -30,12 +30,8 @@ class ESRGAN:
|
|||||||
upscale=4,
|
upscale=4,
|
||||||
act_type="prelu",
|
act_type="prelu",
|
||||||
)
|
)
|
||||||
model_path = os.path.join(
|
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
|
||||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||||
)
|
|
||||||
wdn_model_path = os.path.join(
|
|
||||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
||||||
)
|
|
||||||
scale = 4
|
scale = 4
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
bg_upsampler = RealESRGANer(
|
||||||
|
@ -15,7 +15,7 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .globals import global_cache_dir
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
class SafetyChecker(object):
|
class SafetyChecker(object):
|
||||||
@ -26,10 +26,11 @@ class SafetyChecker(object):
|
|||||||
caution = Image.open(path)
|
caution = Image.open(path)
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
self.device = device
|
self.device = device
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_model_path = global_cache_dir("hub")
|
safety_model_path = config.cache_dir
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
safety_model_id,
|
safety_model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
@ -18,15 +18,15 @@ from huggingface_hub import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
"""
|
"""
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
"""
|
"""
|
||||||
self.root = root or Globals.root
|
self.config = get_invokeai_config()
|
||||||
|
self.root = root or self.config.root
|
||||||
self.hf_api = HfApi()
|
self.hf_api = HfApi()
|
||||||
self.local_concepts = dict()
|
self.local_concepts = dict()
|
||||||
self.concept_list = None
|
self.concept_list = None
|
||||||
@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
elif Globals.internet_available is True:
|
elif self.config.internet_available is True:
|
||||||
try:
|
try:
|
||||||
models = self.hf_api.list_models(
|
models = self.hf_api.list_models(
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||||
|
@ -33,8 +33,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
from ..util import CPU_DEVICE, normalize_device
|
from ..util import CPU_DEVICE, normalize_device
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
@ -44,7 +43,6 @@ from .diffusion import (
|
|||||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||||
from .textual_inversion_manager import TextualInversionManager
|
from .textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
run_id: str
|
run_id: str
|
||||||
@ -348,10 +346,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
if (
|
if (
|
||||||
torch.cuda.is_available()
|
torch.cuda.is_available()
|
||||||
and is_xformers_available()
|
and is_xformers_available()
|
||||||
and not Globals.disable_xformers
|
and not config.disable_xformers
|
||||||
):
|
):
|
||||||
self.enable_xformers_memory_efficient_attention()
|
self.enable_xformers_memory_efficient_attention()
|
||||||
else:
|
else:
|
||||||
@ -548,6 +547,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
|
@ -10,6 +10,7 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -352,8 +353,7 @@ def restore_default_cross_attention(
|
|||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||||
def override_cross_attention(model, context: Context, is_running_diffusers=False):
|
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -372,15 +372,13 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
indices = torch.arange(max_length, dtype=torch.long)
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
if b0 < max_length:
|
if b0 < max_length:
|
||||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
# these tokens have not been edited
|
# these tokens have not been edited
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
if is_running_diffusers:
|
|
||||||
unet = model
|
|
||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||||
@ -388,21 +386,8 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
else:
|
else:
|
||||||
# try to re-use an existing slice size
|
# try to re-use an existing slice size
|
||||||
default_slice_size = 4
|
default_slice_size = 4
|
||||||
slice_size = next(
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
(
|
|
||||||
p.slice_size
|
|
||||||
for p in old_attn_processors.values()
|
|
||||||
if type(p) is SlicedAttnProcessor
|
|
||||||
),
|
|
||||||
default_slice_size,
|
|
||||||
)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
return old_attn_processors
|
|
||||||
else:
|
|
||||||
context.register_cross_attention_modules(model)
|
|
||||||
inject_attention_function(model, context)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(
|
def get_cross_attention_modules(
|
||||||
model, which: CrossAttentionType
|
model, which: CrossAttentionType
|
||||||
|
@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.attention_processor import AttentionProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
Arguments,
|
||||||
@ -17,8 +18,8 @@ from .cross_attention_control import (
|
|||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
override_cross_attention,
|
|
||||||
restore_default_cross_attention,
|
restore_default_cross_attention,
|
||||||
|
setup_cross_attention_control_attention_processors,
|
||||||
)
|
)
|
||||||
from .cross_attention_map_saving import AttentionMapSaver
|
from .cross_attention_map_saving import AttentionMapSaver
|
||||||
|
|
||||||
@ -31,7 +32,6 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
@ -72,31 +72,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
self.is_running_diffusers = is_running_diffusers
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = Globals.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
cls,
|
||||||
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
|
step_count: int
|
||||||
):
|
):
|
||||||
do_swap = (
|
old_attn_processors = None
|
||||||
extra_conditioning_info is not None
|
if extra_conditioning_info and (
|
||||||
and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info.wants_cross_attention_control
|
||||||
|
):
|
||||||
|
old_attn_processors = unet.attn_processors
|
||||||
|
# Load lora conditions into the model
|
||||||
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
cross_attention_control_context = Context(
|
||||||
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
)
|
)
|
||||||
old_attn_processor = None
|
setup_cross_attention_control_attention_processors(
|
||||||
if do_swap:
|
unet,
|
||||||
old_attn_processor = self.override_cross_attention(
|
cross_attention_control_context,
|
||||||
extra_conditioning_info, step_count=step_count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if old_attn_processor is not None:
|
if old_attn_processors is not None:
|
||||||
self.restore_default_cross_attention(old_attn_processor)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ SCHEDULER_MAP = dict(
|
|||||||
deis=(DEISMultistepScheduler, dict()),
|
deis=(DEISMultistepScheduler, dict()),
|
||||||
lms=(LMSDiscreteScheduler, dict()),
|
lms=(LMSDiscreteScheduler, dict()),
|
||||||
pndm=(PNDMScheduler, dict()),
|
pndm=(PNDMScheduler, dict()),
|
||||||
heun=(HeunDiscreteScheduler, dict()),
|
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||||
|
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
This is the backend to "textual_inversion.py"
|
This is the backend to "textual_inversion.py"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -47,8 +46,7 @@ from tqdm.auto import tqdm
|
|||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# invokeai stuff
|
# invokeai stuff
|
||||||
from ..args import ArgFormatter, PagingArgumentParser
|
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||||
from ..globals import Globals, global_cache_dir
|
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
@ -90,8 +88,9 @@ def save_progress(
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
config = InvokeAIAppConfig(argv=[])
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
description="Textual inversion training", formatter_class=ArgFormatter
|
description="Textual inversion training"
|
||||||
)
|
)
|
||||||
general_group = parser.add_argument_group("General")
|
general_group = parser.add_argument_group("General")
|
||||||
model_group = parser.add_argument_group("Models and Paths")
|
model_group = parser.add_argument_group("Models and Paths")
|
||||||
@ -112,7 +111,7 @@ def parse_args():
|
|||||||
"--root_dir",
|
"--root_dir",
|
||||||
"--root",
|
"--root",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Globals.root,
|
default=config.root,
|
||||||
help="Path to the invokeai runtime directory",
|
help="Path to the invokeai runtime directory",
|
||||||
)
|
)
|
||||||
general_group.add_argument(
|
general_group.add_argument(
|
||||||
@ -127,7 +126,7 @@ def parse_args():
|
|||||||
general_group.add_argument(
|
general_group.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=f"{Globals.root}/text-inversion-model",
|
default=f"{config.root}/text-inversion-model",
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
@ -528,6 +527,7 @@ def get_full_repo_name(
|
|||||||
|
|
||||||
|
|
||||||
def do_textual_inversion_training(
|
def do_textual_inversion_training(
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
model: str,
|
model: str,
|
||||||
train_data_dir: Path,
|
train_data_dir: Path,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
@ -580,7 +580,7 @@ def do_textual_inversion_training(
|
|||||||
|
|
||||||
# setting up things the way invokeai expects them
|
# setting up things the way invokeai expects them
|
||||||
if not os.path.isabs(output_dir):
|
if not os.path.isabs(output_dir):
|
||||||
output_dir = os.path.join(Globals.root, output_dir)
|
output_dir = os.path.join(config.root, output_dir)
|
||||||
|
|
||||||
logging_dir = output_dir / logging_dir
|
logging_dir = output_dir / logging_dir
|
||||||
|
|
||||||
@ -628,7 +628,7 @@ def do_textual_inversion_training(
|
|||||||
elif output_dir is not None:
|
elif output_dir is not None:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
models_conf = OmegaConf.load(config.model_conf_path)
|
||||||
model_conf = models_conf.get(model, None)
|
model_conf = models_conf.get(model, None)
|
||||||
assert model_conf is not None, f"Unknown model: {model}"
|
assert model_conf is not None, f"Unknown model: {model}"
|
||||||
assert (
|
assert (
|
||||||
@ -640,7 +640,7 @@ def do_textual_inversion_training(
|
|||||||
assert (
|
assert (
|
||||||
pretrained_model_name_or_path
|
pretrained_model_name_or_path
|
||||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||||
pipeline_args = dict(cache_dir=global_cache_dir("hub"))
|
pipeline_args = dict(cache_dir=config.cache_dir)
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
|
@ -4,17 +4,16 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
if Globals.always_use_cpu:
|
config = get_invokeai_config()
|
||||||
|
if config.always_use_cpu:
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device("cuda")
|
return torch.device("cuda")
|
||||||
@ -33,7 +32,8 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
if Globals.full_precision:
|
config = get_invokeai_config()
|
||||||
|
if config.full_precision:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
if choose_precision(device) == "float16":
|
if choose_precision(device) == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
@ -2,34 +2,37 @@
|
|||||||
|
|
||||||
"""invokeai.util.logging
|
"""invokeai.util.logging
|
||||||
|
|
||||||
Logging class for InvokeAI that produces console messages that follow
|
Logging class for InvokeAI that produces console messages
|
||||||
the conventions established in InvokeAI 1.X through 2.X.
|
|
||||||
|
|
||||||
|
Usage:
|
||||||
One way to use it:
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger(__name__)
|
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||||
logger.critical('this is critical')
|
(or)
|
||||||
logger.error('this is an error')
|
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||||
logger.warning('this is a warning')
|
|
||||||
logger.info('this is info')
|
logger.critical('this is critical') // Critical Message
|
||||||
logger.debug('this is debugging')
|
logger.error('this is an error') // Error Message
|
||||||
|
logger.warning('this is a warning') // Warning Message
|
||||||
|
logger.info('this is info') // Info Message
|
||||||
|
logger.debug('this is debugging') // Debug Message
|
||||||
|
|
||||||
Console messages:
|
Console messages:
|
||||||
### this is critical
|
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
|
||||||
*** this is an error ***
|
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
|
||||||
** this is a warning
|
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
|
||||||
>> this is info
|
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
|
||||||
| this is debugging
|
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
|
||||||
|
|
||||||
Another way:
|
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||||
import invokeai.backend.util.logging as ialog
|
import invokeai.backend.util.logging as IAILogger
|
||||||
ialogger.debug('this is a debugging message')
|
IAILogger.debug('this is a debugging message')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
# module level functions
|
# module level functions
|
||||||
def debug(msg, *args, **kwargs):
|
def debug(msg, *args, **kwargs):
|
||||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
||||||
@ -55,49 +58,47 @@ def disable(level=logging.CRITICAL):
|
|||||||
def basicConfig(**kwargs):
|
def basicConfig(**kwargs):
|
||||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
||||||
|
|
||||||
def getLogger(name: str=None)->logging.Logger:
|
def getLogger(name: str = None) -> logging.Logger:
|
||||||
return InvokeAILogger.getLogger(name)
|
return InvokeAILogger.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
class InvokeAILogFormatter(logging.Formatter):
|
class InvokeAILogFormatter(logging.Formatter):
|
||||||
'''
|
'''
|
||||||
Repurposed from:
|
Custom Formatting for the InvokeAI Logger
|
||||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
|
||||||
'''
|
'''
|
||||||
crit_fmt = "### %(msg)s"
|
|
||||||
err_fmt = "*** %(msg)s"
|
|
||||||
warn_fmt = "** %(msg)s"
|
|
||||||
info_fmt = ">> %(msg)s"
|
|
||||||
dbg_fmt = " | %(msg)s"
|
|
||||||
|
|
||||||
def __init__(self):
|
# Color Codes
|
||||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
grey = "\x1b[38;20m"
|
||||||
|
yellow = "\x1b[33;20m"
|
||||||
|
red = "\x1b[31;20m"
|
||||||
|
cyan = "\x1b[36;20m"
|
||||||
|
bold_red = "\x1b[31;1m"
|
||||||
|
reset = "\x1b[0m"
|
||||||
|
|
||||||
|
# Log Format
|
||||||
|
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||||
|
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
||||||
|
|
||||||
|
# Format Map
|
||||||
|
FORMATS = {
|
||||||
|
logging.DEBUG: cyan + format + reset,
|
||||||
|
logging.INFO: grey + format + reset,
|
||||||
|
logging.WARNING: yellow + format + reset,
|
||||||
|
logging.ERROR: red + format + reset,
|
||||||
|
logging.CRITICAL: bold_red + format + reset
|
||||||
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
# Remember the format used when the logging module
|
log_fmt = self.FORMATS.get(record.levelno)
|
||||||
# was installed (in the event that this formatter is
|
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
||||||
# used with the vanilla logging module.
|
return formatter.format(record)
|
||||||
format_orig = self._style._fmt
|
|
||||||
if record.levelno == logging.DEBUG:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
|
||||||
if record.levelno == logging.INFO:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
|
||||||
if record.levelno == logging.WARNING:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
|
||||||
if record.levelno == logging.ERROR:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
|
||||||
if record.levelno == logging.CRITICAL:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
|
||||||
|
|
||||||
# parent class does the work
|
|
||||||
result = super().format(record)
|
|
||||||
self._style._fmt = format_orig
|
|
||||||
return result
|
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
class InvokeAILogger(object):
|
||||||
loggers = dict()
|
loggers = dict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
|
||||||
if name not in self.loggers:
|
if name not in self.loggers:
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
@ -9,6 +9,7 @@ SAMPLER_CHOICES = [
|
|||||||
"lms",
|
"lms",
|
||||||
"pndm",
|
"pndm",
|
||||||
"heun",
|
"heun",
|
||||||
|
'heun_k',
|
||||||
"euler",
|
"euler",
|
||||||
"euler_k",
|
"euler_k",
|
||||||
"euler_a",
|
"euler_a",
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,498 +0,0 @@
|
|||||||
"""
|
|
||||||
Readline helper functions for invoke.py.
|
|
||||||
You may import the global singleton `completer` to get access to the
|
|
||||||
completer object itself. This is useful when you want to autocomplete
|
|
||||||
seeds:
|
|
||||||
|
|
||||||
from invokeai.frontend.CLI.readline import completer
|
|
||||||
completer.add_seed(18247566)
|
|
||||||
completer.add_seed(9281839)
|
|
||||||
"""
|
|
||||||
import atexit
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
from ...backend.args import Args
|
|
||||||
from ...backend.globals import Globals
|
|
||||||
from ...backend.stable_diffusion import HuggingFaceConceptsLibrary
|
|
||||||
|
|
||||||
# ---------------readline utilities---------------------
|
|
||||||
try:
|
|
||||||
import readline
|
|
||||||
|
|
||||||
readline_available = True
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
|
||||||
print(f"** An error occurred when loading the readline module: {str(e)}")
|
|
||||||
readline_available = False
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG", ".gif", ".GIF")
|
|
||||||
WEIGHT_EXTENSIONS = (".ckpt", ".vae", ".safetensors")
|
|
||||||
TEXT_EXTENSIONS = (".txt", ".TXT")
|
|
||||||
CONFIG_EXTENSIONS = (".yaml", ".yml")
|
|
||||||
COMMANDS = (
|
|
||||||
"--steps",
|
|
||||||
"-s",
|
|
||||||
"--seed",
|
|
||||||
"-S",
|
|
||||||
"--iterations",
|
|
||||||
"-n",
|
|
||||||
"--width",
|
|
||||||
"-W",
|
|
||||||
"--height",
|
|
||||||
"-H",
|
|
||||||
"--cfg_scale",
|
|
||||||
"-C",
|
|
||||||
"--threshold",
|
|
||||||
"--perlin",
|
|
||||||
"--grid",
|
|
||||||
"-g",
|
|
||||||
"--individual",
|
|
||||||
"-i",
|
|
||||||
"--save_intermediates",
|
|
||||||
"--init_img",
|
|
||||||
"-I",
|
|
||||||
"--init_mask",
|
|
||||||
"-M",
|
|
||||||
"--init_color",
|
|
||||||
"--strength",
|
|
||||||
"-f",
|
|
||||||
"--variants",
|
|
||||||
"-v",
|
|
||||||
"--outdir",
|
|
||||||
"-o",
|
|
||||||
"--sampler",
|
|
||||||
"-A",
|
|
||||||
"-m",
|
|
||||||
"--embedding_path",
|
|
||||||
"--device",
|
|
||||||
"--grid",
|
|
||||||
"-g",
|
|
||||||
"--facetool",
|
|
||||||
"-ft",
|
|
||||||
"--facetool_strength",
|
|
||||||
"-G",
|
|
||||||
"--codeformer_fidelity",
|
|
||||||
"-cf",
|
|
||||||
"--upscale",
|
|
||||||
"-U",
|
|
||||||
"-save_orig",
|
|
||||||
"--save_original",
|
|
||||||
"--log_tokenization",
|
|
||||||
"-t",
|
|
||||||
"--hires_fix",
|
|
||||||
"--inpaint_replace",
|
|
||||||
"-r",
|
|
||||||
"--png_compression",
|
|
||||||
"-z",
|
|
||||||
"--text_mask",
|
|
||||||
"-tm",
|
|
||||||
"--h_symmetry_time_pct",
|
|
||||||
"--v_symmetry_time_pct",
|
|
||||||
"!fix",
|
|
||||||
"!fetch",
|
|
||||||
"!replay",
|
|
||||||
"!history",
|
|
||||||
"!search",
|
|
||||||
"!clear",
|
|
||||||
"!models",
|
|
||||||
"!switch",
|
|
||||||
"!import_model",
|
|
||||||
"!optimize_model",
|
|
||||||
"!convert_model",
|
|
||||||
"!edit_model",
|
|
||||||
"!del_model",
|
|
||||||
"!mask",
|
|
||||||
"!triggers",
|
|
||||||
)
|
|
||||||
MODEL_COMMANDS = (
|
|
||||||
"!switch",
|
|
||||||
"!edit_model",
|
|
||||||
"!del_model",
|
|
||||||
)
|
|
||||||
CKPT_MODEL_COMMANDS = ("!optimize_model",)
|
|
||||||
WEIGHT_COMMANDS = (
|
|
||||||
"!import_model",
|
|
||||||
"!convert_model",
|
|
||||||
)
|
|
||||||
IMG_PATH_COMMANDS = ("--outdir[=\s]",)
|
|
||||||
TEXT_PATH_COMMANDS = ("!replay",)
|
|
||||||
IMG_FILE_COMMANDS = (
|
|
||||||
"!fix",
|
|
||||||
"!fetch",
|
|
||||||
"!mask",
|
|
||||||
"--init_img[=\s]",
|
|
||||||
"-I",
|
|
||||||
"--init_mask[=\s]",
|
|
||||||
"-M",
|
|
||||||
"--init_color[=\s]",
|
|
||||||
"--embedding_path[=\s]",
|
|
||||||
)
|
|
||||||
|
|
||||||
path_regexp = "(" + "|".join(IMG_PATH_COMMANDS + IMG_FILE_COMMANDS) + ")\s*\S*$"
|
|
||||||
weight_regexp = "(" + "|".join(WEIGHT_COMMANDS) + ")\s*\S*$"
|
|
||||||
text_regexp = "(" + "|".join(TEXT_PATH_COMMANDS) + ")\s*\S*$"
|
|
||||||
|
|
||||||
|
|
||||||
class Completer(object):
|
|
||||||
def __init__(self, options, models={}):
|
|
||||||
self.options = sorted(options)
|
|
||||||
self.models = models
|
|
||||||
self.seeds = set()
|
|
||||||
self.matches = list()
|
|
||||||
self.default_dir = None
|
|
||||||
self.linebuffer = None
|
|
||||||
self.auto_history_active = True
|
|
||||||
self.extensions = None
|
|
||||||
self.concepts = None
|
|
||||||
self.embedding_terms = set()
|
|
||||||
return
|
|
||||||
|
|
||||||
def complete(self, text, state):
|
|
||||||
"""
|
|
||||||
Completes invoke command line.
|
|
||||||
BUG: it doesn't correctly complete files that have spaces in the name.
|
|
||||||
"""
|
|
||||||
buffer = readline.get_line_buffer()
|
|
||||||
|
|
||||||
if state == 0:
|
|
||||||
# extensions defined, so go directly into path completion mode
|
|
||||||
if self.extensions is not None:
|
|
||||||
self.matches = self._path_completions(text, state, self.extensions)
|
|
||||||
|
|
||||||
# looking for an image file
|
|
||||||
elif re.search(path_regexp, buffer):
|
|
||||||
do_shortcut = re.search("^" + "|".join(IMG_FILE_COMMANDS), buffer)
|
|
||||||
self.matches = self._path_completions(
|
|
||||||
text, state, IMG_EXTENSIONS, shortcut_ok=do_shortcut
|
|
||||||
)
|
|
||||||
|
|
||||||
# looking for a seed
|
|
||||||
elif re.search("(-S\s*|--seed[=\s])\d*$", buffer):
|
|
||||||
self.matches = self._seed_completions(text, state)
|
|
||||||
|
|
||||||
# looking for an embedding concept
|
|
||||||
elif re.search("<[\w-]*$", buffer):
|
|
||||||
self.matches = self._concept_completions(text, state)
|
|
||||||
|
|
||||||
# looking for a model
|
|
||||||
elif re.match("^" + "|".join(MODEL_COMMANDS), buffer):
|
|
||||||
self.matches = self._model_completions(text, state)
|
|
||||||
|
|
||||||
# looking for a ckpt model
|
|
||||||
elif re.match("^" + "|".join(CKPT_MODEL_COMMANDS), buffer):
|
|
||||||
self.matches = self._model_completions(text, state, ckpt_only=True)
|
|
||||||
|
|
||||||
elif re.search(weight_regexp, buffer):
|
|
||||||
self.matches = self._path_completions(
|
|
||||||
text,
|
|
||||||
state,
|
|
||||||
WEIGHT_EXTENSIONS,
|
|
||||||
default_dir=Globals.root,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif re.search(text_regexp, buffer):
|
|
||||||
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
|
|
||||||
|
|
||||||
# This is the first time for this text, so build a match list.
|
|
||||||
elif text:
|
|
||||||
self.matches = [s for s in self.options if s and s.startswith(text)]
|
|
||||||
else:
|
|
||||||
self.matches = self.options[:]
|
|
||||||
|
|
||||||
# Return the state'th item from the match list,
|
|
||||||
# if we have that many.
|
|
||||||
try:
|
|
||||||
response = self.matches[state]
|
|
||||||
except IndexError:
|
|
||||||
response = None
|
|
||||||
return response
|
|
||||||
|
|
||||||
def complete_extensions(self, extensions: list):
|
|
||||||
"""
|
|
||||||
If called with a list of extensions, will force completer
|
|
||||||
to do file path completions.
|
|
||||||
"""
|
|
||||||
self.extensions = extensions
|
|
||||||
|
|
||||||
def add_history(self, line):
|
|
||||||
"""
|
|
||||||
Pass thru to readline
|
|
||||||
"""
|
|
||||||
if not self.auto_history_active:
|
|
||||||
readline.add_history(line)
|
|
||||||
|
|
||||||
def clear_history(self):
|
|
||||||
"""
|
|
||||||
Pass clear_history() thru to readline
|
|
||||||
"""
|
|
||||||
readline.clear_history()
|
|
||||||
|
|
||||||
def search_history(self, match: str):
|
|
||||||
"""
|
|
||||||
Like show_history() but only shows items that
|
|
||||||
contain the match string.
|
|
||||||
"""
|
|
||||||
self.show_history(match)
|
|
||||||
|
|
||||||
def remove_history_item(self, pos):
|
|
||||||
readline.remove_history_item(pos)
|
|
||||||
|
|
||||||
def add_seed(self, seed):
|
|
||||||
"""
|
|
||||||
Add a seed to the autocomplete list for display when -S is autocompleted.
|
|
||||||
"""
|
|
||||||
if seed is not None:
|
|
||||||
self.seeds.add(str(seed))
|
|
||||||
|
|
||||||
def set_default_dir(self, path):
|
|
||||||
self.default_dir = path
|
|
||||||
|
|
||||||
def set_options(self, options):
|
|
||||||
self.options = options
|
|
||||||
|
|
||||||
def get_line(self, index):
|
|
||||||
try:
|
|
||||||
line = self.get_history_item(index)
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
return line
|
|
||||||
|
|
||||||
def get_current_history_length(self):
|
|
||||||
return readline.get_current_history_length()
|
|
||||||
|
|
||||||
def get_history_item(self, index):
|
|
||||||
return readline.get_history_item(index)
|
|
||||||
|
|
||||||
def show_history(self, match=None):
|
|
||||||
"""
|
|
||||||
Print the session history using the pydoc pager
|
|
||||||
"""
|
|
||||||
import pydoc
|
|
||||||
|
|
||||||
lines = list()
|
|
||||||
h_len = self.get_current_history_length()
|
|
||||||
if h_len < 1:
|
|
||||||
print("<empty history>")
|
|
||||||
return
|
|
||||||
|
|
||||||
for i in range(0, h_len):
|
|
||||||
line = self.get_history_item(i + 1)
|
|
||||||
if match and match not in line:
|
|
||||||
continue
|
|
||||||
lines.append(f"[{i+1}] {line}")
|
|
||||||
pydoc.pager("\n".join(lines))
|
|
||||||
|
|
||||||
def set_line(self, line) -> None:
|
|
||||||
"""
|
|
||||||
Set the default string displayed in the next line of input.
|
|
||||||
"""
|
|
||||||
self.linebuffer = line
|
|
||||||
readline.redisplay()
|
|
||||||
|
|
||||||
def update_models(self, models: dict) -> None:
|
|
||||||
"""
|
|
||||||
update our list of models
|
|
||||||
"""
|
|
||||||
self.models = models
|
|
||||||
|
|
||||||
def _seed_completions(self, text, state):
|
|
||||||
m = re.search("(-S\s?|--seed[=\s]?)(\d*)", text)
|
|
||||||
if m:
|
|
||||||
switch = m.groups()[0]
|
|
||||||
partial = m.groups()[1]
|
|
||||||
else:
|
|
||||||
switch = ""
|
|
||||||
partial = text
|
|
||||||
|
|
||||||
matches = list()
|
|
||||||
for s in self.seeds:
|
|
||||||
if s.startswith(partial):
|
|
||||||
matches.append(switch + s)
|
|
||||||
matches.sort()
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def add_embedding_terms(self, terms: list[str]):
|
|
||||||
self.embedding_terms = set(terms)
|
|
||||||
if self.concepts:
|
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
|
||||||
|
|
||||||
def _concept_completions(self, text, state):
|
|
||||||
if self.concepts is None:
|
|
||||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
|
||||||
self.concepts = HuggingFaceConceptsLibrary()
|
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
|
||||||
else:
|
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
|
||||||
|
|
||||||
partial = text[1:] # this removes the leading '<'
|
|
||||||
if len(partial) == 0:
|
|
||||||
return list(self.embedding_terms) # whole dump - think if user wants this!
|
|
||||||
|
|
||||||
matches = list()
|
|
||||||
for concept in self.embedding_terms:
|
|
||||||
if concept.startswith(partial):
|
|
||||||
matches.append(f"<{concept}>")
|
|
||||||
matches.sort()
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def _model_completions(self, text, state, ckpt_only=False):
|
|
||||||
m = re.search("(!switch\s+)(\w*)", text)
|
|
||||||
if m:
|
|
||||||
switch = m.groups()[0]
|
|
||||||
partial = m.groups()[1]
|
|
||||||
else:
|
|
||||||
switch = ""
|
|
||||||
partial = text
|
|
||||||
matches = list()
|
|
||||||
for s in self.models:
|
|
||||||
name = self.models[s]["model_name"]
|
|
||||||
format = self.models[s]["format"]
|
|
||||||
if format == "vae":
|
|
||||||
continue
|
|
||||||
if ckpt_only and format != "ckpt":
|
|
||||||
continue
|
|
||||||
if name.startswith(partial):
|
|
||||||
matches.append(switch + name)
|
|
||||||
matches.sort()
|
|
||||||
return matches
|
|
||||||
|
|
||||||
def _pre_input_hook(self):
|
|
||||||
if self.linebuffer:
|
|
||||||
readline.insert_text(self.linebuffer)
|
|
||||||
readline.redisplay()
|
|
||||||
self.linebuffer = None
|
|
||||||
|
|
||||||
def _path_completions(
|
|
||||||
self, text, state, extensions, shortcut_ok=True, default_dir: str = ""
|
|
||||||
):
|
|
||||||
# separate the switch from the partial path
|
|
||||||
match = re.search("^(-\w|--\w+=?)(.*)", text)
|
|
||||||
if match is None:
|
|
||||||
switch = None
|
|
||||||
partial_path = text
|
|
||||||
else:
|
|
||||||
switch, partial_path = match.groups()
|
|
||||||
|
|
||||||
partial_path = partial_path.lstrip()
|
|
||||||
|
|
||||||
matches = list()
|
|
||||||
path = os.path.expanduser(partial_path)
|
|
||||||
|
|
||||||
if os.path.isdir(path):
|
|
||||||
dir = path
|
|
||||||
elif os.path.dirname(path) != "":
|
|
||||||
dir = os.path.dirname(path)
|
|
||||||
else:
|
|
||||||
dir = default_dir if os.path.exists(default_dir) else ""
|
|
||||||
path = os.path.join(dir, path)
|
|
||||||
|
|
||||||
dir_list = os.listdir(dir or ".")
|
|
||||||
if shortcut_ok and os.path.exists(self.default_dir) and dir == "":
|
|
||||||
dir_list += os.listdir(self.default_dir)
|
|
||||||
|
|
||||||
for node in dir_list:
|
|
||||||
if node.startswith(".") and len(node) > 1:
|
|
||||||
continue
|
|
||||||
full_path = os.path.join(dir, node)
|
|
||||||
|
|
||||||
if not (node.endswith(extensions) or os.path.isdir(full_path)):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if path and not full_path.startswith(path):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if switch is None:
|
|
||||||
match_path = os.path.join(dir, node)
|
|
||||||
matches.append(
|
|
||||||
match_path + "/" if os.path.isdir(full_path) else match_path
|
|
||||||
)
|
|
||||||
elif os.path.isdir(full_path):
|
|
||||||
matches.append(
|
|
||||||
switch + os.path.join(os.path.dirname(full_path), node) + "/"
|
|
||||||
)
|
|
||||||
elif node.endswith(extensions):
|
|
||||||
matches.append(switch + os.path.join(os.path.dirname(full_path), node))
|
|
||||||
|
|
||||||
return matches
|
|
||||||
|
|
||||||
|
|
||||||
class DummyCompleter(Completer):
|
|
||||||
def __init__(self, options):
|
|
||||||
super().__init__(options)
|
|
||||||
self.history = list()
|
|
||||||
|
|
||||||
def add_history(self, line):
|
|
||||||
self.history.append(line)
|
|
||||||
|
|
||||||
def clear_history(self):
|
|
||||||
self.history = list()
|
|
||||||
|
|
||||||
def get_current_history_length(self):
|
|
||||||
return len(self.history)
|
|
||||||
|
|
||||||
def get_history_item(self, index):
|
|
||||||
return self.history[index - 1]
|
|
||||||
|
|
||||||
def remove_history_item(self, index):
|
|
||||||
return self.history.pop(index - 1)
|
|
||||||
|
|
||||||
def set_line(self, line):
|
|
||||||
print(f"# {line}")
|
|
||||||
|
|
||||||
|
|
||||||
def generic_completer(commands: list) -> Completer:
|
|
||||||
if readline_available:
|
|
||||||
completer = Completer(commands, [])
|
|
||||||
readline.set_completer(completer.complete)
|
|
||||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
|
||||||
readline.set_completer_delims(" ")
|
|
||||||
readline.parse_and_bind("tab: complete")
|
|
||||||
readline.parse_and_bind("set print-completions-horizontally off")
|
|
||||||
readline.parse_and_bind("set page-completions on")
|
|
||||||
readline.parse_and_bind("set skip-completed-text on")
|
|
||||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
|
||||||
else:
|
|
||||||
completer = DummyCompleter(commands)
|
|
||||||
return completer
|
|
||||||
|
|
||||||
|
|
||||||
def get_completer(opt: Args, models=[]) -> Completer:
|
|
||||||
if readline_available:
|
|
||||||
completer = Completer(COMMANDS, models)
|
|
||||||
|
|
||||||
readline.set_completer(completer.complete)
|
|
||||||
# pyreadline3 does not have a set_auto_history() method
|
|
||||||
try:
|
|
||||||
readline.set_auto_history(False)
|
|
||||||
completer.auto_history_active = False
|
|
||||||
except:
|
|
||||||
completer.auto_history_active = True
|
|
||||||
readline.set_pre_input_hook(completer._pre_input_hook)
|
|
||||||
readline.set_completer_delims(" ")
|
|
||||||
readline.parse_and_bind("tab: complete")
|
|
||||||
readline.parse_and_bind("set print-completions-horizontally off")
|
|
||||||
readline.parse_and_bind("set page-completions on")
|
|
||||||
readline.parse_and_bind("set skip-completed-text on")
|
|
||||||
readline.parse_and_bind("set show-all-if-ambiguous on")
|
|
||||||
|
|
||||||
outdir = os.path.expanduser(opt.outdir)
|
|
||||||
if os.path.isabs(outdir):
|
|
||||||
histfile = os.path.join(outdir, ".invoke_history")
|
|
||||||
else:
|
|
||||||
histfile = os.path.join(Globals.root, outdir, ".invoke_history")
|
|
||||||
try:
|
|
||||||
readline.read_history_file(histfile)
|
|
||||||
readline.set_history_length(1000)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
except OSError: # file likely corrupted
|
|
||||||
newname = f"{histfile}.old"
|
|
||||||
print(
|
|
||||||
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
|
||||||
)
|
|
||||||
os.replace(histfile, newname)
|
|
||||||
atexit.register(readline.write_history_file, histfile)
|
|
||||||
|
|
||||||
else:
|
|
||||||
completer = DummyCompleter(COMMANDS)
|
|
||||||
return completer
|
|
@ -1,30 +0,0 @@
|
|||||||
'''
|
|
||||||
This is a modularized version of the sd-metadata.py script,
|
|
||||||
which retrieves and prints the metadata from a series of generated png files.
|
|
||||||
'''
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from invokeai.backend.image_util import retrieve_metadata
|
|
||||||
|
|
||||||
|
|
||||||
def print_metadata():
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
|
||||||
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
filenames = sys.argv[1:]
|
|
||||||
for f in filenames:
|
|
||||||
try:
|
|
||||||
metadata = retrieve_metadata(f)
|
|
||||||
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
|
|
||||||
except FileNotFoundError:
|
|
||||||
sys.stderr.write(f'{f} not found\n')
|
|
||||||
continue
|
|
||||||
except PermissionError:
|
|
||||||
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if __name__== '__main__':
|
|
||||||
print_metadata()
|
|
||||||
|
|
@ -23,7 +23,6 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals, global_config_dir
|
|
||||||
|
|
||||||
from ...backend.config.model_install_backend import (
|
from ...backend.config.model_install_backend import (
|
||||||
Dataset_path,
|
Dataset_path,
|
||||||
@ -41,11 +40,13 @@ from .widgets import (
|
|||||||
TextBox,
|
TextBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
# minimum size for the UI
|
# minimum size for the UI
|
||||||
MIN_COLS = 120
|
MIN_COLS = 120
|
||||||
MIN_LINES = 45
|
MIN_LINES = 45
|
||||||
|
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
class addModelsForm(npyscreen.FormMultiPage):
|
class addModelsForm(npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing - disabled
|
||||||
@ -453,9 +454,9 @@ def main():
|
|||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
# setting a global here
|
# setting a global here
|
||||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
config.root = os.path.expanduser(get_root(opt.root) or "")
|
||||||
|
|
||||||
if not global_config_dir().exists():
|
if not (config.conf_path / '..' ).exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,6 @@ import argparse
|
|||||||
import curses
|
import curses
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -20,20 +19,13 @@ from diffusers import logging as dlogging
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ...backend.globals import (
|
|
||||||
Globals,
|
|
||||||
global_cache_dir,
|
|
||||||
global_config_file,
|
|
||||||
global_models_dir,
|
|
||||||
global_set_root,
|
|
||||||
)
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.services.config import get_invokeai_config
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
model_ids_or_paths: List[Union[str, Path]],
|
||||||
@ -60,7 +52,7 @@ def merge_diffusion_models(
|
|||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
model_ids_or_paths[0],
|
model_ids_or_paths[0],
|
||||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||||
custom_pipeline="checkpoint_merger",
|
custom_pipeline="checkpoint_merger",
|
||||||
)
|
)
|
||||||
merged_pipe = pipe.merge(
|
merged_pipe = pipe.merge(
|
||||||
@ -94,7 +86,7 @@ def merge_diffusion_models_and_commit(
|
|||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
config_file = global_config_file()
|
config_file = config.model_conf_path
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||||
for mod in models:
|
for mod in models:
|
||||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||||
@ -106,7 +98,7 @@ def merge_diffusion_models_and_commit(
|
|||||||
merged_pipe = merge_diffusion_models(
|
merged_pipe = merge_diffusion_models(
|
||||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||||
)
|
)
|
||||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||||
|
|
||||||
os.makedirs(dump_path, exist_ok=True)
|
os.makedirs(dump_path, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
@ -126,7 +118,7 @@ def _parse_args() -> Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root_dir",
|
"--root_dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Globals.root,
|
default=config.root,
|
||||||
help="Path to the invokeai runtime directory",
|
help="Path to the invokeai runtime directory",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -398,7 +390,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
conf = OmegaConf.load(global_config_file())
|
conf = OmegaConf.load(config.model_conf_path)
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
conf, "cpu", "float16"
|
conf, "cpu", "float16"
|
||||||
) # precision doesn't really matter here
|
) # precision doesn't really matter here
|
||||||
@ -429,7 +421,7 @@ def run_cli(args: Namespace):
|
|||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
|
||||||
assert (
|
assert (
|
||||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
@ -440,9 +432,9 @@ def run_cli(args: Namespace):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
global_set_root(args.root_dir)
|
config.root = args.root_dir
|
||||||
|
|
||||||
cache_dir = str(global_cache_dir("hub"))
|
cache_dir = config.cache_dir
|
||||||
os.environ[
|
os.environ[
|
||||||
"HF_HOME"
|
"HF_HOME"
|
||||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||||
|
@ -21,14 +21,17 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals, global_set_root
|
|
||||||
|
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
from ...backend.training import (
|
||||||
|
do_textual_inversion_training,
|
||||||
|
parse_args
|
||||||
|
)
|
||||||
|
|
||||||
TRAINING_DATA = "text-inversion-training-data"
|
TRAINING_DATA = "text-inversion-training-data"
|
||||||
TRAINING_DIR = "text-inversion-output"
|
TRAINING_DIR = "text-inversion-output"
|
||||||
CONF_FILE = "preferences.conf"
|
CONF_FILE = "preferences.conf"
|
||||||
|
config = None
|
||||||
|
|
||||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||||
resolutions = [512, 768, 1024]
|
resolutions = [512, 768, 1024]
|
||||||
@ -122,7 +125,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
value=str(
|
value=str(
|
||||||
saved_args.get(
|
saved_args.get(
|
||||||
"train_data_dir",
|
"train_data_dir",
|
||||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
config.root_dir / TRAINING_DATA / default_placeholder_token,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
@ -135,7 +138,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
value=str(
|
value=str(
|
||||||
saved_args.get(
|
saved_args.get(
|
||||||
"output_dir",
|
"output_dir",
|
||||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
config.root_dir / TRAINING_DIR / default_placeholder_token,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
@ -241,9 +244,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
placeholder = self.placeholder_token.value
|
placeholder = self.placeholder_token.value
|
||||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||||
self.train_data_dir.value = str(
|
self.train_data_dir.value = str(
|
||||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
config.root_dir / TRAINING_DATA / placeholder
|
||||||
)
|
)
|
||||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
|
||||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
@ -284,7 +287,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self) -> Tuple[List[str], int]:
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||||
model_names = [
|
model_names = [
|
||||||
idx
|
idx
|
||||||
for idx in sorted(list(conf.keys()))
|
for idx in sorted(list(conf.keys()))
|
||||||
@ -367,7 +370,7 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
"""
|
"""
|
||||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
shutil.copy(source, destination)
|
shutil.copy(source, destination)
|
||||||
@ -383,7 +386,7 @@ def save_args(args: dict):
|
|||||||
"""
|
"""
|
||||||
Save the current argument values to an omegaconf file
|
Save the current argument values to an omegaconf file
|
||||||
"""
|
"""
|
||||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
dest_dir = config.root_dir / TRAINING_DIR
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
conf_file = dest_dir / CONF_FILE
|
conf_file = dest_dir / CONF_FILE
|
||||||
conf = OmegaConf.create(args)
|
conf = OmegaConf.create(args)
|
||||||
@ -394,7 +397,7 @@ def previous_args() -> dict:
|
|||||||
"""
|
"""
|
||||||
Get the previous arguments used.
|
Get the previous arguments used.
|
||||||
"""
|
"""
|
||||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(conf_file)
|
conf = OmegaConf.load(conf_file)
|
||||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||||
@ -420,7 +423,7 @@ def do_front_end(args: Namespace):
|
|||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(get_invokeai_config(),**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
@ -430,13 +433,20 @@ def do_front_end(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
global config
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
global_set_root(args.root_dir or Globals.root)
|
config = get_invokeai_config(argv=[])
|
||||||
|
|
||||||
|
# change root if needed
|
||||||
|
if args.root_dir:
|
||||||
|
config.root = args.root_dir
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
do_front_end(args)
|
do_front_end(args)
|
||||||
else:
|
else:
|
||||||
do_textual_inversion_training(**vars(args))
|
do_textual_inversion_training(config,**vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener
|
|||||||
### Patch `@chakra-ui/cli`
|
### Patch `@chakra-ui/cli`
|
||||||
|
|
||||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||||
|
|
||||||
### Patch `redux-persist`
|
|
||||||
|
|
||||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
|
||||||
|
|
||||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
|
||||||
|
|
||||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
|
||||||
|
|
||||||
### Patch `redux-deep-persist`
|
|
||||||
|
|
||||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
|
||||||
|
@ -62,11 +62,13 @@
|
|||||||
"@dagrejs/graphlib": "^2.1.12",
|
"@dagrejs/graphlib": "^2.1.12",
|
||||||
"@emotion/react": "^11.10.6",
|
"@emotion/react": "^11.10.6",
|
||||||
"@emotion/styled": "^11.10.6",
|
"@emotion/styled": "^11.10.6",
|
||||||
|
"@floating-ui/react-dom": "^2.0.0",
|
||||||
"@fontsource/inter": "^4.5.15",
|
"@fontsource/inter": "^4.5.15",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
"chakra-ui-contextmenu": "^1.0.5",
|
"chakra-ui-contextmenu": "^1.0.5",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
|
"downshift": "^7.6.0",
|
||||||
"formik": "^2.2.9",
|
"formik": "^2.2.9",
|
||||||
"framer-motion": "^10.12.4",
|
"framer-motion": "^10.12.4",
|
||||||
"fuse.js": "^6.6.2",
|
"fuse.js": "^6.6.2",
|
||||||
@ -87,18 +89,13 @@
|
|||||||
"react-i18next": "^12.2.2",
|
"react-i18next": "^12.2.2",
|
||||||
"react-icons": "^4.7.1",
|
"react-icons": "^4.7.1",
|
||||||
"react-konva": "^18.2.7",
|
"react-konva": "^18.2.7",
|
||||||
"react-konva-utils": "^1.0.4",
|
|
||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.0.5",
|
||||||
"react-resizable-panels": "^0.0.42",
|
"react-resizable-panels": "^0.0.42",
|
||||||
"react-rnd": "^10.4.1",
|
|
||||||
"react-transition-group": "^4.4.5",
|
|
||||||
"react-use": "^17.4.0",
|
"react-use": "^17.4.0",
|
||||||
"react-virtuoso": "^4.3.5",
|
"react-virtuoso": "^4.3.5",
|
||||||
"react-zoom-pan-pinch": "^3.0.7",
|
"react-zoom-pan-pinch": "^3.0.7",
|
||||||
"reactflow": "^11.7.0",
|
"reactflow": "^11.7.0",
|
||||||
"redux-deep-persist": "^1.0.7",
|
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-persist": "^6.0.0",
|
|
||||||
"redux-remember": "^3.3.1",
|
"redux-remember": "^3.3.1",
|
||||||
"roarr": "^7.15.0",
|
"roarr": "^7.15.0",
|
||||||
"serialize-error": "^11.0.0",
|
"serialize-error": "^11.0.0",
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts
|
|
||||||
index b67b8c2..7fc0fa1 100644
|
|
||||||
--- a/node_modules/redux-deep-persist/lib/types.d.ts
|
|
||||||
+++ b/node_modules/redux-deep-persist/lib/types.d.ts
|
|
||||||
@@ -35,6 +35,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
|
||||||
whitelist?: Array<string>;
|
|
||||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
|
||||||
throttle?: number;
|
|
||||||
+ debounce?: number;
|
|
||||||
migrate?: PersistMigrate;
|
|
||||||
stateReconciler?: false | StateReconciler<S>;
|
|
||||||
getStoredState?: (config: PersistConfig<S, RS, HSS, ESS>) => Promise<PersistedState>;
|
|
||||||
diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts
|
|
||||||
index 398ac19..cbc5663 100644
|
|
||||||
--- a/node_modules/redux-deep-persist/src/types.ts
|
|
||||||
+++ b/node_modules/redux-deep-persist/src/types.ts
|
|
||||||
@@ -91,6 +91,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
|
|
||||||
whitelist?: Array<string>;
|
|
||||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
|
||||||
throttle?: number;
|
|
||||||
+ debounce?: number;
|
|
||||||
migrate?: PersistMigrate;
|
|
||||||
stateReconciler?: false | StateReconciler<S>;
|
|
||||||
/**
|
|
@ -1,116 +0,0 @@
|
|||||||
diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js
|
|
||||||
index 8b43b9a..184faab 100644
|
|
||||||
--- a/node_modules/redux-persist/es/createPersistoid.js
|
|
||||||
+++ b/node_modules/redux-persist/es/createPersistoid.js
|
|
||||||
@@ -6,6 +6,7 @@ export default function createPersistoid(config) {
|
|
||||||
var whitelist = config.whitelist || null;
|
|
||||||
var transforms = config.transforms || [];
|
|
||||||
var throttle = config.throttle || 0;
|
|
||||||
+ var debounce = config.debounce || 0;
|
|
||||||
var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key);
|
|
||||||
var storage = config.storage;
|
|
||||||
var serialize;
|
|
||||||
@@ -28,30 +29,37 @@ export default function createPersistoid(config) {
|
|
||||||
var timeIterator = null;
|
|
||||||
var writePromise = null;
|
|
||||||
|
|
||||||
- var update = function update(state) {
|
|
||||||
- // add any changed keys to the queue
|
|
||||||
- Object.keys(state).forEach(function (key) {
|
|
||||||
- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
|
||||||
+ // Timer for debounced `update()`
|
|
||||||
+ let timer = 0;
|
|
||||||
|
|
||||||
- if (lastState[key] === state[key]) return; // value unchanged? noop
|
|
||||||
+ function update(state) {
|
|
||||||
+ // Debounce the update
|
|
||||||
+ clearTimeout(timer);
|
|
||||||
+ timer = setTimeout(() => {
|
|
||||||
+ // add any changed keys to the queue
|
|
||||||
+ Object.keys(state).forEach(function (key) {
|
|
||||||
+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
|
|
||||||
|
|
||||||
- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
|
||||||
+ if (lastState[key] === state[key]) return; // value unchanged? noop
|
|
||||||
|
|
||||||
- keysToProcess.push(key); // add key to queue
|
|
||||||
- }); //if any key is missing in the new state which was present in the lastState,
|
|
||||||
- //add it for processing too
|
|
||||||
+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
|
|
||||||
|
|
||||||
- Object.keys(lastState).forEach(function (key) {
|
|
||||||
- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
|
||||||
- keysToProcess.push(key);
|
|
||||||
- }
|
|
||||||
- }); // start the time iterator if not running (read: throttle)
|
|
||||||
+ keysToProcess.push(key); // add key to queue
|
|
||||||
+ }); //if any key is missing in the new state which was present in the lastState,
|
|
||||||
+ //add it for processing too
|
|
||||||
|
|
||||||
- if (timeIterator === null) {
|
|
||||||
- timeIterator = setInterval(processNextKey, throttle);
|
|
||||||
- }
|
|
||||||
+ Object.keys(lastState).forEach(function (key) {
|
|
||||||
+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
|
|
||||||
+ keysToProcess.push(key);
|
|
||||||
+ }
|
|
||||||
+ }); // start the time iterator if not running (read: throttle)
|
|
||||||
+
|
|
||||||
+ if (timeIterator === null) {
|
|
||||||
+ timeIterator = setInterval(processNextKey, throttle);
|
|
||||||
+ }
|
|
||||||
|
|
||||||
- lastState = state;
|
|
||||||
+ lastState = state;
|
|
||||||
+ }, debounce)
|
|
||||||
};
|
|
||||||
|
|
||||||
function processNextKey() {
|
|
||||||
diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow
|
|
||||||
index c50d3cd..39d8be2 100644
|
|
||||||
--- a/node_modules/redux-persist/es/types.js.flow
|
|
||||||
+++ b/node_modules/redux-persist/es/types.js.flow
|
|
||||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
|
||||||
whitelist?: Array<string>,
|
|
||||||
transforms?: Array<Transform>,
|
|
||||||
throttle?: number,
|
|
||||||
+ debounce?: number,
|
|
||||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
|
||||||
stateReconciler?: false | Function,
|
|
||||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
|
||||||
diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow
|
|
||||||
index c50d3cd..39d8be2 100644
|
|
||||||
--- a/node_modules/redux-persist/lib/types.js.flow
|
|
||||||
+++ b/node_modules/redux-persist/lib/types.js.flow
|
|
||||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
|
||||||
whitelist?: Array<string>,
|
|
||||||
transforms?: Array<Transform>,
|
|
||||||
throttle?: number,
|
|
||||||
+ debounce?: number,
|
|
||||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
|
||||||
stateReconciler?: false | Function,
|
|
||||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
|
||||||
diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js
|
|
||||||
index c50d3cd..39d8be2 100644
|
|
||||||
--- a/node_modules/redux-persist/src/types.js
|
|
||||||
+++ b/node_modules/redux-persist/src/types.js
|
|
||||||
@@ -19,6 +19,7 @@ export type PersistConfig = {
|
|
||||||
whitelist?: Array<string>,
|
|
||||||
transforms?: Array<Transform>,
|
|
||||||
throttle?: number,
|
|
||||||
+ debounce?: number,
|
|
||||||
migrate?: (PersistedState, number) => Promise<PersistedState>,
|
|
||||||
stateReconciler?: false | Function,
|
|
||||||
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
|
|
||||||
diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts
|
|
||||||
index b3733bc..2a1696c 100644
|
|
||||||
--- a/node_modules/redux-persist/types/types.d.ts
|
|
||||||
+++ b/node_modules/redux-persist/types/types.d.ts
|
|
||||||
@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" {
|
|
||||||
whitelist?: Array<string>;
|
|
||||||
transforms?: Array<Transform<HSS, ESS, S, RS>>;
|
|
||||||
throttle?: number;
|
|
||||||
+ debounce?: number;
|
|
||||||
migrate?: PersistMigrate;
|
|
||||||
stateReconciler?: false | StateReconciler<S>;
|
|
||||||
/**
|
|
@ -450,7 +450,7 @@
|
|||||||
"cfgScale": "CFG Scale",
|
"cfgScale": "CFG Scale",
|
||||||
"width": "Width",
|
"width": "Width",
|
||||||
"height": "Height",
|
"height": "Height",
|
||||||
"sampler": "Sampler",
|
"scheduler": "Scheduler",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
"imageToImage": "Image to Image",
|
"imageToImage": "Image to Image",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
@ -540,7 +540,10 @@
|
|||||||
"consoleLogLevel": "Log Level",
|
"consoleLogLevel": "Log Level",
|
||||||
"shouldLogToConsole": "Console Logging",
|
"shouldLogToConsole": "Console Logging",
|
||||||
"developer": "Developer",
|
"developer": "Developer",
|
||||||
"general": "General"
|
"general": "General",
|
||||||
|
"generation": "Generation",
|
||||||
|
"ui": "User Interface",
|
||||||
|
"availableSchedulers": "Available Schedulers"
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"serverError": "Server Error",
|
"serverError": "Server Error",
|
||||||
@ -549,8 +552,8 @@
|
|||||||
"canceled": "Processing Canceled",
|
"canceled": "Processing Canceled",
|
||||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||||
"uploadFailed": "Upload failed",
|
"uploadFailed": "Upload failed",
|
||||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
|
||||||
"uploadFailedUnableToLoadDesc": "Unable to load file",
|
"uploadFailedUnableToLoadDesc": "Unable to load file",
|
||||||
|
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
|
||||||
"downloadImageStarted": "Image Download Started",
|
"downloadImageStarted": "Image Download Started",
|
||||||
"imageCopied": "Image Copied",
|
"imageCopied": "Image Copied",
|
||||||
"imageLinkCopied": "Image Link Copied",
|
"imageLinkCopied": "Image Link Copied",
|
||||||
|
@ -2,14 +2,11 @@ import ImageUploader from 'common/components/ImageUploader';
|
|||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import ProgressBar from 'features/system/components/ProgressBar';
|
import ProgressBar from 'features/system/components/ProgressBar';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
|
|
||||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
|
||||||
|
|
||||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||||
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||||
import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
||||||
@ -17,25 +14,28 @@ import { motion, AnimatePresence } from 'framer-motion';
|
|||||||
import Loading from 'common/components/Loading/Loading';
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
|
import Toaster from './Toaster';
|
||||||
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
config?: PartialAppConfig;
|
config?: PartialAppConfig;
|
||||||
headerComponent?: ReactNode;
|
headerComponent?: ReactNode;
|
||||||
|
setIsReady?: (isReady: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
const App = ({
|
||||||
useToastWatcher();
|
config = DEFAULT_CONFIG,
|
||||||
useGlobalHotkeys();
|
headerComponent,
|
||||||
|
setIsReady,
|
||||||
|
}: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(languageSelector);
|
||||||
|
|
||||||
const log = useLogger();
|
const log = useLogger();
|
||||||
@ -61,7 +61,18 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
setLoadingOverridden(true);
|
setLoadingOverridden(true);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isApplicationReady && setIsReady) {
|
||||||
|
setIsReady(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
setIsReady && setIsReady(false);
|
||||||
|
};
|
||||||
|
}, [isApplicationReady, setIsReady]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
|
||||||
{isLightboxEnabled && <Lightbox />}
|
{isLightboxEnabled && <Lightbox />}
|
||||||
<ImageUploader>
|
<ImageUploader>
|
||||||
@ -121,6 +132,9 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
<FloatingGalleryButton />
|
<FloatingGalleryButton />
|
||||||
</Portal>
|
</Portal>
|
||||||
</Grid>
|
</Grid>
|
||||||
|
<Toaster />
|
||||||
|
<GlobalHotkeys />
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
const selector = createSelector(systemSelector, (system) => {
|
||||||
|
const { isUploading } = system;
|
||||||
|
|
||||||
|
let tooltip = '';
|
||||||
|
|
||||||
|
if (isUploading) {
|
||||||
|
tooltip = 'Uploading...';
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
tooltip,
|
||||||
|
shouldShow: isUploading,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
export const AuxiliaryProgressIndicator = () => {
|
||||||
|
const { shouldShow, tooltip } = useAppSelector(selector);
|
||||||
|
|
||||||
|
if (!shouldShow) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
color: 'base.600',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Tooltip label={tooltip} placement="right" hasArrow>
|
||||||
|
<Spinner />
|
||||||
|
</Tooltip>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(AuxiliaryProgressIndicator);
|
@ -10,6 +10,7 @@ import {
|
|||||||
togglePinParametersPanel,
|
togglePinParametersPanel,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
import React, { memo } from 'react';
|
||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
const globalHotkeysSelector = createSelector(
|
const globalHotkeysSelector = createSelector(
|
||||||
@ -27,7 +28,11 @@ const globalHotkeysSelector = createSelector(
|
|||||||
|
|
||||||
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
|
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
|
||||||
|
|
||||||
export const useGlobalHotkeys = () => {
|
/**
|
||||||
|
* Logical component. Handles app-level global hotkeys.
|
||||||
|
* @returns null
|
||||||
|
*/
|
||||||
|
const GlobalHotkeys: React.FC = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { shift } = useAppSelector(globalHotkeysSelector);
|
const { shift } = useAppSelector(globalHotkeysSelector);
|
||||||
|
|
||||||
@ -75,4 +80,8 @@ export const useGlobalHotkeys = () => {
|
|||||||
useHotkeys('4', () => {
|
useHotkeys('4', () => {
|
||||||
dispatch(setActiveTab('nodes'));
|
dispatch(setActiveTab('nodes'));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export default memo(GlobalHotkeys);
|
@ -24,9 +24,16 @@ interface Props extends PropsWithChildren {
|
|||||||
token?: string;
|
token?: string;
|
||||||
config?: PartialAppConfig;
|
config?: PartialAppConfig;
|
||||||
headerComponent?: ReactNode;
|
headerComponent?: ReactNode;
|
||||||
|
setIsReady?: (isReady: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
const InvokeAIUI = ({
|
||||||
|
apiUrl,
|
||||||
|
token,
|
||||||
|
config,
|
||||||
|
headerComponent,
|
||||||
|
setIsReady,
|
||||||
|
}: Props) => {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// configure API client token
|
// configure API client token
|
||||||
if (token) {
|
if (token) {
|
||||||
@ -55,7 +62,11 @@ const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
|||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<App config={config} headerComponent={headerComponent} />
|
<App
|
||||||
|
config={config}
|
||||||
|
headerComponent={headerComponent}
|
||||||
|
setIsReady={setIsReady}
|
||||||
|
/>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</Provider>
|
</Provider>
|
||||||
|
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
65
invokeai/frontend/web/src/app/components/Toaster.ts
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import { useToast, UseToastOptions } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { toastQueueSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
|
||||||
|
import { useCallback, useEffect } from 'react';
|
||||||
|
|
||||||
|
export type MakeToastArg = string | UseToastOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Makes a toast from a string or a UseToastOptions object.
|
||||||
|
* If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms.
|
||||||
|
*/
|
||||||
|
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
|
||||||
|
if (typeof arg === 'string') {
|
||||||
|
return {
|
||||||
|
title: arg,
|
||||||
|
status: 'info',
|
||||||
|
isClosable: true,
|
||||||
|
duration: 2500,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return { status: 'info', isClosable: true, duration: 2500, ...arg };
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
|
||||||
|
* @returns null
|
||||||
|
*/
|
||||||
|
const Toaster = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const toastQueue = useAppSelector(toastQueueSelector);
|
||||||
|
const toast = useToast();
|
||||||
|
useEffect(() => {
|
||||||
|
toastQueue.forEach((t) => {
|
||||||
|
toast(t);
|
||||||
|
});
|
||||||
|
toastQueue.length > 0 && dispatch(clearToastQueue());
|
||||||
|
}, [dispatch, toast, toastQueue]);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a function that can be used to make a toast.
|
||||||
|
* @example
|
||||||
|
* const toaster = useAppToaster();
|
||||||
|
* toaster('Hello world!');
|
||||||
|
* toaster({ title: 'Hello world!', status: 'success' });
|
||||||
|
* @returns A function that can be used to make a toast.
|
||||||
|
* @see makeToast
|
||||||
|
* @see MakeToastArg
|
||||||
|
* @see UseToastOptions
|
||||||
|
*/
|
||||||
|
export const useAppToaster = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const toaster = useCallback(
|
||||||
|
(arg: MakeToastArg) => dispatch(addToast(makeToast(arg))),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return toaster;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Toaster;
|
@ -1,28 +1,28 @@
|
|||||||
// TODO: use Enums?
|
// TODO: use Enums?
|
||||||
|
|
||||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
export const SCHEDULERS = [
|
||||||
'ddim',
|
'ddim',
|
||||||
'ddpm',
|
|
||||||
'deis',
|
|
||||||
'lms',
|
'lms',
|
||||||
'pndm',
|
|
||||||
'heun',
|
|
||||||
'euler',
|
'euler',
|
||||||
'euler_k',
|
'euler_k',
|
||||||
'euler_a',
|
'euler_a',
|
||||||
'kdpm_2',
|
|
||||||
'kdpm_2_a',
|
|
||||||
'dpmpp_2s',
|
'dpmpp_2s',
|
||||||
'dpmpp_2m',
|
'dpmpp_2m',
|
||||||
'dpmpp_2m_k',
|
'dpmpp_2m_k',
|
||||||
|
'kdpm_2',
|
||||||
|
'kdpm_2_a',
|
||||||
|
'deis',
|
||||||
|
'ddpm',
|
||||||
|
'pndm',
|
||||||
|
'heun',
|
||||||
|
'heun_k',
|
||||||
'unipc',
|
'unipc',
|
||||||
];
|
] as const;
|
||||||
|
|
||||||
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
|
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||||
(scheduler) => {
|
|
||||||
return scheduler !== 'dpmpp_2s';
|
export const isScheduler = (x: string): x is Scheduler =>
|
||||||
}
|
SCHEDULERS.includes(x as Scheduler);
|
||||||
);
|
|
||||||
|
|
||||||
// Valid image widths
|
// Valid image widths
|
||||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||||
|
@ -15,6 +15,10 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
|||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
|
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
|
||||||
|
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||||
|
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||||
|
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -43,3 +47,8 @@ addUserInvokedCanvasListener();
|
|||||||
addUserInvokedNodesListener();
|
addUserInvokedNodesListener();
|
||||||
addUserInvokedTextToImageListener();
|
addUserInvokedTextToImageListener();
|
||||||
addUserInvokedImageToImageListener();
|
addUserInvokedImageToImageListener();
|
||||||
|
|
||||||
|
addCanvasSavedToGalleryListener();
|
||||||
|
addCanvasDownloadedAsImageListener();
|
||||||
|
addCanvasCopiedToClipboardListener();
|
||||||
|
addCanvasMergedListener();
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||||
|
|
||||||
|
export const addCanvasCopiedToClipboardListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasCopiedToClipboard,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const blob = await getBaseLayerBlob(state);
|
||||||
|
|
||||||
|
if (!blob) {
|
||||||
|
moduleLog.error('Problem getting base layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Copying Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
copyBlobToClipboard(blob);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,33 @@
|
|||||||
|
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
||||||
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||||
|
|
||||||
|
export const addCanvasDownloadedAsImageListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasDownloadedAsImage,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const blob = await getBaseLayerBlob(state);
|
||||||
|
|
||||||
|
if (!blob) {
|
||||||
|
moduleLog.error('Problem getting base layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Downloading Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
downloadBlob(blob, 'mergedCanvas.png');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,31 +0,0 @@
|
|||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
import {
|
|
||||||
canvasSessionIdChanged,
|
|
||||||
stagingAreaInitialized,
|
|
||||||
} from 'features/canvas/store/canvasSlice';
|
|
||||||
import { sessionInvoked } from 'services/thunks/session';
|
|
||||||
|
|
||||||
export const addCanvasGraphBuiltListener = () =>
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasGraphBuilt,
|
|
||||||
effect: async (action, { dispatch, getState, take }) => {
|
|
||||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
|
||||||
const { sessionId } = meta.arg;
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
|
||||||
dispatch(
|
|
||||||
stagingAreaInitialized({
|
|
||||||
sessionId,
|
|
||||||
boundingBox: {
|
|
||||||
...state.canvas.boundingBoxCoordinates,
|
|
||||||
...state.canvas.boundingBoxDimensions,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(canvasSessionIdChanged(sessionId));
|
|
||||||
},
|
|
||||||
});
|
|
@ -0,0 +1,88 @@
|
|||||||
|
import { canvasMerged } from 'features/canvas/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
|
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||||
|
|
||||||
|
export const addCanvasMergedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasMerged,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const blob = await getBaseLayerBlob(state, true);
|
||||||
|
|
||||||
|
if (!blob) {
|
||||||
|
moduleLog.error('Problem getting base layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Merging Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
|
if (!canvasBaseLayer) {
|
||||||
|
moduleLog.error('Problem getting canvas base layer');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Merging Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseLayerRect = canvasBaseLayer.getClientRect({
|
||||||
|
relativeTo: canvasBaseLayer.getParent(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
imageType: 'intermediates',
|
||||||
|
formData: {
|
||||||
|
file: new File([blob], filename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const [{ payload }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === filename
|
||||||
|
);
|
||||||
|
|
||||||
|
const mergedCanvasImage = deserializeImageResponse(payload.response);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
setMergedCanvas({
|
||||||
|
kind: 'image',
|
||||||
|
layer: 'base',
|
||||||
|
image: mergedCanvasImage,
|
||||||
|
...baseLayerRect,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Canvas Merged',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,40 @@
|
|||||||
|
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||||
|
|
||||||
|
export const addCanvasSavedToGalleryListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasSavedToGallery,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const blob = await getBaseLayerBlob(state);
|
||||||
|
|
||||||
|
if (!blob) {
|
||||||
|
moduleLog.error('Problem getting base layer blob');
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Problem Saving Canvas',
|
||||||
|
description: 'Unable to export base layer',
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
imageType: 'results',
|
||||||
|
formData: {
|
||||||
|
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -3,6 +3,10 @@ import { startAppListening } from '..';
|
|||||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||||
|
|
||||||
export const addImageUploadedListener = () => {
|
export const addImageUploadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -11,15 +15,32 @@ export const addImageUploadedListener = () => {
|
|||||||
action.payload.response.image_type !== 'intermediates',
|
action.payload.response.image_type !== 'intermediates',
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { response } = action.payload;
|
const { response } = action.payload;
|
||||||
|
const { imageType } = action.meta.arg;
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const image = deserializeImageResponse(response);
|
const image = deserializeImageResponse(response);
|
||||||
|
|
||||||
|
if (imageType === 'uploads') {
|
||||||
dispatch(uploadAdded(image));
|
dispatch(uploadAdded(image));
|
||||||
|
|
||||||
|
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||||
|
|
||||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||||
dispatch(imageSelected(image));
|
dispatch(imageSelected(image));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (action.meta.arg.activeTabName === 'img2img') {
|
||||||
|
dispatch(initialImageSelected(image));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
||||||
|
dispatch(setInitialCanvasImage(image));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imageType === 'results') {
|
||||||
|
dispatch(resultAdded(image));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -2,11 +2,11 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
|||||||
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
||||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||||
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
|
||||||
export const addInitialImageSelectedListener = () => {
|
export const addInitialImageSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
||||||
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
@ -11,9 +11,17 @@ import {
|
|||||||
stagingAreaInitialized,
|
stagingAreaInitialized,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
||||||
|
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||||
|
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
|
||||||
|
* It is also responsible for uploading the base and mask layers to the server.
|
||||||
|
*/
|
||||||
export const addUserInvokedCanvasListener = () => {
|
export const addUserInvokedCanvasListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
@ -21,25 +29,49 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
effect: async (action, { getState, dispatch, take }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const data = await buildCanvasGraphAndBlobs(state);
|
// Build canvas blobs
|
||||||
|
const canvasBlobsAndImageData = await getCanvasData(state);
|
||||||
|
|
||||||
if (!data) {
|
if (!canvasBlobsAndImageData) {
|
||||||
|
moduleLog.error('Unable to create canvas data');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { baseBlob, baseImageData, maskBlob, maskImageData } =
|
||||||
|
canvasBlobsAndImageData;
|
||||||
|
|
||||||
|
// Determine the generation mode
|
||||||
|
const generationMode = getCanvasGenerationMode(
|
||||||
|
baseImageData,
|
||||||
|
maskImageData
|
||||||
|
);
|
||||||
|
|
||||||
|
if (state.system.enableImageDebugging) {
|
||||||
|
const baseDataURL = await blobToDataURL(baseBlob);
|
||||||
|
const maskDataURL = await blobToDataURL(maskBlob);
|
||||||
|
openBase64ImageInTab([
|
||||||
|
{ base64: maskDataURL, caption: 'mask b64' },
|
||||||
|
{ base64: baseDataURL, caption: 'image b64' },
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||||
|
|
||||||
|
// Build the canvas graph
|
||||||
|
const graphComponents = await buildCanvasGraphComponents(
|
||||||
|
state,
|
||||||
|
generationMode
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!graphComponents) {
|
||||||
moduleLog.error('Problem building graph');
|
moduleLog.error('Problem building graph');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const {
|
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||||
rangeNode,
|
|
||||||
iterateNode,
|
|
||||||
baseNode,
|
|
||||||
edges,
|
|
||||||
baseBlob,
|
|
||||||
maskBlob,
|
|
||||||
generationMode,
|
|
||||||
} = data;
|
|
||||||
|
|
||||||
|
// Upload the base layer, to be used as init image
|
||||||
const baseFilename = `${uuidv4()}.png`;
|
const baseFilename = `${uuidv4()}.png`;
|
||||||
const maskFilename = `${uuidv4()}.png`;
|
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
@ -66,6 +98,9 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Upload the mask layer image
|
||||||
|
const maskFilename = `${uuidv4()}.png`;
|
||||||
|
|
||||||
if (baseNode.type === 'inpaint') {
|
if (baseNode.type === 'inpaint') {
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
@ -103,9 +138,12 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
dispatch(canvasGraphBuilt(graph));
|
dispatch(canvasGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Canvas graph built');
|
moduleLog({ data: graph }, 'Canvas graph built');
|
||||||
|
|
||||||
|
// Actually create the session
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
// Wait for the session to be invoked (this is just the HTTP request to start processing)
|
||||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||||
|
|
||||||
const { sessionId } = meta.arg;
|
const { sessionId } = meta.arg;
|
||||||
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
|
@ -52,6 +52,7 @@ export type CommonGeneratedImageMetadata = {
|
|||||||
| 'lms'
|
| 'lms'
|
||||||
| 'pndm'
|
| 'pndm'
|
||||||
| 'heun'
|
| 'heun'
|
||||||
|
| 'heun_k'
|
||||||
| 'euler'
|
| 'euler'
|
||||||
| 'euler_k'
|
| 'euler_k'
|
||||||
| 'euler_a'
|
| 'euler_a'
|
||||||
|
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
172
invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import { CheckIcon } from '@chakra-ui/icons';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Flex,
|
||||||
|
FlexProps,
|
||||||
|
FormControl,
|
||||||
|
FormControlProps,
|
||||||
|
FormLabel,
|
||||||
|
Grid,
|
||||||
|
GridItem,
|
||||||
|
List,
|
||||||
|
ListItem,
|
||||||
|
Select,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
TooltipProps,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||||
|
import { useSelect } from 'downshift';
|
||||||
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
type IAICustomSelectProps = {
|
||||||
|
label?: string;
|
||||||
|
items: string[];
|
||||||
|
selectedItem: string;
|
||||||
|
setSelectedItem: (v: string | null | undefined) => void;
|
||||||
|
withCheckIcon?: boolean;
|
||||||
|
formControlProps?: FormControlProps;
|
||||||
|
buttonProps?: FlexProps;
|
||||||
|
tooltip?: string;
|
||||||
|
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||||
|
const {
|
||||||
|
label,
|
||||||
|
items,
|
||||||
|
setSelectedItem,
|
||||||
|
selectedItem,
|
||||||
|
withCheckIcon,
|
||||||
|
formControlProps,
|
||||||
|
tooltip,
|
||||||
|
buttonProps,
|
||||||
|
tooltipProps,
|
||||||
|
} = props;
|
||||||
|
|
||||||
|
const {
|
||||||
|
isOpen,
|
||||||
|
getToggleButtonProps,
|
||||||
|
getLabelProps,
|
||||||
|
getMenuProps,
|
||||||
|
highlightedIndex,
|
||||||
|
getItemProps,
|
||||||
|
} = useSelect({
|
||||||
|
items,
|
||||||
|
selectedItem,
|
||||||
|
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
|
||||||
|
setSelectedItem(newSelectedItem),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
||||||
|
whileElementsMounted: autoUpdate,
|
||||||
|
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
||||||
|
{label && (
|
||||||
|
<FormLabel
|
||||||
|
{...getLabelProps()}
|
||||||
|
onClick={() => {
|
||||||
|
refs.floating.current && refs.floating.current.focus();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<Tooltip label={tooltip} {...tooltipProps}>
|
||||||
|
<Select
|
||||||
|
{...getToggleButtonProps({ ref: refs.setReference })}
|
||||||
|
{...buttonProps}
|
||||||
|
as={Flex}
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
userSelect: 'none',
|
||||||
|
cursor: 'pointer',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}>
|
||||||
|
{selectedItem}
|
||||||
|
</Text>
|
||||||
|
</Select>
|
||||||
|
</Tooltip>
|
||||||
|
<Box {...getMenuProps()}>
|
||||||
|
{isOpen && (
|
||||||
|
<List
|
||||||
|
as={Flex}
|
||||||
|
ref={refs.setFloating}
|
||||||
|
sx={{
|
||||||
|
...floatingStyles,
|
||||||
|
width: 'max-content',
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
flexDirection: 'column',
|
||||||
|
zIndex: 1,
|
||||||
|
bg: 'base.800',
|
||||||
|
borderRadius: 'base',
|
||||||
|
border: '1px',
|
||||||
|
borderColor: 'base.700',
|
||||||
|
shadow: 'dark-lg',
|
||||||
|
py: 2,
|
||||||
|
px: 0,
|
||||||
|
h: 'fit-content',
|
||||||
|
maxH: 64,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<OverlayScrollbarsComponent>
|
||||||
|
{items.map((item, index) => (
|
||||||
|
<ListItem
|
||||||
|
sx={{
|
||||||
|
bg: highlightedIndex === index ? 'base.700' : undefined,
|
||||||
|
py: 1,
|
||||||
|
paddingInlineStart: 3,
|
||||||
|
paddingInlineEnd: 6,
|
||||||
|
cursor: 'pointer',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.15s',
|
||||||
|
}}
|
||||||
|
key={`${item}${index}`}
|
||||||
|
{...getItemProps({ item, index })}
|
||||||
|
>
|
||||||
|
{withCheckIcon ? (
|
||||||
|
<Grid gridTemplateColumns="1.25rem auto">
|
||||||
|
<GridItem>
|
||||||
|
{selectedItem === item && <CheckIcon boxSize={2} />}
|
||||||
|
</GridItem>
|
||||||
|
<GridItem>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
color: 'base.100',
|
||||||
|
fontWeight: 500,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{item}
|
||||||
|
</Text>
|
||||||
|
</GridItem>
|
||||||
|
</Grid>
|
||||||
|
) : (
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
color: 'base.100',
|
||||||
|
fontWeight: 500,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{item}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</ListItem>
|
||||||
|
))}
|
||||||
|
</OverlayScrollbarsComponent>
|
||||||
|
</List>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAICustomSelect);
|
@ -5,6 +5,7 @@ import {
|
|||||||
Input,
|
Input,
|
||||||
InputProps,
|
InputProps,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||||
import { ChangeEvent, memo } from 'react';
|
import { ChangeEvent, memo } from 'react';
|
||||||
|
|
||||||
interface IAIInputProps extends InputProps {
|
interface IAIInputProps extends InputProps {
|
||||||
@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => {
|
|||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
{label !== '' && <FormLabel>{label}</FormLabel>}
|
{label !== '' && <FormLabel>{label}</FormLabel>}
|
||||||
<Input {...rest} />
|
<Input {...rest} onPaste={stopPastePropagation} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
TooltipProps,
|
TooltipProps,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
|
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
|
|
||||||
import { FocusEvent, memo, useEffect, useState } from 'react';
|
import { FocusEvent, memo, useEffect, useState } from 'react';
|
||||||
@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => {
|
|||||||
onChange={handleOnChange}
|
onChange={handleOnChange}
|
||||||
onBlur={handleBlur}
|
onBlur={handleBlur}
|
||||||
{...rest}
|
{...rest}
|
||||||
|
onPaste={stopPastePropagation}
|
||||||
>
|
>
|
||||||
<NumberInputField {...numberInputFieldProps} />
|
<NumberInputField {...numberInputFieldProps} />
|
||||||
{showStepper && (
|
{showStepper && (
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react';
|
||||||
|
import { stopPastePropagation } from 'common/util/stopPastePropagation';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
const IAITextarea = forwardRef((props: TextareaProps, ref) => {
|
||||||
|
return <Textarea ref={ref} onPaste={stopPastePropagation} {...props} />;
|
||||||
|
});
|
||||||
|
|
||||||
|
export default memo(IAITextarea);
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, useToast } from '@chakra-ui/react';
|
import { Box } from '@chakra-ui/react';
|
||||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import useImageUploader from 'common/hooks/useImageUploader';
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
@ -10,12 +10,33 @@ import {
|
|||||||
ReactNode,
|
ReactNode,
|
||||||
useCallback,
|
useCallback,
|
||||||
useEffect,
|
useEffect,
|
||||||
|
useMemo,
|
||||||
|
useRef,
|
||||||
useState,
|
useState,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||||
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
|
import { filter, map, some } from 'lodash-es';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import { ErrorCode } from 'react-dropzone';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[systemSelector, activeTabNameSelector],
|
||||||
|
(system, activeTabName) => {
|
||||||
|
const { isConnected, isUploading } = system;
|
||||||
|
|
||||||
|
const isUploaderDisabled = !isConnected || isUploading;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isUploaderDisabled,
|
||||||
|
activeTabName,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
type ImageUploaderProps = {
|
type ImageUploaderProps = {
|
||||||
children: ReactNode;
|
children: ReactNode;
|
||||||
@ -24,38 +45,49 @@ type ImageUploaderProps = {
|
|||||||
const ImageUploader = (props: ImageUploaderProps) => {
|
const ImageUploader = (props: ImageUploaderProps) => {
|
||||||
const { children } = props;
|
const { children } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
const { isUploaderDisabled, activeTabName } = useAppSelector(selector);
|
||||||
const toast = useToast({});
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
|
||||||
const { setOpenUploader } = useImageUploader();
|
const { setOpenUploaderFunction } = useImageUploader();
|
||||||
|
|
||||||
const fileRejectionCallback = useCallback(
|
const fileRejectionCallback = useCallback(
|
||||||
(rejection: FileRejection) => {
|
(rejection: FileRejection) => {
|
||||||
setIsHandlingUpload(true);
|
setIsHandlingUpload(true);
|
||||||
const msg = rejection.errors.reduce(
|
|
||||||
(acc: string, cur: { message: string }) => `${acc}\n${cur.message}`,
|
toaster({
|
||||||
''
|
|
||||||
);
|
|
||||||
toast({
|
|
||||||
title: t('toast.uploadFailed'),
|
title: t('toast.uploadFailed'),
|
||||||
description: msg,
|
description: rejection.errors.map((error) => error.message).join('\n'),
|
||||||
status: 'error',
|
status: 'error',
|
||||||
isClosable: true,
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[t, toast]
|
[t, toaster]
|
||||||
);
|
);
|
||||||
|
|
||||||
const fileAcceptedCallback = useCallback(
|
const fileAcceptedCallback = useCallback(
|
||||||
async (file: File) => {
|
async (file: File) => {
|
||||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
imageType: 'uploads',
|
||||||
|
formData: { file },
|
||||||
|
activeTabName,
|
||||||
|
})
|
||||||
|
);
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch, activeTabName]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onDrop = useCallback(
|
const onDrop = useCallback(
|
||||||
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
|
||||||
|
if (fileRejections.length > 1) {
|
||||||
|
toaster({
|
||||||
|
title: t('toast.uploadFailed'),
|
||||||
|
description: t('toast.uploadFailedInvalidUploadDesc'),
|
||||||
|
status: 'error',
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fileRejections.forEach((rejection: FileRejection) => {
|
fileRejections.forEach((rejection: FileRejection) => {
|
||||||
fileRejectionCallback(rejection);
|
fileRejectionCallback(rejection);
|
||||||
});
|
});
|
||||||
@ -64,7 +96,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
fileAcceptedCallback(file);
|
fileAcceptedCallback(file);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[fileAcceptedCallback, fileRejectionCallback]
|
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
|
||||||
);
|
);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -73,73 +105,55 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
isDragAccept,
|
isDragAccept,
|
||||||
isDragReject,
|
isDragReject,
|
||||||
isDragActive,
|
isDragActive,
|
||||||
|
inputRef,
|
||||||
open,
|
open,
|
||||||
} = useDropzone({
|
} = useDropzone({
|
||||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||||
noClick: true,
|
noClick: true,
|
||||||
onDrop,
|
onDrop,
|
||||||
onDragOver: () => setIsHandlingUpload(true),
|
onDragOver: () => setIsHandlingUpload(true),
|
||||||
maxFiles: 1,
|
disabled: isUploaderDisabled,
|
||||||
|
multiple: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
setOpenUploader(open);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const pasteImageListener = (e: ClipboardEvent) => {
|
// This is a hack to allow pasting images into the uploader
|
||||||
const dataTransferItemList = e.clipboardData?.items;
|
const handlePaste = async (e: ClipboardEvent) => {
|
||||||
if (!dataTransferItemList) return;
|
if (!inputRef.current) {
|
||||||
|
|
||||||
const imageItems: Array<DataTransferItem> = [];
|
|
||||||
|
|
||||||
for (const item of dataTransferItemList) {
|
|
||||||
if (
|
|
||||||
item.kind === 'file' &&
|
|
||||||
['image/png', 'image/jpg'].includes(item.type)
|
|
||||||
) {
|
|
||||||
imageItems.push(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!imageItems.length) return;
|
|
||||||
|
|
||||||
e.stopImmediatePropagation();
|
|
||||||
|
|
||||||
if (imageItems.length > 1) {
|
|
||||||
toast({
|
|
||||||
description: t('toast.uploadFailedMultipleImagesDesc'),
|
|
||||||
status: 'error',
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const file = imageItems[0].getAsFile();
|
if (e.clipboardData?.files) {
|
||||||
|
// Set the files on the inputRef
|
||||||
if (!file) {
|
inputRef.current.files = e.clipboardData.files;
|
||||||
toast({
|
// Dispatch the change event, dropzone catches this and we get to use its own validation
|
||||||
description: t('toast.uploadFailedUnableToLoadDesc'),
|
inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
|
||||||
status: 'error',
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
|
||||||
};
|
};
|
||||||
document.addEventListener('paste', pasteImageListener);
|
|
||||||
|
// Set the open function so we can open the uploader from anywhere
|
||||||
|
setOpenUploaderFunction(open);
|
||||||
|
|
||||||
|
// Add the paste event listener
|
||||||
|
document.addEventListener('paste', handlePaste);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
document.removeEventListener('paste', pasteImageListener);
|
document.removeEventListener('paste', handlePaste);
|
||||||
|
setOpenUploaderFunction(() => {
|
||||||
|
return;
|
||||||
|
});
|
||||||
};
|
};
|
||||||
}, [t, dispatch, toast, activeTabName]);
|
}, [inputRef, open, setOpenUploaderFunction]);
|
||||||
|
|
||||||
const overlaySecondaryText = ['img2img', 'unifiedCanvas'].includes(
|
const overlaySecondaryText = useMemo(() => {
|
||||||
activeTabName
|
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
|
||||||
)
|
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
|
||||||
? ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`
|
}
|
||||||
: ``;
|
|
||||||
|
return '';
|
||||||
|
}, [t, activeTabName]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ImageUploaderTriggerContext.Provider value={open}>
|
|
||||||
<Box
|
<Box
|
||||||
{...getRootProps({ style: {} })}
|
{...getRootProps({ style: {} })}
|
||||||
onKeyDown={(e: KeyboardEvent) => {
|
onKeyDown={(e: KeyboardEvent) => {
|
||||||
@ -158,7 +172,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</ImageUploaderTriggerContext.Provider>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { Flex, Heading, Icon } from '@chakra-ui/react';
|
import { Flex, Heading, Icon } from '@chakra-ui/react';
|
||||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
import { useContext } from 'react';
|
|
||||||
import { FaUpload } from 'react-icons/fa';
|
import { FaUpload } from 'react-icons/fa';
|
||||||
|
|
||||||
type ImageUploaderButtonProps = {
|
type ImageUploaderButtonProps = {
|
||||||
@ -9,11 +8,7 @@ type ImageUploaderButtonProps = {
|
|||||||
|
|
||||||
const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
|
const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
|
||||||
const { styleClass } = props;
|
const { styleClass } = props;
|
||||||
const open = useContext(ImageUploaderTriggerContext);
|
const { openUploader } = useImageUploader();
|
||||||
|
|
||||||
const handleClickUpload = () => {
|
|
||||||
open && open();
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -26,7 +21,7 @@ const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
|
|||||||
className={styleClass}
|
className={styleClass}
|
||||||
>
|
>
|
||||||
<Flex
|
<Flex
|
||||||
onClick={handleClickUpload}
|
onClick={openUploader}
|
||||||
sx={{
|
sx={{
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
|
||||||
import { useContext } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaUpload } from 'react-icons/fa';
|
import { FaUpload } from 'react-icons/fa';
|
||||||
import IAIIconButton from './IAIIconButton';
|
import IAIIconButton from './IAIIconButton';
|
||||||
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
|
|
||||||
const ImageUploaderIconButton = () => {
|
const ImageUploaderIconButton = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const openImageUploader = useContext(ImageUploaderTriggerContext);
|
const { openUploader } = useImageUploader();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('accessibility.uploadImage')}
|
aria-label={t('accessibility.uploadImage')}
|
||||||
tooltip="Upload Image"
|
tooltip="Upload Image"
|
||||||
icon={<FaUpload />}
|
icon={<FaUpload />}
|
||||||
onClick={openImageUploader || undefined}
|
onClick={openUploader}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -6,10 +6,12 @@ import { FaUndo, FaUpload } from 'react-icons/fa';
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
|
|
||||||
const InitialImageButtons = () => {
|
const InitialImageButtons = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { openUploader } = useImageUploader();
|
||||||
|
|
||||||
const handleResetInitialImage = useCallback(() => {
|
const handleResetInitialImage = useCallback(() => {
|
||||||
dispatch(clearInitialImage());
|
dispatch(clearInitialImage());
|
||||||
@ -27,7 +29,11 @@ const InitialImageButtons = () => {
|
|||||||
aria-label={t('accessibility.reset')}
|
aria-label={t('accessibility.reset')}
|
||||||
onClick={handleResetInitialImage}
|
onClick={handleResetInitialImage}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
|
<IAIIconButton
|
||||||
|
icon={<FaUpload />}
|
||||||
|
onClick={openUploader}
|
||||||
|
aria-label={t('common.upload')}
|
||||||
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
@ -24,7 +24,6 @@ const Loading = () => {
|
|||||||
height="24px !important"
|
height="24px !important"
|
||||||
right="1.5rem"
|
right="1.5rem"
|
||||||
bottom="1.5rem"
|
bottom="1.5rem"
|
||||||
speed="1.2s"
|
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -1,29 +0,0 @@
|
|||||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import WorkInProgress from './WorkInProgress';
|
|
||||||
|
|
||||||
export const PostProcessingWIP = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
return (
|
|
||||||
<WorkInProgress>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
w: '100%',
|
|
||||||
h: '100%',
|
|
||||||
gap: 4,
|
|
||||||
textAlign: 'center',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Heading>{t('common.postProcessing')}</Heading>
|
|
||||||
<VStack maxW="50rem" gap={4}>
|
|
||||||
<Text>{t('common.postProcessDesc1')}</Text>
|
|
||||||
<Text>{t('common.postProcessDesc2')}</Text>
|
|
||||||
<Text>{t('common.postProcessDesc3')}</Text>
|
|
||||||
</VStack>
|
|
||||||
</Flex>
|
|
||||||
</WorkInProgress>
|
|
||||||
);
|
|
||||||
};
|
|
@ -1,28 +0,0 @@
|
|||||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import WorkInProgress from './WorkInProgress';
|
|
||||||
|
|
||||||
export default function TrainingWIP() {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
return (
|
|
||||||
<WorkInProgress>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
w: '100%',
|
|
||||||
h: '100%',
|
|
||||||
gap: 4,
|
|
||||||
textAlign: 'center',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Heading>{t('common.training')}</Heading>
|
|
||||||
<VStack maxW="50rem" gap={4}>
|
|
||||||
<Text>{t('common.trainingDesc1')}</Text>
|
|
||||||
<Text>{t('common.trainingDesc2')}</Text>
|
|
||||||
</VStack>
|
|
||||||
</Flex>
|
|
||||||
</WorkInProgress>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import { ReactNode } from 'react';
|
|
||||||
|
|
||||||
type WorkInProgressProps = {
|
|
||||||
children: ReactNode;
|
|
||||||
};
|
|
||||||
|
|
||||||
const WorkInProgress = (props: WorkInProgressProps) => {
|
|
||||||
const { children } = props;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
width: '100%',
|
|
||||||
height: '100%',
|
|
||||||
bg: 'base.850',
|
|
||||||
borderRadius: 'base',
|
|
||||||
position: 'relative',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default WorkInProgress;
|
|
@ -1,35 +0,0 @@
|
|||||||
import { RefObject, useEffect } from 'react';
|
|
||||||
|
|
||||||
const watchers: {
|
|
||||||
ref: RefObject<HTMLElement>;
|
|
||||||
enable: boolean;
|
|
||||||
callback: () => void;
|
|
||||||
}[] = [];
|
|
||||||
|
|
||||||
const useClickOutsideWatcher = () => {
|
|
||||||
useEffect(() => {
|
|
||||||
function handleClickOutside(e: MouseEvent) {
|
|
||||||
watchers.forEach(({ ref, enable, callback }) => {
|
|
||||||
if (enable && ref.current && !ref.current.contains(e.target as Node)) {
|
|
||||||
callback();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
document.addEventListener('mousedown', handleClickOutside);
|
|
||||||
return () => {
|
|
||||||
document.removeEventListener('mousedown', handleClickOutside);
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return {
|
|
||||||
addWatcher: (watcher: {
|
|
||||||
ref: RefObject<HTMLElement>;
|
|
||||||
callback: () => void;
|
|
||||||
enable: boolean;
|
|
||||||
}) => {
|
|
||||||
watchers.push(watcher);
|
|
||||||
},
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export default useClickOutsideWatcher;
|
|
@ -1,13 +1,22 @@
|
|||||||
let openFunction: () => void;
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
let openUploader = () => {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
const useImageUploader = () => {
|
const useImageUploader = () => {
|
||||||
return {
|
const setOpenUploaderFunction = useCallback(
|
||||||
setOpenUploader: (open?: () => void) => {
|
(openUploaderFunction?: () => void) => {
|
||||||
if (open) {
|
if (openUploaderFunction) {
|
||||||
openFunction = open;
|
openUploader = openUploaderFunction;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
openUploader: openFunction,
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
setOpenUploaderFunction,
|
||||||
|
openUploader,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
import React from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
export default function useUpdateTranslations(fn: () => void) {
|
|
||||||
const { i18n } = useTranslation();
|
|
||||||
const currentLang = localStorage.getItem('i18nextLng');
|
|
||||||
|
|
||||||
React.useEffect(() => {
|
|
||||||
fn();
|
|
||||||
}, [fn]);
|
|
||||||
|
|
||||||
React.useEffect(() => {
|
|
||||||
i18n.on('languageChanged', () => {
|
|
||||||
fn();
|
|
||||||
});
|
|
||||||
}, [fn, i18n, currentLang]);
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
import { createIcon } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
const ImageToImageIcon = createIcon({
|
|
||||||
displayName: 'ImageToImageIcon',
|
|
||||||
viewBox: '0 0 3543 3543',
|
|
||||||
path: (
|
|
||||||
<g transform="matrix(1.10943,0,0,1.10943,-206.981,-213.533)">
|
|
||||||
<path
|
|
||||||
fill="currentColor"
|
|
||||||
fillRule="evenodd"
|
|
||||||
clipRule="evenodd"
|
|
||||||
d="M688.533,2405.95L542.987,2405.95C349.532,2405.95 192.47,2248.89 192.47,2055.44L192.47,542.987C192.47,349.532 349.532,192.47 542.987,192.47L2527.88,192.47C2721.33,192.47 2878.4,349.532 2878.4,542.987L2878.4,1172.79L3023.94,1172.79C3217.4,1172.79 3374.46,1329.85 3374.46,1523.3C3374.46,1523.3 3374.46,3035.75 3374.46,3035.75C3374.46,3229.21 3217.4,3386.27 3023.94,3386.27L1039.05,3386.27C845.595,3386.27 688.533,3229.21 688.533,3035.75L688.533,2405.95ZM3286.96,2634.37L3286.96,1523.3C3286.96,1378.14 3169.11,1260.29 3023.94,1260.29C3023.94,1260.29 1039.05,1260.29 1039.05,1260.29C893.887,1260.29 776.033,1378.14 776.033,1523.3L776.033,2489.79L1440.94,1736.22L2385.83,2775.59L2880.71,2200.41L3286.96,2634.37ZM2622.05,1405.51C2778.5,1405.51 2905.51,1532.53 2905.51,1688.98C2905.51,1845.42 2778.5,1972.44 2622.05,1972.44C2465.6,1972.44 2338.58,1845.42 2338.58,1688.98C2338.58,1532.53 2465.6,1405.51 2622.05,1405.51ZM2790.9,1172.79L1323.86,1172.79L944.882,755.906L279.97,1509.47L279.97,542.987C279.97,397.824 397.824,279.97 542.987,279.97C542.987,279.97 2527.88,279.97 2527.88,279.97C2673.04,279.97 2790.9,397.824 2790.9,542.987L2790.9,1172.79ZM2125.98,425.197C2282.43,425.197 2409.45,552.213 2409.45,708.661C2409.45,865.11 2282.43,992.126 2125.98,992.126C1969.54,992.126 1842.52,865.11 1842.52,708.661C1842.52,552.213 1969.54,425.197 2125.98,425.197Z"
|
|
||||||
/>
|
|
||||||
</g>
|
|
||||||
),
|
|
||||||
defaultProps: {
|
|
||||||
boxSize: '24px',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
export default ImageToImageIcon;
|
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
|||||||
import { createIcon } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
const NodesIcon = createIcon({
|
|
||||||
displayName: 'NodesIcon',
|
|
||||||
viewBox: '0 0 3543 3543',
|
|
||||||
path: (
|
|
||||||
<path
|
|
||||||
fill="currentColor"
|
|
||||||
fillRule="evenodd"
|
|
||||||
clipRule="evenodd"
|
|
||||||
d="M3543.31,770.787C3543.31,515.578 3336.11,308.38 3080.9,308.38L462.407,308.38C207.197,308.38 0,515.578 0,770.787L0,2766.03C0,3021.24 207.197,3228.44 462.407,3228.44L3080.9,3228.44C3336.11,3228.44 3543.31,3021.24 3543.31,2766.03C3543.31,2766.03 3543.31,770.787 3543.31,770.787ZM3427.88,770.787L3427.88,2766.03C3427.88,2957.53 3272.4,3113.01 3080.9,3113.01C3080.9,3113.01 462.407,3113.01 462.407,3113.01C270.906,3113.01 115.431,2957.53 115.431,2766.03L115.431,770.787C115.431,579.286 270.906,423.812 462.407,423.812L3080.9,423.812C3272.4,423.812 3427.88,579.286 3427.88,770.787ZM1214.23,1130.69L1321.47,1130.69C1324.01,1130.69 1326.54,1130.53 1329.05,1130.2C1329.05,1130.2 1367.3,1125.33 1397.94,1149.8C1421.63,1168.72 1437.33,1204.3 1437.33,1265.48L1437.33,2078.74L1220.99,2078.74C1146.83,2078.74 1086.61,2138.95 1086.61,2213.12L1086.61,2762.46C1086.61,2836.63 1146.83,2896.84 1220.99,2896.84L1770.34,2896.84C1844.5,2896.84 1904.71,2836.63 1904.71,2762.46L1904.71,2213.12C1904.71,2138.95 1844.5,2078.74 1770.34,2078.74L1554,2078.74L1554,1604.84C1625.84,1658.19 1703.39,1658.1 1703.39,1658.1C1703.54,1658.1 1703.69,1658.11 1703.84,1658.11L2362.2,1658.11L2362.2,1874.44C2362.2,1948.61 2422.42,2008.82 2496.58,2008.82L3045.93,2008.82C3120.09,2008.82 3180.3,1948.61 3180.3,1874.44L3180.3,1325.1C3180.3,1250.93 3120.09,1190.72 3045.93,1190.72L2496.58,1190.72C2422.42,1190.72 2362.2,1250.93 2362.2,1325.1L2362.2,1558.97L2362.2,1541.44L1704.23,1541.44C1702.2,1541.37 1650.96,1539.37 1609.51,1499.26C1577.72,1468.49 1554,1416.47 1554,1331.69L1554,1265.48C1554,1153.86 1513.98,1093.17 1470.76,1058.64C1411.24,1011.1 1338.98,1012.58 1319.15,1014.03L1214.23,1014.03L1214.23,796.992C1214.23,722.828 1154.02,662.617 1079.85,662.617L530.507,662.617C456.343,662.617 396.131,722.828 396.131,796.992L396.131,1346.34C396.131,1420.5 456.343,1480.71 530.507,1480.71L1079.85,1480.71C1154.02,1480.71 1214.23,1420.5 1214.23,1346.34L1214.23,1130.69Z"
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
defaultProps: {
|
|
||||||
boxSize: '24px',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export default NodesIcon;
|
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
|||||||
import { createIcon } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
const PostprocessingIcon = createIcon({
|
|
||||||
displayName: 'PostprocessingIcon',
|
|
||||||
viewBox: '0 0 3543 3543',
|
|
||||||
path: (
|
|
||||||
<path
|
|
||||||
fill="currentColor"
|
|
||||||
fillRule="evenodd"
|
|
||||||
clipRule="evenodd"
|
|
||||||
d="M709.477,1596.53L992.591,1275.66L2239.09,2646.81L2891.95,1888.03L3427.88,2460.51L3427.88,994.78C3427.88,954.66 3421.05,916.122 3408.5,880.254L3521.9,855.419C3535.8,899.386 3543.31,946.214 3543.31,994.78L3543.31,2990.02C3543.31,3245.23 3336.11,3452.43 3080.9,3452.43C3080.9,3452.43 462.407,3452.43 462.407,3452.43C207.197,3452.43 -0,3245.23 -0,2990.02L-0,994.78C-0,739.571 207.197,532.373 462.407,532.373L505.419,532.373L504.644,532.546L807.104,600.085C820.223,601.729 832.422,607.722 841.77,617.116C850.131,625.517 855.784,636.21 858.055,647.804L462.407,647.804C270.906,647.804 115.431,803.279 115.431,994.78L115.431,2075.73L-0,2101.5L115.431,2127.28L115.431,2269.78L220.47,2150.73L482.345,2209.21C503.267,2211.83 522.722,2221.39 537.63,2236.37C552.538,2251.35 562.049,2270.9 564.657,2291.93L671.84,2776.17L779.022,2291.93C781.631,2270.9 791.141,2251.35 806.05,2236.37C820.958,2221.39 840.413,2211.83 861.334,2209.21L1353.15,2101.5L861.334,1993.8C840.413,1991.18 820.958,1981.62 806.05,1966.64C791.141,1951.66 781.631,1932.11 779.022,1911.08L709.477,1596.53ZM671.84,1573.09L725.556,2006.07C726.863,2016.61 731.63,2026.4 739.101,2033.91C746.573,2041.42 756.323,2046.21 766.808,2047.53L1197.68,2101.5L766.808,2155.48C756.323,2156.8 746.573,2161.59 739.101,2169.09C731.63,2176.6 726.863,2186.4 725.556,2196.94L671.84,2629.92L618.124,2196.94C616.817,2186.4 612.05,2176.6 604.579,2169.09C597.107,2161.59 587.357,2156.8 576.872,2155.48L146.001,2101.5L576.872,2047.53C587.357,2046.21 597.107,2041.42 604.579,2033.91C612.05,2026.4 616.817,2016.61 618.124,2006.07L671.84,1573.09ZM609.035,1710.36L564.657,1911.08C562.049,1932.11 552.538,1951.66 537.63,1966.64C522.722,1981.62 503.267,1991.18 482.345,1993.8L328.665,2028.11L609.035,1710.36ZM2297.12,938.615L2451.12,973.003C2480.59,976.695 2507.99,990.158 2528.99,1011.26C2549.99,1032.37 2563.39,1059.9 2567.07,1089.52L2672.73,1566.9C2634.5,1580.11 2593.44,1587.29 2550.72,1587.29C2344.33,1587.29 2176.77,1419.73 2176.77,1213.34C2176.77,1104.78 2223.13,1006.96 2297.12,938.615ZM2718.05,76.925L2793.72,686.847C2795.56,701.69 2802.27,715.491 2812.8,726.068C2823.32,736.644 2837.06,743.391 2851.83,745.242L3458.78,821.28L2851.83,897.318C2837.06,899.168 2823.32,905.916 2812.8,916.492C2802.27,927.068 2795.56,940.87 2793.72,955.712L2718.05,1565.63L2642.38,955.712C2640.54,940.87 2633.83,927.068 2623.3,916.492C2612.78,905.916 2599.04,899.168 2584.27,897.318L1977.32,821.28L2584.27,745.242C2599.04,743.391 2612.78,736.644 2623.3,726.068C2633.83,715.491 2640.54,701.69 2642.38,686.847L2718.05,76.925ZM2883.68,1043.06C2909.88,1094.13 2924.67,1152.02 2924.67,1213.34C2924.67,1335.4 2866.06,1443.88 2775.49,1512.14L2869.03,1089.52C2871.07,1073.15 2876.07,1057.42 2883.68,1043.06ZM925.928,201.2L959.611,472.704C960.431,479.311 963.42,485.455 968.105,490.163C972.79,494.871 978.904,497.875 985.479,498.698L1255.66,532.546L985.479,566.395C978.904,567.218 972.79,570.222 968.105,574.93C963.42,579.638 960.431,585.781 959.611,592.388L925.928,863.893L892.245,592.388C891.425,585.781 888.436,579.638 883.751,574.93C879.066,570.222 872.952,567.218 866.378,566.395L596.195,532.546L866.378,498.698C872.952,497.875 879.066,494.871 883.751,490.163C888.436,485.455 891.425,479.311 892.245,472.704L925.928,201.2ZM2864.47,532.373L3080.9,532.373C3258.7,532.373 3413.2,632.945 3490.58,780.281L3319.31,742.773C3257.14,683.925 3173.2,647.804 3080.9,647.804L2927.07,647.804C2919.95,642.994 2913.25,637.473 2907.11,631.298C2886.11,610.194 2872.71,582.655 2869.03,553.04L2864.47,532.373ZM1352.36,532.373L2571.64,532.373L2567.07,553.04C2563.39,582.655 2549.99,610.194 2528.99,631.298C2522.85,637.473 2516.16,642.994 2509.03,647.804L993.801,647.804C996.072,636.21 1001.73,625.517 1010.09,617.116C1019.43,607.722 1031.63,601.729 1044.75,600.085L1353.15,532.546L1352.36,532.373Z"
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
defaultProps: {
|
|
||||||
boxSize: '24px',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export default PostprocessingIcon;
|
|
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
|||||||
import { createIcon } from '@chakra-ui/react';
|
|
||||||
|
|
||||||
const TrainingIcon = createIcon({
|
|
||||||
displayName: 'TrainingIcon',
|
|
||||||
viewBox: '0 0 3544 3544',
|
|
||||||
path: (
|
|
||||||
<path
|
|
||||||
fill="currentColor"
|
|
||||||
fillRule="evenodd"
|
|
||||||
clipRule="evenodd"
|
|
||||||
d="M0,768.593L0,2774.71C0,2930.6 78.519,3068.3 198.135,3150.37C273.059,3202.68 364.177,3233.38 462.407,3233.38C462.407,3233.38 3080.9,3233.38 3080.9,3233.38C3179.13,3233.38 3270.25,3202.68 3345.17,3150.37C3464.79,3068.3 3543.31,2930.6 3543.31,2774.71L3543.31,768.593C3543.31,517.323 3339.31,313.324 3088.04,313.324L455.269,313.324C203.999,313.324 0,517.323 0,768.593ZM3427.88,775.73L3427.88,2770.97C3427.88,2962.47 3272.4,3117.95 3080.9,3117.95L462.407,3117.95C270.906,3117.95 115.431,2962.47 115.431,2770.97C115.431,2770.97 115.431,775.73 115.431,775.73C115.431,584.229 270.906,428.755 462.407,428.755C462.407,428.755 3080.9,428.755 3080.9,428.755C3272.4,428.755 3427.88,584.229 3427.88,775.73ZM796.24,1322.76L796.24,1250.45C796.24,1199.03 836.16,1157.27 885.331,1157.27C885.331,1157.27 946.847,1157.27 946.847,1157.27C996.017,1157.27 1035.94,1199.03 1035.94,1250.45L1035.94,1644.81L2507.37,1644.81L2507.37,1250.45C2507.37,1199.03 2547.29,1157.27 2596.46,1157.27C2596.46,1157.27 2657.98,1157.27 2657.98,1157.27C2707.15,1157.27 2747.07,1199.03 2747.07,1250.45L2747.07,1322.76C2756.66,1319.22 2767.02,1317.29 2777.83,1317.29C2777.83,1317.29 2839.34,1317.29 2839.34,1317.29C2888.51,1317.29 2928.43,1357.21 2928.43,1406.38L2928.43,1527.32C2933.51,1526.26 2938.77,1525.71 2944.16,1525.71L2995.3,1525.71C3036.18,1525.71 3069.37,1557.59 3069.37,1596.86C3069.37,1596.86 3069.37,1946.44 3069.37,1946.44C3069.37,1985.72 3036.18,2017.6 2995.3,2017.6C2995.3,2017.6 2944.16,2017.6 2944.16,2017.6C2938.77,2017.6 2933.51,2017.04 2928.43,2015.99L2928.43,2136.92C2928.43,2186.09 2888.51,2226.01 2839.34,2226.01L2777.83,2226.01C2767.02,2226.01 2756.66,2224.08 2747.07,2220.55L2747.07,2292.85C2747.07,2344.28 2707.15,2386.03 2657.98,2386.03C2657.98,2386.03 2596.46,2386.03 2596.46,2386.03C2547.29,2386.03 2507.37,2344.28 2507.37,2292.85L2507.37,1898.5L1035.94,1898.5L1035.94,2292.85C1035.94,2344.28 996.017,2386.03 946.847,2386.03C946.847,2386.03 885.331,2386.03 885.331,2386.03C836.16,2386.03 796.24,2344.28 796.24,2292.85L796.24,2220.55C786.651,2224.08 776.29,2226.01 765.482,2226.01L703.967,2226.01C654.796,2226.01 614.876,2186.09 614.876,2136.92L614.876,2015.99C609.801,2017.04 604.539,2017.6 599.144,2017.6C599.144,2017.6 548.003,2017.6 548.003,2017.6C507.125,2017.6 473.937,1985.72 473.937,1946.44C473.937,1946.44 473.937,1596.86 473.937,1596.86C473.937,1557.59 507.125,1525.71 548.003,1525.71L599.144,1525.71C604.539,1525.71 609.801,1526.26 614.876,1527.32L614.876,1406.38C614.876,1357.21 654.796,1317.29 703.967,1317.29C703.967,1317.29 765.482,1317.29 765.482,1317.29C776.29,1317.29 786.651,1319.22 796.24,1322.76ZM977.604,1250.45C977.604,1232.7 963.822,1218.29 946.847,1218.29L885.331,1218.29C868.355,1218.29 854.573,1232.7 854.573,1250.45L854.573,2292.85C854.573,2310.61 868.355,2325.02 885.331,2325.02L946.847,2325.02C963.822,2325.02 977.604,2310.61 977.604,2292.85L977.604,1250.45ZM2565.7,1250.45C2565.7,1232.7 2579.49,1218.29 2596.46,1218.29L2657.98,1218.29C2674.95,1218.29 2688.73,1232.7 2688.73,1250.45L2688.73,2292.85C2688.73,2310.61 2674.95,2325.02 2657.98,2325.02L2596.46,2325.02C2579.49,2325.02 2565.7,2310.61 2565.7,2292.85L2565.7,1250.45ZM673.209,1406.38L673.209,2136.92C673.209,2153.9 686.991,2167.68 703.967,2167.68L765.482,2167.68C782.458,2167.68 796.24,2153.9 796.24,2136.92L796.24,1406.38C796.24,1389.41 782.458,1375.63 765.482,1375.63L703.967,1375.63C686.991,1375.63 673.209,1389.41 673.209,1406.38ZM2870.1,1406.38L2870.1,2136.92C2870.1,2153.9 2856.32,2167.68 2839.34,2167.68L2777.83,2167.68C2760.85,2167.68 2747.07,2153.9 2747.07,2136.92L2747.07,1406.38C2747.07,1389.41 2760.85,1375.63 2777.83,1375.63L2839.34,1375.63C2856.32,1375.63 2870.1,1389.41 2870.1,1406.38ZM614.876,1577.5C610.535,1574.24 605.074,1572.3 599.144,1572.3L548.003,1572.3C533.89,1572.3 522.433,1583.3 522.433,1596.86L522.433,1946.44C522.433,1960 533.89,1971.01 548.003,1971.01L599.144,1971.01C605.074,1971.01 610.535,1969.07 614.876,1965.81L614.876,1577.5ZM2928.43,1965.81L2928.43,1577.5C2932.77,1574.24 2938.23,1572.3 2944.16,1572.3L2995.3,1572.3C3009.42,1572.3 3020.87,1583.3 3020.87,1596.86L3020.87,1946.44C3020.87,1960 3009.42,1971.01 2995.3,1971.01L2944.16,1971.01C2938.23,1971.01 2932.77,1969.07 2928.43,1965.81ZM2507.37,1703.14L1035.94,1703.14L1035.94,1840.16L2507.37,1840.16L2507.37,1898.38L2507.37,1659.46L2507.37,1703.14Z"
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
defaultProps: {
|
|
||||||
boxSize: '24px',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export default TrainingIcon;
|
|
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
@ -1,7 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
||||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
|
||||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
|
||||||
<g transform="matrix(1.10943,0,0,1.10943,-206.981,-213.533)">
|
|
||||||
<path d="M688.533,2405.95L542.987,2405.95C349.532,2405.95 192.47,2248.89 192.47,2055.44L192.47,542.987C192.47,349.532 349.532,192.47 542.987,192.47L2527.88,192.47C2721.33,192.47 2878.4,349.532 2878.4,542.987L2878.4,1172.79L3023.94,1172.79C3217.4,1172.79 3374.46,1329.85 3374.46,1523.3C3374.46,1523.3 3374.46,3035.75 3374.46,3035.75C3374.46,3229.21 3217.4,3386.27 3023.94,3386.27L1039.05,3386.27C845.595,3386.27 688.533,3229.21 688.533,3035.75L688.533,2405.95ZM3286.96,2634.37L3286.96,1523.3C3286.96,1378.14 3169.11,1260.29 3023.94,1260.29C3023.94,1260.29 1039.05,1260.29 1039.05,1260.29C893.887,1260.29 776.033,1378.14 776.033,1523.3L776.033,2489.79L1440.94,1736.22L2385.83,2775.59L2880.71,2200.41L3286.96,2634.37ZM2622.05,1405.51C2778.5,1405.51 2905.51,1532.53 2905.51,1688.98C2905.51,1845.42 2778.5,1972.44 2622.05,1972.44C2465.6,1972.44 2338.58,1845.42 2338.58,1688.98C2338.58,1532.53 2465.6,1405.51 2622.05,1405.51ZM2790.9,1172.79L1323.86,1172.79L944.882,755.906L279.97,1509.47L279.97,542.987C279.97,397.824 397.824,279.97 542.987,279.97C542.987,279.97 2527.88,279.97 2527.88,279.97C2673.04,279.97 2790.9,397.824 2790.9,542.987L2790.9,1172.79ZM2125.98,425.197C2282.43,425.197 2409.45,552.213 2409.45,708.661C2409.45,865.11 2282.43,992.126 2125.98,992.126C1969.54,992.126 1842.52,865.11 1842.52,708.661C1842.52,552.213 1969.54,425.197 2125.98,425.197Z"/>
|
|
||||||
</g>
|
|
||||||
</svg>
|
|
Before Width: | Height: | Size: 1.9 KiB |
Binary file not shown.
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 8.9 KiB |
Binary file not shown.
@ -1,5 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
||||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
|
||||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
|
||||||
<path d="M3543.31,770.787C3543.31,515.578 3336.11,308.38 3080.9,308.38L462.407,308.38C207.197,308.38 0,515.578 0,770.787L0,2766.03C0,3021.24 207.197,3228.44 462.407,3228.44L3080.9,3228.44C3336.11,3228.44 3543.31,3021.24 3543.31,2766.03C3543.31,2766.03 3543.31,770.787 3543.31,770.787ZM3427.88,770.787L3427.88,2766.03C3427.88,2957.53 3272.4,3113.01 3080.9,3113.01C3080.9,3113.01 462.407,3113.01 462.407,3113.01C270.906,3113.01 115.431,2957.53 115.431,2766.03L115.431,770.787C115.431,579.286 270.906,423.812 462.407,423.812L3080.9,423.812C3272.4,423.812 3427.88,579.286 3427.88,770.787ZM1214.23,1130.69L1321.47,1130.69C1324.01,1130.69 1326.54,1130.53 1329.05,1130.2C1329.05,1130.2 1367.3,1125.33 1397.94,1149.8C1421.63,1168.72 1437.33,1204.3 1437.33,1265.48L1437.33,2078.74L1220.99,2078.74C1146.83,2078.74 1086.61,2138.95 1086.61,2213.12L1086.61,2762.46C1086.61,2836.63 1146.83,2896.84 1220.99,2896.84L1770.34,2896.84C1844.5,2896.84 1904.71,2836.63 1904.71,2762.46L1904.71,2213.12C1904.71,2138.95 1844.5,2078.74 1770.34,2078.74L1554,2078.74L1554,1604.84C1625.84,1658.19 1703.39,1658.1 1703.39,1658.1C1703.54,1658.1 1703.69,1658.11 1703.84,1658.11L2362.2,1658.11L2362.2,1874.44C2362.2,1948.61 2422.42,2008.82 2496.58,2008.82L3045.93,2008.82C3120.09,2008.82 3180.3,1948.61 3180.3,1874.44L3180.3,1325.1C3180.3,1250.93 3120.09,1190.72 3045.93,1190.72L2496.58,1190.72C2422.42,1190.72 2362.2,1250.93 2362.2,1325.1L2362.2,1558.97L2362.2,1541.44L1704.23,1541.44C1702.2,1541.37 1650.96,1539.37 1609.51,1499.26C1577.72,1468.49 1554,1416.47 1554,1331.69L1554,1265.48C1554,1153.86 1513.98,1093.17 1470.76,1058.64C1411.24,1011.1 1338.98,1012.58 1319.15,1014.03L1214.23,1014.03L1214.23,796.992C1214.23,722.828 1154.02,662.617 1079.85,662.617L530.507,662.617C456.343,662.617 396.131,722.828 396.131,796.992L396.131,1346.34C396.131,1420.5 456.343,1480.71 530.507,1480.71L1079.85,1480.71C1154.02,1480.71 1214.23,1420.5 1214.23,1346.34L1214.23,1130.69Z"/>
|
|
||||||
</svg>
|
|
Before Width: | Height: | Size: 2.3 KiB |
Binary file not shown.
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 6.3 KiB |
Binary file not shown.
@ -1,5 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
||||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
|
||||||
<svg width="100%" height="100%" viewBox="0 0 3543 3543" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:1.5;">
|
|
||||||
<path d="M709.477,1596.53L992.591,1275.66L2239.09,2646.81L2891.95,1888.03L3427.88,2460.51L3427.88,994.78C3427.88,954.66 3421.05,916.122 3408.5,880.254L3521.9,855.419C3535.8,899.386 3543.31,946.214 3543.31,994.78L3543.31,2990.02C3543.31,3245.23 3336.11,3452.43 3080.9,3452.43C3080.9,3452.43 462.407,3452.43 462.407,3452.43C207.197,3452.43 -0,3245.23 -0,2990.02L-0,994.78C-0,739.571 207.197,532.373 462.407,532.373L505.419,532.373L504.644,532.546L807.104,600.085C820.223,601.729 832.422,607.722 841.77,617.116C850.131,625.517 855.784,636.21 858.055,647.804L462.407,647.804C270.906,647.804 115.431,803.279 115.431,994.78L115.431,2075.73L-0,2101.5L115.431,2127.28L115.431,2269.78L220.47,2150.73L482.345,2209.21C503.267,2211.83 522.722,2221.39 537.63,2236.37C552.538,2251.35 562.049,2270.9 564.657,2291.93L671.84,2776.17L779.022,2291.93C781.631,2270.9 791.141,2251.35 806.05,2236.37C820.958,2221.39 840.413,2211.83 861.334,2209.21L1353.15,2101.5L861.334,1993.8C840.413,1991.18 820.958,1981.62 806.05,1966.64C791.141,1951.66 781.631,1932.11 779.022,1911.08L709.477,1596.53ZM671.84,1573.09L725.556,2006.07C726.863,2016.61 731.63,2026.4 739.101,2033.91C746.573,2041.42 756.323,2046.21 766.808,2047.53L1197.68,2101.5L766.808,2155.48C756.323,2156.8 746.573,2161.59 739.101,2169.09C731.63,2176.6 726.863,2186.4 725.556,2196.94L671.84,2629.92L618.124,2196.94C616.817,2186.4 612.05,2176.6 604.579,2169.09C597.107,2161.59 587.357,2156.8 576.872,2155.48L146.001,2101.5L576.872,2047.53C587.357,2046.21 597.107,2041.42 604.579,2033.91C612.05,2026.4 616.817,2016.61 618.124,2006.07L671.84,1573.09ZM609.035,1710.36L564.657,1911.08C562.049,1932.11 552.538,1951.66 537.63,1966.64C522.722,1981.62 503.267,1991.18 482.345,1993.8L328.665,2028.11L609.035,1710.36ZM2297.12,938.615L2451.12,973.003C2480.59,976.695 2507.99,990.158 2528.99,1011.26C2549.99,1032.37 2563.39,1059.9 2567.07,1089.52L2672.73,1566.9C2634.5,1580.11 2593.44,1587.29 2550.72,1587.29C2344.33,1587.29 2176.77,1419.73 2176.77,1213.34C2176.77,1104.78 2223.13,1006.96 2297.12,938.615ZM2718.05,76.925L2793.72,686.847C2795.56,701.69 2802.27,715.491 2812.8,726.068C2823.32,736.644 2837.06,743.391 2851.83,745.242L3458.78,821.28L2851.83,897.318C2837.06,899.168 2823.32,905.916 2812.8,916.492C2802.27,927.068 2795.56,940.87 2793.72,955.712L2718.05,1565.63L2642.38,955.712C2640.54,940.87 2633.83,927.068 2623.3,916.492C2612.78,905.916 2599.04,899.168 2584.27,897.318L1977.32,821.28L2584.27,745.242C2599.04,743.391 2612.78,736.644 2623.3,726.068C2633.83,715.491 2640.54,701.69 2642.38,686.847L2718.05,76.925ZM2883.68,1043.06C2909.88,1094.13 2924.67,1152.02 2924.67,1213.34C2924.67,1335.4 2866.06,1443.88 2775.49,1512.14L2869.03,1089.52C2871.07,1073.15 2876.07,1057.42 2883.68,1043.06ZM925.928,201.2L959.611,472.704C960.431,479.311 963.42,485.455 968.105,490.163C972.79,494.871 978.904,497.875 985.479,498.698L1255.66,532.546L985.479,566.395C978.904,567.218 972.79,570.222 968.105,574.93C963.42,579.638 960.431,585.781 959.611,592.388L925.928,863.893L892.245,592.388C891.425,585.781 888.436,579.638 883.751,574.93C879.066,570.222 872.952,567.218 866.378,566.395L596.195,532.546L866.378,498.698C872.952,497.875 879.066,494.871 883.751,490.163C888.436,485.455 891.425,479.311 892.245,472.704L925.928,201.2ZM2864.47,532.373L3080.9,532.373C3258.7,532.373 3413.2,632.945 3490.58,780.281L3319.31,742.773C3257.14,683.925 3173.2,647.804 3080.9,647.804L2927.07,647.804C2919.95,642.994 2913.25,637.473 2907.11,631.298C2886.11,610.194 2872.71,582.655 2869.03,553.04L2864.47,532.373ZM1352.36,532.373L2571.64,532.373L2567.07,553.04C2563.39,582.655 2549.99,610.194 2528.99,631.298C2522.85,637.473 2516.16,642.994 2509.03,647.804L993.801,647.804C996.072,636.21 1001.73,625.517 1010.09,617.116C1019.43,607.722 1031.63,601.729 1044.75,600.085L1353.15,532.546L1352.36,532.373Z" style="stroke:white;stroke-opacity:0;stroke-width:1px;"/>
|
|
||||||
</svg>
|
|
Before Width: | Height: | Size: 4.2 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user