mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/onnx
This commit is contained in:
commit
da751da3dd
@ -1 +1,2 @@
|
|||||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||||
|
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||||
|
27
.github/workflows/style-checks.yml
vendored
Normal file
27
.github/workflows/style-checks.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: Black # TODO: add isort and flake8 later
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request: {}
|
||||||
|
push:
|
||||||
|
branches: master
|
||||||
|
tags: "*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies with pip
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip wheel
|
||||||
|
pip install .[test]
|
||||||
|
|
||||||
|
# - run: isort --check-only .
|
||||||
|
- run: black --check .
|
||||||
|
# - run: flake8
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,7 +38,6 @@ develop-eggs/
|
|||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
lib/
|
|
||||||
lib64/
|
lib64/
|
||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
|
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# See https://pre-commit.com/ for usage and config
|
||||||
|
repos:
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
name: black
|
||||||
|
stages: [commit]
|
||||||
|
language: system
|
||||||
|
entry: black
|
||||||
|
types: [python]
|
@ -123,7 +123,7 @@ and go to http://localhost:9090.
|
|||||||
|
|
||||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||||
|
|
||||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
|
||||||
later versions are not supported.
|
later versions are not supported.
|
||||||
Node.js also needs to be installed along with yarn (can be installed with
|
Node.js also needs to be installed along with yarn (can be installed with
|
||||||
the command `npm install -g yarn` if needed)
|
the command `npm install -g yarn` if needed)
|
||||||
|
@ -40,10 +40,8 @@ experimental versions later.
|
|||||||
this, open up a command-line window ("Terminal" on Linux and
|
this, open up a command-line window ("Terminal" on Linux and
|
||||||
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
||||||
--version`. If Python is installed, it will print out the version
|
--version`. If Python is installed, it will print out the version
|
||||||
number. If it is version `3.9.*` or `3.10.*`, you meet
|
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
|
||||||
requirements. We do not recommend using Python 3.11 or higher,
|
requirements.
|
||||||
as not all the libraries that InvokeAI depends on work properly
|
|
||||||
with this version.
|
|
||||||
|
|
||||||
!!! warning "What to do if you have an unsupported version"
|
!!! warning "What to do if you have an unsupported version"
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ gaming):
|
|||||||
|
|
||||||
* **Python**
|
* **Python**
|
||||||
|
|
||||||
version 3.9 or 3.10 (3.11 is not recommended).
|
version 3.9 through 3.11
|
||||||
|
|
||||||
* **CUDA Tools**
|
* **CUDA Tools**
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ gaming):
|
|||||||
To install InvokeAI with virtual environments and the PIP package
|
To install InvokeAI with virtual environments and the PIP package
|
||||||
manager, please follow these steps:
|
manager, please follow these steps:
|
||||||
|
|
||||||
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
|
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
|
||||||
procedure depends on this and will not work with other versions:
|
procedure depends on this and will not work with other versions:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -9,16 +9,20 @@ cd $scriptdir
|
|||||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||||
|
|
||||||
MINIMUM_PYTHON_VERSION=3.9.0
|
MINIMUM_PYTHON_VERSION=3.9.0
|
||||||
MAXIMUM_PYTHON_VERSION=3.11.0
|
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||||
PYTHON=""
|
PYTHON=""
|
||||||
for candidate in python3.10 python3.9 python3 python ; do
|
for candidate in python3.11 python3.10 python3.9 python3 python ; do
|
||||||
if ppath=`which $candidate`; then
|
if ppath=`which $candidate`; then
|
||||||
|
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||||
|
# we check that this found executable can actually run
|
||||||
|
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||||
|
|
||||||
python_version=$($ppath -V | awk '{ print $2 }')
|
python_version=$($ppath -V | awk '{ print $2 }')
|
||||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||||
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||||
PYTHON=$ppath
|
PYTHON=$ppath
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
@ -141,15 +141,16 @@ class Installer:
|
|||||||
|
|
||||||
# upgrade pip in Python 3.9 environments
|
# upgrade pip in Python 3.9 environments
|
||||||
if int(platform.python_version_tuple()[1]) == 9:
|
if int(platform.python_version_tuple()[1]) == 9:
|
||||||
|
|
||||||
from plumbum import FG, local
|
from plumbum import FG, local
|
||||||
|
|
||||||
pip = local[get_pip_from_venv(venv_dir)]
|
pip = local[get_pip_from_venv(venv_dir)]
|
||||||
pip[ "install", "--upgrade", "pip"] & FG
|
pip["install", "--upgrade", "pip"] & FG
|
||||||
|
|
||||||
return venv_dir
|
return venv_dir
|
||||||
|
|
||||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
def install(
|
||||||
|
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Install the InvokeAI application into the given runtime path
|
Install the InvokeAI application into the given runtime path
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ class Installer:
|
|||||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||||
|
|
||||||
# install dependencies and the InvokeAI application
|
# install dependencies and the InvokeAI application
|
||||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||||
self.instance.install(
|
self.instance.install(
|
||||||
extra_index_url,
|
extra_index_url,
|
||||||
optional_modules,
|
optional_modules,
|
||||||
@ -188,6 +189,7 @@ class Installer:
|
|||||||
# run through the configuration flow
|
# run through the configuration flow
|
||||||
self.instance.configure()
|
self.instance.configure()
|
||||||
|
|
||||||
|
|
||||||
class InvokeAiInstance:
|
class InvokeAiInstance:
|
||||||
"""
|
"""
|
||||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||||
@ -196,7 +198,6 @@ class InvokeAiInstance:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||||
|
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
self.pip = get_pip_from_venv(venv)
|
self.pip = get_pip_from_venv(venv)
|
||||||
@ -312,7 +313,7 @@ class InvokeAiInstance:
|
|||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"--use-pep517",
|
"--use-pep517",
|
||||||
str(src)+(optional_modules if optional_modules else ''),
|
str(src) + (optional_modules if optional_modules else ""),
|
||||||
"--find-links" if find_links is not None else None,
|
"--find-links" if find_links is not None else None,
|
||||||
find_links,
|
find_links,
|
||||||
"--extra-index-url" if extra_index_url is not None else None,
|
"--extra-index-url" if extra_index_url is not None else None,
|
||||||
@ -329,12 +330,12 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
# set sys.argv to a consistent state
|
# set sys.argv to a consistent state
|
||||||
new_argv = [sys.argv[0]]
|
new_argv = [sys.argv[0]]
|
||||||
for i in range(1,len(sys.argv)):
|
for i in range(1, len(sys.argv)):
|
||||||
el = sys.argv[i]
|
el = sys.argv[i]
|
||||||
if el in ['-r','--root']:
|
if el in ["-r", "--root"]:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
new_argv.append(sys.argv[i+1])
|
new_argv.append(sys.argv[i + 1])
|
||||||
elif el in ['-y','--yes','--yes-to-all']:
|
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
sys.argv = new_argv
|
sys.argv = new_argv
|
||||||
|
|
||||||
@ -353,16 +354,16 @@ class InvokeAiInstance:
|
|||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
succeeded = True
|
succeeded = True
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
if not succeeded:
|
if not succeeded:
|
||||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||||
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||||
print('Alternatively you can relaunch the installer.')
|
print("Alternatively you can relaunch the installer.")
|
||||||
|
|
||||||
def install_user_scripts(self):
|
def install_user_scripts(self):
|
||||||
"""
|
"""
|
||||||
@ -371,11 +372,11 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
ext = "bat" if OS == "Windows" else "sh"
|
ext = "bat" if OS == "Windows" else "sh"
|
||||||
|
|
||||||
#scripts = ['invoke', 'update']
|
# scripts = ['invoke', 'update']
|
||||||
scripts = ['invoke']
|
scripts = ["invoke"]
|
||||||
|
|
||||||
for script in scripts:
|
for script in scripts:
|
||||||
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||||
dest = self.runtime / f"{script}.{ext}"
|
dest = self.runtime / f"{script}.{ext}"
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
os.chmod(dest, 0o0755)
|
os.chmod(dest, 0o0755)
|
||||||
@ -420,11 +421,7 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
# filter out any paths in sys.path that may be system- or user-wide
|
# filter out any paths in sys.path that may be system- or user-wide
|
||||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||||
# temporarily need at install time
|
# temporarily need at install time
|
||||||
sys.path = list(filter(
|
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||||
lambda p: not p.endswith("-packages")
|
|
||||||
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
|
||||||
sys.path
|
|
||||||
))
|
|
||||||
|
|
||||||
# determine site-packages/lib directory location for the venv
|
# determine site-packages/lib directory location for the venv
|
||||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||||
@ -433,7 +430,7 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||||
|
|
||||||
|
|
||||||
def get_torch_source() -> (Union[str, None],str):
|
def get_torch_source() -> (Union[str, None], str):
|
||||||
"""
|
"""
|
||||||
Determine the extra index URL for pip to use for torch installation.
|
Determine the extra index URL for pip to use for torch installation.
|
||||||
This depends on the OS and the graphics accelerator in use.
|
This depends on the OS and the graphics accelerator in use.
|
||||||
@ -461,9 +458,9 @@ def get_torch_source() -> (Union[str, None],str):
|
|||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == "cuda":
|
||||||
url = 'https://download.pytorch.org/whl/cu117'
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
optional_modules = '[xformers]'
|
optional_modules = "[xformers]"
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
@ -36,13 +36,15 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def welcome():
|
def welcome():
|
||||||
|
|
||||||
@group()
|
@group()
|
||||||
def text():
|
def text():
|
||||||
if (platform_specific := _platform_specific_help()) != "":
|
if (platform_specific := _platform_specific_help()) != "":
|
||||||
yield platform_specific
|
yield platform_specific
|
||||||
yield ""
|
yield ""
|
||||||
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
yield Text.from_markup(
|
||||||
|
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||||
|
justify="center",
|
||||||
|
)
|
||||||
|
|
||||||
console.rule()
|
console.rule()
|
||||||
print(
|
print(
|
||||||
@ -58,6 +60,7 @@ def welcome():
|
|||||||
)
|
)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
|
|
||||||
def confirm_install(dest: Path) -> bool:
|
def confirm_install(dest: Path) -> bool:
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||||
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
|
|||||||
dest_confirmed = confirm_install(dest)
|
dest_confirmed = confirm_install(dest)
|
||||||
|
|
||||||
while not dest_confirmed:
|
while not dest_confirmed:
|
||||||
|
|
||||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||||
@ -300,15 +302,20 @@ def introduction() -> None:
|
|||||||
)
|
)
|
||||||
console.line(2)
|
console.line(2)
|
||||||
|
|
||||||
def _platform_specific_help()->str:
|
|
||||||
|
def _platform_specific_help() -> str:
|
||||||
if OS == "Darwin":
|
if OS == "Darwin":
|
||||||
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
text = Text.from_markup(
|
||||||
|
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||||
|
)
|
||||||
elif OS == "Windows":
|
elif OS == "Windows":
|
||||||
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
text = Text.from_markup(
|
||||||
|
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||||
enable long path support on your system.
|
enable long path support on your system.
|
||||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
return text
|
return text
|
||||||
|
@ -90,7 +90,7 @@ async def update_model(
|
|||||||
new_name=info.model_name,
|
new_name=info.model_name,
|
||||||
new_base=info.base_model,
|
new_base=info.base_model,
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}")
|
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||||
# update information to support an update of attributes
|
# update information to support an update of attributes
|
||||||
model_name = info.model_name
|
model_name = info.model_name
|
||||||
base_model = info.base_model
|
base_model = info.base_model
|
||||||
|
@ -3,6 +3,7 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
|
import logging
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
@ -210,11 +211,25 @@ def invoke_api():
|
|||||||
port = find_port(app_config.port)
|
port = find_port(app_config.port)
|
||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||||
|
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
|
config = uvicorn.Config(
|
||||||
# Use access_log to turn off logging
|
app=app,
|
||||||
|
host=app_config.host,
|
||||||
|
port=port,
|
||||||
|
loop=loop,
|
||||||
|
log_level=app_config.log_level,
|
||||||
|
)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||||
|
for logname in ["uvicorn.access", "uvicorn"]:
|
||||||
|
l = logging.getLogger(logname)
|
||||||
|
l.handlers.clear()
|
||||||
|
for ch in logger.handlers:
|
||||||
|
l.addHandler(ch)
|
||||||
|
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models.base import ModelType
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
@ -312,70 +312,71 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
with SilenceWarnings():
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
|
unet_info.context.model, _lora_loader()
|
||||||
|
), unet_info as unet:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
**self.unet.unet.dict(),
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
|
||||||
unet_info.context.model, _lora_loader()
|
|
||||||
), unet_info as unet:
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
control_data = self.prep_control_data(
|
||||||
context=context,
|
model=pipeline,
|
||||||
scheduler_info=self.unet.scheduler,
|
context=context,
|
||||||
scheduler_name=self.scheduler,
|
control_input=self.control,
|
||||||
)
|
latents_shape=noise.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
# TODO: Verify the noise is the right size
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
|
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||||
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=control_data, # list[ControlNetData]
|
||||||
|
callback=step_callback,
|
||||||
|
)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
model=pipeline,
|
result_latents = result_latents.to("cpu")
|
||||||
context=context,
|
torch.cuda.empty_cache()
|
||||||
control_input=self.control,
|
|
||||||
latents_shape=noise.shape,
|
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
context.services.latents.save(name, result_latents)
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
noise=noise,
|
|
||||||
num_inference_steps=self.steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=control_data, # list[ControlNetData]
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
result_latents = result_latents.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, result_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@ -403,82 +404,83 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
|
unet_info.context.model, _lora_loader()
|
||||||
|
), unet_info as unet:
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
|
||||||
del lora_info
|
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
**self.unet.unet.dict(),
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
|
||||||
unet_info.context.model, _lora_loader()
|
|
||||||
), unet_info as unet:
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
control_data = self.prep_control_data(
|
||||||
context=context,
|
model=pipeline,
|
||||||
scheduler_info=self.unet.scheduler,
|
context=context,
|
||||||
scheduler_name=self.scheduler,
|
control_input=self.control,
|
||||||
)
|
latents_shape=noise.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
# TODO: Verify the noise is the right size
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
initial_latents = (
|
||||||
|
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
model=pipeline,
|
self.steps,
|
||||||
context=context,
|
self.strength,
|
||||||
control_input=self.control,
|
device=unet.device,
|
||||||
latents_shape=noise.shape,
|
)
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
initial_latents = (
|
latents=initial_latents,
|
||||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
timesteps=timesteps,
|
||||||
)
|
noise=noise,
|
||||||
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=control_data, # list[ControlNetData]
|
||||||
|
callback=step_callback,
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
self.steps,
|
result_latents = result_latents.to("cpu")
|
||||||
self.strength,
|
torch.cuda.empty_cache()
|
||||||
device=unet.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
latents=initial_latents,
|
context.services.latents.save(name, result_latents)
|
||||||
timesteps=timesteps,
|
|
||||||
noise=noise,
|
|
||||||
num_inference_steps=self.steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=control_data, # list[ControlNetData]
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
result_latents = result_latents.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, result_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
@ -491,7 +493,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)")
|
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
metadata: Optional[CoreMetadata] = Field(
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
default=None, description="Optional core metadata to be written to the image"
|
||||||
|
@ -712,6 +712,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
|
|
||||||
class ONNXModelPatcher:
|
class ONNXModelPatcher:
|
||||||
|
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_unet(
|
def apply_lora_unet(
|
||||||
|
@ -358,6 +358,7 @@ class ModelCache(object):
|
|||||||
# 2 refs:
|
# 2 refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
|
# 1 from onnx runtime object
|
||||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
|
@ -401,7 +401,11 @@ class ModelManager(object):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
return f"{base_model}/{model_type}/{model_name}"
|
# In 3.11, the behavior of (str,enum) when interpolated into a
|
||||||
|
# string has changed. The next two lines are defensive.
|
||||||
|
base_model = BaseModelType(base_model)
|
||||||
|
model_type = ModelType(model_type)
|
||||||
|
return f"{base_model.value}/{model_type.value}/{model_name}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||||
|
@ -19,13 +19,9 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
|
|||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnx.external_data_helper import set_external_data
|
|
||||||
from onnxruntime import (
|
from onnxruntime import (
|
||||||
InferenceSession,
|
InferenceSession,
|
||||||
OrtValue,
|
|
||||||
SessionOptions,
|
SessionOptions,
|
||||||
ExecutionMode,
|
|
||||||
GraphOptimizationLevel,
|
|
||||||
get_available_providers,
|
get_available_providers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class LoRAModel(ModelBase):
|
|||||||
|
|
||||||
@classproperty
|
@classproperty
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
|
169
invokeai/frontend/web/dist/assets/App-58b095d3.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-58b095d3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-ea42d3d1.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-ea42d3d1.js
vendored
Normal file
File diff suppressed because one or more lines are too long
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-13e3db3d.js
vendored
Normal file
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-13e3db3d.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-5a784cdd.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-5a784cdd.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-b3976531.js"></script>
|
<script type="module" crossorigin src="./assets/index-5a784cdd.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
@ -340,6 +340,7 @@
|
|||||||
"allModels": "All Models",
|
"allModels": "All Models",
|
||||||
"checkpointModels": "Checkpoints",
|
"checkpointModels": "Checkpoints",
|
||||||
"diffusersModels": "Diffusers",
|
"diffusersModels": "Diffusers",
|
||||||
|
"loraModels": "LoRAs",
|
||||||
"safetensorModels": "SafeTensors",
|
"safetensorModels": "SafeTensors",
|
||||||
"onnxModels": "Onnx",
|
"onnxModels": "Onnx",
|
||||||
"oliveModels": "Olives",
|
"oliveModels": "Olives",
|
||||||
|
@ -10,8 +10,11 @@ export const addAppConfigReceivedListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const { infill_methods, nsfw_methods, watermarking_methods } =
|
const {
|
||||||
action.payload;
|
infill_methods = [],
|
||||||
|
nsfw_methods = [],
|
||||||
|
watermarking_methods = [],
|
||||||
|
} = action.payload;
|
||||||
const infillMethod = getState().generation.infillMethod;
|
const infillMethod = getState().generation.infillMethod;
|
||||||
|
|
||||||
if (!infill_methods.includes(infillMethod)) {
|
if (!infill_methods.includes(infillMethod)) {
|
||||||
|
@ -148,7 +148,7 @@ const ParamPositiveConditioning = () => {
|
|||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
top: shouldPinParametersPanel ? 6 : 0,
|
top: shouldPinParametersPanel ? 5 : 0,
|
||||||
insetInlineEnd: 0,
|
insetInlineEnd: 0,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
|
|
||||||
|
export default function ParamPromptArea() {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ParamPositiveConditioning />
|
||||||
|
<ParamNegativeConditioning />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -1,3 +1,5 @@
|
|||||||
|
import { components } from 'services/api/schema';
|
||||||
|
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
'sd-1': 'Stable Diffusion 1.x',
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
'sd-2': 'Stable Diffusion 2.x',
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
|
|||||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const MODEL_TYPE_SHORT_MAP = {
|
||||||
|
'sd-1': 'SD1',
|
||||||
|
'sd-2': 'SD2',
|
||||||
|
sdxl: 'SDXL',
|
||||||
|
'sdxl-refiner': 'SDXLR',
|
||||||
|
};
|
||||||
|
|
||||||
export const clipSkipMap = {
|
export const clipSkipMap = {
|
||||||
'sd-1': {
|
'sd-1': {
|
||||||
maxClip: 12,
|
maxClip: 12,
|
||||||
@ -23,3 +32,12 @@ export const clipSkipMap = {
|
|||||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type LoRAModelFormatMap = {
|
||||||
|
[key in components['schemas']['LoRAModelFormat']]: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
|
||||||
|
lycoris: 'LyCORIS',
|
||||||
|
diffusers: 'Diffusers',
|
||||||
|
};
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { FaLink } from 'react-icons/fa';
|
||||||
|
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
|
||||||
|
|
||||||
|
export default function ParamSDXLConcatButton() {
|
||||||
|
const shouldConcatSDXLStylePrompt = useAppSelector(
|
||||||
|
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
||||||
|
);
|
||||||
|
|
||||||
|
const shouldPinParametersPanel = useAppSelector(
|
||||||
|
(state: RootState) => state.ui.shouldPinParametersPanel
|
||||||
|
);
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleShouldConcatPromptChange = () => {
|
||||||
|
dispatch(setShouldConcatSDXLStylePrompt(!shouldConcatSDXLStylePrompt));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label="Concat"
|
||||||
|
tooltip="Concatenates Basic Prompt with Style (Recommended)"
|
||||||
|
variant="outline"
|
||||||
|
isChecked={shouldConcatSDXLStylePrompt}
|
||||||
|
onClick={handleShouldConcatPromptChange}
|
||||||
|
icon={<FaLink />}
|
||||||
|
size="xs"
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
insetInlineEnd: 1,
|
||||||
|
top: shouldPinParametersPanel ? 12 : 20,
|
||||||
|
border: 'none',
|
||||||
|
color: shouldConcatSDXLStylePrompt ? 'accent.500' : 'base.500',
|
||||||
|
_hover: {
|
||||||
|
bg: 'none',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
></IAIIconButton>
|
||||||
|
);
|
||||||
|
}
|
@ -1,33 +0,0 @@
|
|||||||
import { Box } from '@chakra-ui/react';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
|
||||||
import { ChangeEvent } from 'react';
|
|
||||||
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
|
|
||||||
|
|
||||||
export default function ParamSDXLConcatPrompt() {
|
|
||||||
const shouldConcatSDXLStylePrompt = useAppSelector(
|
|
||||||
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
|
||||||
);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleShouldConcatPromptChange = (e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
dispatch(setShouldConcatSDXLStylePrompt(e.target.checked));
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
sx={{
|
|
||||||
px: 2,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<IAISwitch
|
|
||||||
label="Concat Style Prompt"
|
|
||||||
tooltip="Concatenates Basic Prompt with Style (Recommended)"
|
|
||||||
isChecked={shouldConcatSDXLStylePrompt}
|
|
||||||
onChange={handleShouldConcatPromptChange}
|
|
||||||
/>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
}
|
|
@ -0,0 +1,61 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
|
import { AnimatePresence } from 'framer-motion';
|
||||||
|
import ParamSDXLConcatButton from './ParamSDXLConcatButton';
|
||||||
|
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||||
|
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||||
|
import SDXLConcatLink from './SDXLConcatLink';
|
||||||
|
|
||||||
|
export default function ParamSDXLPromptArea() {
|
||||||
|
const shouldPinParametersPanel = useAppSelector(
|
||||||
|
(state: RootState) => state.ui.shouldPinParametersPanel
|
||||||
|
);
|
||||||
|
|
||||||
|
const shouldConcatSDXLStylePrompt = useAppSelector(
|
||||||
|
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AnimatePresence>
|
||||||
|
{shouldConcatSDXLStylePrompt && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
w: 'full',
|
||||||
|
top: shouldPinParametersPanel ? '119px' : '175px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SDXLConcatLink />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
|
<AnimatePresence>
|
||||||
|
{shouldConcatSDXLStylePrompt && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
w: 'full',
|
||||||
|
top: shouldPinParametersPanel ? '263px' : '319px',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SDXLConcatLink />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
|
<ParamPositiveConditioning />
|
||||||
|
<ParamSDXLConcatButton />
|
||||||
|
<ParamSDXLPositiveStyleConditioning />
|
||||||
|
<ParamNegativeConditioning />
|
||||||
|
<ParamSDXLNegativeStyleConditioning />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,109 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { CSSObject } from '@emotion/react';
|
||||||
|
import { motion } from 'framer-motion';
|
||||||
|
import { FaLink } from 'react-icons/fa';
|
||||||
|
|
||||||
|
const sharedConcatLinkStyle: CSSObject = {
|
||||||
|
position: 'absolute',
|
||||||
|
bg: 'none',
|
||||||
|
w: 'full',
|
||||||
|
minH: 2,
|
||||||
|
borderRadius: 0,
|
||||||
|
borderLeft: 'none',
|
||||||
|
borderRight: 'none',
|
||||||
|
zIndex: 2,
|
||||||
|
maskImage:
|
||||||
|
'radial-gradient(circle at center, black, black 65%, black 30%, black 15%, transparent)',
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function SDXLConcatLink() {
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
h: 0.5,
|
||||||
|
placeContent: 'center',
|
||||||
|
gap: 2,
|
||||||
|
flexDirection: 'column',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
as={motion.div}
|
||||||
|
initial={{
|
||||||
|
scaleX: 0,
|
||||||
|
borderWidth: 0,
|
||||||
|
display: 'none',
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
display: ['block', 'block', 'block', 'none'],
|
||||||
|
scaleX: [0, 0.25, 0.5, 1],
|
||||||
|
borderWidth: [0, 3, 3, 0],
|
||||||
|
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
|
||||||
|
}}
|
||||||
|
sx={{
|
||||||
|
top: '1px',
|
||||||
|
borderTop: 'none',
|
||||||
|
borderColor: 'base.400',
|
||||||
|
zIndex: 2,
|
||||||
|
...sharedConcatLinkStyle,
|
||||||
|
_dark: {
|
||||||
|
borderColor: 'accent.500',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Box
|
||||||
|
as={motion.div}
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
scale: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: [0, 1, 1, 1],
|
||||||
|
scale: [0, 0.75, 1.5, 1],
|
||||||
|
transition: { duration: 0.42, times: [0, 0.25, 0.5, 1] },
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
scale: 0,
|
||||||
|
}}
|
||||||
|
sx={{
|
||||||
|
zIndex: 3,
|
||||||
|
position: 'absolute',
|
||||||
|
left: '48%',
|
||||||
|
top: '3px',
|
||||||
|
p: 1,
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'accent.400',
|
||||||
|
color: 'base.50',
|
||||||
|
_dark: {
|
||||||
|
bg: 'accent.500',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<FaLink size={12} />
|
||||||
|
</Box>
|
||||||
|
<Box
|
||||||
|
as={motion.div}
|
||||||
|
initial={{
|
||||||
|
scaleX: 0,
|
||||||
|
borderWidth: 0,
|
||||||
|
display: 'none',
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
display: ['block', 'block', 'block', 'none'],
|
||||||
|
scaleX: [0, 0.25, 0.5, 1],
|
||||||
|
borderWidth: [0, 3, 3, 0],
|
||||||
|
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
|
||||||
|
}}
|
||||||
|
sx={{
|
||||||
|
top: '17px',
|
||||||
|
borderBottom: 'none',
|
||||||
|
borderColor: 'base.400',
|
||||||
|
...sharedConcatLinkStyle,
|
||||||
|
_dark: {
|
||||||
|
borderColor: 'accent.500',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -1,34 +1,14 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
|
||||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt';
|
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
|
||||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
|
||||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
|
||||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
||||||
|
|
||||||
const SDXLImageToImageTabParameters = () => {
|
const SDXLImageToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex
|
<ParamSDXLPromptArea />
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
gap: 2,
|
|
||||||
p: 2,
|
|
||||||
borderRadius: 4,
|
|
||||||
bg: 'base.100',
|
|
||||||
_dark: { bg: 'base.850' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<ParamPositiveConditioning />
|
|
||||||
<ParamSDXLPositiveStyleConditioning />
|
|
||||||
<ParamNegativeConditioning />
|
|
||||||
<ParamSDXLNegativeStyleConditioning />
|
|
||||||
<ParamSDXLConcatPrompt />
|
|
||||||
</Flex>
|
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<SDXLImageToImageTabCoreParameters />
|
<SDXLImageToImageTabCoreParameters />
|
||||||
<ParamSDXLRefinerCollapse />
|
<ParamSDXLRefinerCollapse />
|
||||||
|
@ -1,34 +1,14 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
|
||||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||||
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt';
|
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
|
||||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
|
||||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
|
||||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
|
|
||||||
const SDXLTextToImageTabParameters = () => {
|
const SDXLTextToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex
|
<ParamSDXLPromptArea />
|
||||||
sx={{
|
|
||||||
flexDirection: 'column',
|
|
||||||
gap: 2,
|
|
||||||
p: 2,
|
|
||||||
borderRadius: 4,
|
|
||||||
bg: 'base.100',
|
|
||||||
_dark: { bg: 'base.850' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<ParamPositiveConditioning />
|
|
||||||
<ParamSDXLPositiveStyleConditioning />
|
|
||||||
<ParamNegativeConditioning />
|
|
||||||
<ParamSDXLNegativeStyleConditioning />
|
|
||||||
<ParamSDXLConcatPrompt />
|
|
||||||
</Flex>
|
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<TextToImageTabCoreParameters />
|
<TextToImageTabCoreParameters />
|
||||||
<ParamSDXLRefinerCollapse />
|
<ParamSDXLRefinerCollapse />
|
||||||
|
@ -227,7 +227,8 @@ const InvokeTabs = () => {
|
|||||||
id="gallery"
|
id="gallery"
|
||||||
order={3}
|
order={3}
|
||||||
defaultSize={
|
defaultSize={
|
||||||
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
galleryMinSizePct > DEFAULT_GALLERY_PCT &&
|
||||||
|
galleryMinSizePct < 100 // prevent this error https://github.com/bvaughn/react-resizable-panels/blob/main/packages/react-resizable-panels/src/Panel.ts#L96
|
||||||
? galleryMinSizePct
|
? galleryMinSizePct
|
||||||
: DEFAULT_GALLERY_PCT
|
: DEFAULT_GALLERY_PCT
|
||||||
}
|
}
|
||||||
|
@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
|||||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||||
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
|
||||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||||
|
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
||||||
|
|
||||||
const ImageToImageTabParameters = () => {
|
const ImageToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<ParamPositiveConditioning />
|
<ParamPromptArea />
|
||||||
<ParamNegativeConditioning />
|
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<ImageToImageTabCoreParameters />
|
<ImageToImageTabCoreParameters />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
|
@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
|
|||||||
import { useState } from 'react';
|
import { useState } from 'react';
|
||||||
import {
|
import {
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
|
DiffusersModelConfigEntity,
|
||||||
|
LoRAModelConfigEntity,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
|
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
export default function ModelManagerPanel() {
|
export default function ModelManagerPanel() {
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
const { loraModel } = useGetLoRAModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const model = mainModel ? mainModel : loraModel;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
|
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
|
||||||
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ModelEditProps = {
|
type ModelEditProps = {
|
||||||
model: MainModelConfigEntity | undefined;
|
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelEdit = (props: ModelEditProps) => {
|
const ModelEdit = (props: ModelEditProps) => {
|
||||||
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model?.model_format === 'diffusers') {
|
if (model?.model_format === 'diffusers') {
|
||||||
return <DiffusersModelEdit key={model.id} model={model} />;
|
return (
|
||||||
|
<DiffusersModelEdit
|
||||||
|
key={model.id}
|
||||||
|
model={model as DiffusersModelConfigEntity}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model?.model_type === 'lora') {
|
||||||
|
return <LoRAModelEdit key={model.id} model={model} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -0,0 +1,137 @@
|
|||||||
|
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
LORA_MODEL_FORMAT_MAP,
|
||||||
|
MODEL_TYPE_MAP,
|
||||||
|
} from 'features/parameters/types/constants';
|
||||||
|
import {
|
||||||
|
LoRAModelConfigEntity,
|
||||||
|
useUpdateLoRAModelsMutation,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { LoRAModelConfig } from 'services/api/types';
|
||||||
|
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||||
|
|
||||||
|
type LoRAModelEditProps = {
|
||||||
|
model: LoRAModelConfigEntity;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function LoRAModelEdit(props: LoRAModelEditProps) {
|
||||||
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
|
const { model } = props;
|
||||||
|
|
||||||
|
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const loraEditForm = useForm<LoRAModelConfig>({
|
||||||
|
initialValues: {
|
||||||
|
model_name: model.model_name ? model.model_name : '',
|
||||||
|
base_model: model.base_model,
|
||||||
|
model_type: 'lora',
|
||||||
|
path: model.path ? model.path : '',
|
||||||
|
description: model.description ? model.description : '',
|
||||||
|
model_format: model.model_format,
|
||||||
|
},
|
||||||
|
validate: {
|
||||||
|
path: (value) =>
|
||||||
|
value.trim().length === 0 ? 'Must provide a path' : null,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const editModelFormSubmitHandler = useCallback(
|
||||||
|
(values: LoRAModelConfig) => {
|
||||||
|
const responseBody = {
|
||||||
|
base_model: model.base_model,
|
||||||
|
model_name: model.model_name,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
|
||||||
|
updateLoRAModel(responseBody)
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
loraEditForm.setValues(payload as LoRAModelConfig);
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
loraEditForm.reset();
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[
|
||||||
|
dispatch,
|
||||||
|
loraEditForm,
|
||||||
|
model.base_model,
|
||||||
|
model.model_name,
|
||||||
|
t,
|
||||||
|
updateLoRAModel,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||||
|
<Flex flexDirection="column">
|
||||||
|
<Text fontSize="lg" fontWeight="bold">
|
||||||
|
{model.model_name}
|
||||||
|
</Text>
|
||||||
|
<Text fontSize="sm" color="base.400">
|
||||||
|
{MODEL_TYPE_MAP[model.base_model]} Model ⋅{' '}
|
||||||
|
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<form
|
||||||
|
onSubmit={loraEditForm.onSubmit((values) =>
|
||||||
|
editModelFormSubmitHandler(values)
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label={t('modelManager.name')}
|
||||||
|
{...loraEditForm.getInputProps('model_name')}
|
||||||
|
/>
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label={t('modelManager.description')}
|
||||||
|
{...loraEditForm.getInputProps('description')}
|
||||||
|
/>
|
||||||
|
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
|
||||||
|
<IAIMantineTextInput
|
||||||
|
label={t('modelManager.modelLocation')}
|
||||||
|
{...loraEditForm.getInputProps('path')}
|
||||||
|
/>
|
||||||
|
<IAIButton
|
||||||
|
type="submit"
|
||||||
|
isDisabled={isBusy || isLoading}
|
||||||
|
isLoading={isLoading}
|
||||||
|
>
|
||||||
|
{t('modelManager.updateModel')}
|
||||||
|
</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
</form>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
@ -11,6 +11,8 @@ import {
|
|||||||
OnnxModelConfigEntity,
|
OnnxModelConfigEntity,
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
useGetOnnxModelsQuery,
|
||||||
|
useGetLoRAModelsQuery,
|
||||||
|
LoRAModelConfigEntity,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
@ -22,22 +24,42 @@ type ModelListProps = {
|
|||||||
|
|
||||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||||
|
|
||||||
|
type ModelType = 'main' | 'lora';
|
||||||
|
|
||||||
|
type CombinedModelFormat = ModelFormat | 'lora';
|
||||||
|
|
||||||
const ModelList = (props: ModelListProps) => {
|
const ModelList = (props: ModelListProps) => {
|
||||||
const { selectedModelId, setSelectedModelId } = props;
|
const { selectedModelId, setSelectedModelId } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
const [modelFormatFilter, setModelFormatFilter] =
|
const [modelFormatFilter, setModelFormatFilter] =
|
||||||
useState<ModelFormat>('images');
|
useState<CombinedModelFormat>('images');
|
||||||
|
|
||||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
filteredDiffusersModels: modelsFilter(
|
||||||
|
data,
|
||||||
|
'main',
|
||||||
|
'diffusers',
|
||||||
|
nameFilter
|
||||||
|
),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
filteredCheckpointModels: modelsFilter(
|
||||||
|
data,
|
||||||
|
'main',
|
||||||
|
'checkpoint',
|
||||||
|
nameFilter
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -89,6 +111,13 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
>
|
>
|
||||||
{t('modelManager.oliveModels')}
|
{t('modelManager.oliveModels')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setModelFormatFilter('lora')}
|
||||||
|
isChecked={modelFormatFilter === 'lora'}
|
||||||
|
>
|
||||||
|
{t('modelManager.loraModels')}
|
||||||
|
</IAIButton>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
<IAIInput
|
<IAIInput
|
||||||
@ -175,6 +204,24 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</StyledModelContainer>
|
</StyledModelContainer>
|
||||||
)}
|
)}
|
||||||
|
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||||
|
filteredLoraModels.length > 0 && (
|
||||||
|
<StyledModelContainer>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<Text variant="subtext" fontSize="sm">
|
||||||
|
LoRAs
|
||||||
|
</Text>
|
||||||
|
{filteredLoraModels.map((model) => (
|
||||||
|
<ModelListItem
|
||||||
|
key={model.id}
|
||||||
|
model={model}
|
||||||
|
isSelected={selectedModelId === model.id}
|
||||||
|
setSelectedModelId={setSelectedModelId}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</StyledModelContainer>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -183,15 +230,18 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
|
|
||||||
export default ModelList;
|
export default ModelList;
|
||||||
|
|
||||||
const modelsFilter = (
|
const modelsFilter = <
|
||||||
data:
|
T extends
|
||||||
| EntityState<MainModelConfigEntity>
|
| MainModelConfigEntity
|
||||||
| EntityState<OnnxModelConfigEntity>
|
| LoRAModelConfigEntity
|
||||||
| undefined,
|
| OnnxModelConfigEntity
|
||||||
model_format: ModelFormat,
|
>(
|
||||||
|
data: EntityState<T> | undefined,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_format: ModelFormat | undefined,
|
||||||
nameFilter: string
|
nameFilter: string
|
||||||
) => {
|
) => {
|
||||||
const filteredModels: MainModelConfigEntity[] = [];
|
const filteredModels: T[] = [];
|
||||||
forEach(data?.entities, (model) => {
|
forEach(data?.entities, (model) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
@ -201,9 +251,11 @@ const modelsFilter = (
|
|||||||
.toLowerCase()
|
.toLowerCase()
|
||||||
.includes(nameFilter.toLowerCase());
|
.includes(nameFilter.toLowerCase());
|
||||||
|
|
||||||
const matchesFormat = model.model_format === model_format;
|
const matchesFormat =
|
||||||
|
model_format === undefined || model.model_format === model_format;
|
||||||
|
const matchesType = model.model_type === model_type;
|
||||||
|
|
||||||
if (matchesFilter && matchesFormat) {
|
if (matchesFilter && matchesFormat && matchesType) {
|
||||||
filteredModels.push(model);
|
filteredModels.push(model);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
import {
|
import {
|
||||||
MainModelConfigEntity,
|
MainModelConfigEntity,
|
||||||
|
LoRAModelConfigEntity,
|
||||||
useDeleteMainModelsMutation,
|
useDeleteMainModelsMutation,
|
||||||
|
useDeleteLoRAModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
model: MainModelConfigEntity;
|
model: MainModelConfigEntity | LoRAModelConfigEntity;
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
setSelectedModelId: (v: string | undefined) => void;
|
setSelectedModelId: (v: string | undefined) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
const modelBaseTypeMap = {
|
|
||||||
'sd-1': 'SD1',
|
|
||||||
'sd-2': 'SD2',
|
|
||||||
sdxl: 'SDXL',
|
|
||||||
'sdxl-refiner': 'SDXLR',
|
|
||||||
};
|
|
||||||
|
|
||||||
export default function ModelListItem(props: ModelListItemProps) {
|
export default function ModelListItem(props: ModelListItemProps) {
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const [deleteMainModel] = useDeleteMainModelsMutation();
|
const [deleteMainModel] = useDeleteMainModelsMutation();
|
||||||
|
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
|
||||||
|
|
||||||
const { model, isSelected, setSelectedModelId } = props;
|
const { model, isSelected, setSelectedModelId } = props;
|
||||||
|
|
||||||
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
}, [model.id, setSelectedModelId]);
|
}, [model.id, setSelectedModelId]);
|
||||||
|
|
||||||
const handleModelDelete = useCallback(() => {
|
const handleModelDelete = useCallback(() => {
|
||||||
deleteMainModel(model)
|
const method = { main: deleteMainModel, lora: deleteLoRAModel }[
|
||||||
|
model.model_type
|
||||||
|
];
|
||||||
|
method(model)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
title: `${t('modelManager.modelDeleteFailed')}: ${
|
title: `${t('modelManager.modelDeleteFailed')}: ${
|
||||||
model.model_name
|
model.model_name
|
||||||
}`,
|
}`,
|
||||||
status: 'success',
|
status: 'error',
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
setSelectedModelId(undefined);
|
setSelectedModelId(undefined);
|
||||||
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
|
}, [
|
||||||
|
deleteMainModel,
|
||||||
|
deleteLoRAModel,
|
||||||
|
model,
|
||||||
|
setSelectedModelId,
|
||||||
|
dispatch,
|
||||||
|
t,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
|
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
|
||||||
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
|
|||||||
<Flex gap={4} alignItems="center">
|
<Flex gap={4} alignItems="center">
|
||||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||||
{
|
{
|
||||||
modelBaseTypeMap[
|
MODEL_TYPE_SHORT_MAP[
|
||||||
model.base_model as keyof typeof modelBaseTypeMap
|
model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
</Badge>
|
</Badge>
|
||||||
|
@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
|||||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||||
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
|
||||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
|
import ParamPromptArea from '../../../../parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||||
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
||||||
|
|
||||||
const TextToImageTabParameters = () => {
|
const TextToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<ParamPositiveConditioning />
|
<ParamPromptArea />
|
||||||
<ParamNegativeConditioning />
|
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<TextToImageTabCoreParameters />
|
<TextToImageTabCoreParameters />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
|
@ -4,18 +4,16 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
|
|||||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||||
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
|
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
|
||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
|
||||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||||
|
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
||||||
|
|
||||||
const UnifiedCanvasParameters = () => {
|
const UnifiedCanvasParameters = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<ParamPositiveConditioning />
|
<ParamPromptArea />
|
||||||
<ParamNegativeConditioning />
|
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<UnifiedCanvasCoreParameters />
|
<UnifiedCanvasCoreParameters />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
|
@ -57,9 +57,17 @@ type UpdateMainModelArg = {
|
|||||||
body: MainModelConfig;
|
body: MainModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type UpdateLoRAModelArg = {
|
||||||
|
base_model: BaseModelType;
|
||||||
|
model_name: string;
|
||||||
|
body: LoRAModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
type UpdateMainModelResponse =
|
type UpdateMainModelResponse =
|
||||||
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
|
type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||||
|
|
||||||
type DeleteMainModelArg = {
|
type DeleteMainModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
@ -68,6 +76,10 @@ type DeleteMainModelArg = {
|
|||||||
|
|
||||||
type DeleteMainModelResponse = void;
|
type DeleteMainModelResponse = void;
|
||||||
|
|
||||||
|
type DeleteLoRAModelArg = DeleteMainModelArg;
|
||||||
|
|
||||||
|
type DeleteLoRAModelResponse = void;
|
||||||
|
|
||||||
type ConvertMainModelArg = {
|
type ConvertMainModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
@ -373,6 +385,31 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
updateLoRAModels: build.mutation<
|
||||||
|
UpdateLoRAModelResponse,
|
||||||
|
UpdateLoRAModelArg
|
||||||
|
>({
|
||||||
|
query: ({ base_model, model_name, body }) => {
|
||||||
|
return {
|
||||||
|
url: `models/${base_model}/lora/${model_name}`,
|
||||||
|
method: 'PATCH',
|
||||||
|
body: body,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||||
|
}),
|
||||||
|
deleteLoRAModels: build.mutation<
|
||||||
|
DeleteLoRAModelResponse,
|
||||||
|
DeleteLoRAModelArg
|
||||||
|
>({
|
||||||
|
query: ({ base_model, model_name }) => {
|
||||||
|
return {
|
||||||
|
url: `models/${base_model}/lora/${model_name}`,
|
||||||
|
method: 'DELETE',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||||
|
}),
|
||||||
getControlNetModels: build.query<
|
getControlNetModels: build.query<
|
||||||
EntityState<ControlNetModelConfigEntity>,
|
EntityState<ControlNetModelConfigEntity>,
|
||||||
void
|
void
|
||||||
@ -521,6 +558,8 @@ export const {
|
|||||||
useAddMainModelsMutation,
|
useAddMainModelsMutation,
|
||||||
useConvertMainModelsMutation,
|
useConvertMainModelsMutation,
|
||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
|
useDeleteLoRAModelsMutation,
|
||||||
|
useUpdateLoRAModelsMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useGetModelsInFolderQuery,
|
useGetModelsInFolderQuery,
|
||||||
useGetCheckpointConfigsQuery,
|
useGetCheckpointConfigsQuery,
|
||||||
|
@ -5857,11 +5857,11 @@ export type components = {
|
|||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* StableDiffusion1ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionXLModelFormat
|
* StableDiffusionXLModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -5880,6 +5880,12 @@ export type components = {
|
|||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
|
@ -43,8 +43,13 @@ export type ControlField = components['schemas']['ControlField'];
|
|||||||
// Model Configs
|
// Model Configs
|
||||||
export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
|
export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
|
||||||
export type VaeModelConfig = components['schemas']['VaeModelConfig'];
|
export type VaeModelConfig = components['schemas']['VaeModelConfig'];
|
||||||
|
export type ControlNetModelCheckpointConfig =
|
||||||
|
components['schemas']['ControlNetModelCheckpointConfig'];
|
||||||
|
export type ControlNetModelDiffusersConfig =
|
||||||
|
components['schemas']['ControlNetModelDiffusersConfig'];
|
||||||
export type ControlNetModelConfig =
|
export type ControlNetModelConfig =
|
||||||
components['schemas']['ControlNetModelConfig'];
|
| ControlNetModelCheckpointConfig
|
||||||
|
| ControlNetModelDiffusersConfig;
|
||||||
export type TextualInversionModelConfig =
|
export type TextualInversionModelConfig =
|
||||||
components['schemas']['TextualInversionModelConfig'];
|
components['schemas']['TextualInversionModelConfig'];
|
||||||
export type DiffusersModelConfig =
|
export type DiffusersModelConfig =
|
||||||
|
@ -13,6 +13,15 @@ const invokeAI = defineStyle((props) => ({
|
|||||||
var(--invokeai-colors-base-200) 70%,
|
var(--invokeai-colors-base-200) 70%,
|
||||||
var(--invokeai-colors-base-200) 100%)`,
|
var(--invokeai-colors-base-200) 100%)`,
|
||||||
},
|
},
|
||||||
|
_disabled: {
|
||||||
|
'::-webkit-resizer': {
|
||||||
|
backgroundImage: `linear-gradient(135deg,
|
||||||
|
var(--invokeai-colors-base-50) 0%,
|
||||||
|
var(--invokeai-colors-base-50) 70%,
|
||||||
|
var(--invokeai-colors-base-200) 70%,
|
||||||
|
var(--invokeai-colors-base-200) 100%)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
_dark: {
|
_dark: {
|
||||||
'::-webkit-resizer': {
|
'::-webkit-resizer': {
|
||||||
backgroundImage: `linear-gradient(135deg,
|
backgroundImage: `linear-gradient(135deg,
|
||||||
@ -21,6 +30,15 @@ const invokeAI = defineStyle((props) => ({
|
|||||||
var(--invokeai-colors-base-800) 70%,
|
var(--invokeai-colors-base-800) 70%,
|
||||||
var(--invokeai-colors-base-800) 100%)`,
|
var(--invokeai-colors-base-800) 100%)`,
|
||||||
},
|
},
|
||||||
|
_disabled: {
|
||||||
|
'::-webkit-resizer': {
|
||||||
|
backgroundImage: `linear-gradient(135deg,
|
||||||
|
var(--invokeai-colors-base-900) 0%,
|
||||||
|
var(--invokeai-colors-base-900) 70%,
|
||||||
|
var(--invokeai-colors-base-800) 70%,
|
||||||
|
var(--invokeai-colors-base-800) 100%)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "3.0.1rc1"
|
__version__ = "3.0.1rc2"
|
||||||
|
@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "InvokeAI"
|
name = "InvokeAI"
|
||||||
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
|
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
|
||||||
requires-python = ">=3.9, <3.11"
|
requires-python = ">=3.9, <3.12"
|
||||||
readme = { content-type = "text/markdown", file = "README.md" }
|
readme = { content-type = "text/markdown", file = "README.md" }
|
||||||
keywords = ["stable-diffusion", "AI"]
|
keywords = ["stable-diffusion", "AI"]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
@ -32,16 +32,16 @@ classifiers = [
|
|||||||
'Topic :: Scientific/Engineering :: Image Processing',
|
'Topic :: Scientific/Engineering :: Image Processing',
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate~=0.16",
|
"accelerate~=0.21.0",
|
||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.0",
|
"compel~=2.0.0",
|
||||||
"controlnet-aux>=0.0.6",
|
"controlnet-aux>=0.0.6",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.18.1",
|
"diffusers[torch]~=0.19.0",
|
||||||
"dnspython==2.2.1",
|
"dnspython~=2.4.0",
|
||||||
"dynamicprompts",
|
"dynamicprompts",
|
||||||
"easing-functions",
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
@ -54,13 +54,12 @@ dependencies = [
|
|||||||
"flask_cors==3.0.10",
|
"flask_cors==3.0.10",
|
||||||
"flask_socketio==5.3.0",
|
"flask_socketio==5.3.0",
|
||||||
"flaskwebgui==1.0.3",
|
"flaskwebgui==1.0.3",
|
||||||
"gfpgan==1.3.8",
|
|
||||||
"huggingface-hub>=0.11.1",
|
"huggingface-hub>=0.11.1",
|
||||||
"invisible-watermark>=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||||
"matplotlib", # needed for plotting of Penner easing functions
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"numpy<1.24",
|
"numpy==1.24.4",
|
||||||
"omegaconf",
|
"omegaconf",
|
||||||
"onnx",
|
"onnx",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
@ -68,25 +67,26 @@ dependencies = [
|
|||||||
"picklescan",
|
"picklescan",
|
||||||
"pillow",
|
"pillow",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"pympler==1.0.1",
|
"pydantic==1.10.10",
|
||||||
|
"pympler~=1.0.1",
|
||||||
"pypatchmatch",
|
"pypatchmatch",
|
||||||
'pyperclip',
|
'pyperclip',
|
||||||
"pyreadline3",
|
"pyreadline3",
|
||||||
"python-multipart==0.0.6",
|
"python-multipart",
|
||||||
"pytorch-lightning==1.7.7",
|
"pytorch-lightning",
|
||||||
"realesrgan",
|
"realesrgan",
|
||||||
"requests==2.28.2",
|
"requests~=2.28.2",
|
||||||
"rich~=13.3",
|
"rich~=13.3",
|
||||||
"safetensors~=0.3.0",
|
"safetensors~=0.3.0",
|
||||||
"scikit-image>=0.19",
|
"scikit-image~=0.21.0",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
"test-tube>=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
"torch~=2.0.0",
|
"torch~=2.0.1",
|
||||||
"torchvision>=0.14.1",
|
"torchvision~=0.15.2",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics~=1.0.1",
|
||||||
"torchsde==0.2.5",
|
"torchsde~=0.2.5",
|
||||||
"transformers~=4.31.0",
|
"transformers~=4.31.0",
|
||||||
"uvicorn[standard]==0.21.1",
|
"uvicorn[standard]~=0.21.1",
|
||||||
"windows-curses; sys_platform=='win32'",
|
"windows-curses; sys_platform=='win32'",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ dependencies = [
|
|||||||
"dev" = [
|
"dev" = [
|
||||||
"pudb",
|
"pudb",
|
||||||
]
|
]
|
||||||
"test" = ["pytest>6.0.0", "pytest-cov"]
|
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
@ -187,7 +187,7 @@ directory = "coverage/html"
|
|||||||
output = "coverage/index.xml"
|
output = "coverage/index.xml"
|
||||||
#=== End: PyTest and Coverage
|
#=== End: PyTest and Coverage
|
||||||
|
|
||||||
[flake8]
|
[tool.flake8]
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
#!/bin/env python
|
#!/bin/env python
|
||||||
|
|
||||||
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||||
|
|
||||||
info = ModelProbe().probe(Path(sys.argv[1]))
|
parser = argparse.ArgumentParser(description="Probe model type")
|
||||||
|
parser.add_argument(
|
||||||
|
"model_path",
|
||||||
|
type=Path,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
info = ModelProbe().probe(args.model_path)
|
||||||
print(info)
|
print(info)
|
||||||
|
Loading…
Reference in New Issue
Block a user