mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/onnx
This commit is contained in:
commit
da751da3dd
@ -1 +1,2 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
|
27
.github/workflows/style-checks.yml
vendored
Normal file
27
.github/workflows/style-checks.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
name: Black # TODO: add isort and flake8 later
|
||||
|
||||
on:
|
||||
pull_request: {}
|
||||
push:
|
||||
branches: master
|
||||
tags: "*"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install --upgrade pip wheel
|
||||
pip install .[test]
|
||||
|
||||
# - run: isort --check-only .
|
||||
- run: black --check .
|
||||
# - run: flake8
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,7 +38,6 @@ develop-eggs/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
# See https://pre-commit.com/ for usage and config
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: black
|
||||
name: black
|
||||
stages: [commit]
|
||||
language: system
|
||||
entry: black
|
||||
types: [python]
|
@ -123,7 +123,7 @@ and go to http://localhost:9090.
|
||||
|
||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||
|
||||
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
|
||||
later versions are not supported.
|
||||
Node.js also needs to be installed along with yarn (can be installed with
|
||||
the command `npm install -g yarn` if needed)
|
||||
|
@ -40,10 +40,8 @@ experimental versions later.
|
||||
this, open up a command-line window ("Terminal" on Linux and
|
||||
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
||||
--version`. If Python is installed, it will print out the version
|
||||
number. If it is version `3.9.*` or `3.10.*`, you meet
|
||||
requirements. We do not recommend using Python 3.11 or higher,
|
||||
as not all the libraries that InvokeAI depends on work properly
|
||||
with this version.
|
||||
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
|
||||
requirements.
|
||||
|
||||
!!! warning "What to do if you have an unsupported version"
|
||||
|
||||
|
@ -32,7 +32,7 @@ gaming):
|
||||
|
||||
* **Python**
|
||||
|
||||
version 3.9 or 3.10 (3.11 is not recommended).
|
||||
version 3.9 through 3.11
|
||||
|
||||
* **CUDA Tools**
|
||||
|
||||
@ -65,7 +65,7 @@ gaming):
|
||||
To install InvokeAI with virtual environments and the PIP package
|
||||
manager, please follow these steps:
|
||||
|
||||
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
|
||||
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
|
||||
procedure depends on this and will not work with other versions:
|
||||
|
||||
```bash
|
||||
|
@ -9,16 +9,20 @@ cd $scriptdir
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.9.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
||||
PYTHON=""
|
||||
for candidate in python3.10 python3.9 python3 python ; do
|
||||
for candidate in python3.11 python3.10 python3.9 python3 python ; do
|
||||
if ppath=`which $candidate`; then
|
||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
||||
# we check that this found executable can actually run
|
||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
||||
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
@ -141,15 +141,16 @@ class Installer:
|
||||
|
||||
# upgrade pip in Python 3.9 environments
|
||||
if int(platform.python_version_tuple()[1]) == 9:
|
||||
|
||||
from plumbum import FG, local
|
||||
|
||||
pip = local[get_pip_from_venv(venv_dir)]
|
||||
pip[ "install", "--upgrade", "pip"] & FG
|
||||
pip["install", "--upgrade", "pip"] & FG
|
||||
|
||||
return venv_dir
|
||||
|
||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||
def install(
|
||||
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||
) -> None:
|
||||
"""
|
||||
Install the InvokeAI application into the given runtime path
|
||||
|
||||
@ -175,7 +176,7 @@ class Installer:
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
self.instance.install(
|
||||
extra_index_url,
|
||||
optional_modules,
|
||||
@ -188,6 +189,7 @@ class Installer:
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||
@ -196,7 +198,6 @@ class InvokeAiInstance:
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
@ -312,7 +313,7 @@ class InvokeAiInstance:
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--use-pep517",
|
||||
str(src)+(optional_modules if optional_modules else ''),
|
||||
str(src) + (optional_modules if optional_modules else ""),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
@ -329,15 +330,15 @@ class InvokeAiInstance:
|
||||
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1,len(sys.argv)):
|
||||
for i in range(1, len(sys.argv)):
|
||||
el = sys.argv[i]
|
||||
if el in ['-r','--root']:
|
||||
if el in ["-r", "--root"]:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i+1])
|
||||
elif el in ['-y','--yes','--yes-to-all']:
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
sys.argv = new_argv
|
||||
|
||||
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
@ -353,16 +354,16 @@ class InvokeAiInstance:
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||
except OSError as e:
|
||||
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||
finally:
|
||||
if not succeeded:
|
||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
||||
print('Alternatively you can relaunch the installer.')
|
||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||
print("Alternatively you can relaunch the installer.")
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
@ -371,11 +372,11 @@ class InvokeAiInstance:
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
#scripts = ['invoke', 'update']
|
||||
scripts = ['invoke']
|
||||
|
||||
# scripts = ['invoke', 'update']
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||
dest = self.runtime / f"{script}.{ext}"
|
||||
shutil.copy(src, dest)
|
||||
os.chmod(dest, 0o0755)
|
||||
@ -420,11 +421,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
# filter out any paths in sys.path that may be system- or user-wide
|
||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||
# temporarily need at install time
|
||||
sys.path = list(filter(
|
||||
lambda p: not p.endswith("-packages")
|
||||
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
||||
sys.path
|
||||
))
|
||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||
|
||||
# determine site-packages/lib directory location for the venv
|
||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||
@ -433,7 +430,7 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_torch_source() -> (Union[str, None],str):
|
||||
def get_torch_source() -> (Union[str, None], str):
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
@ -461,9 +458,9 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == 'cuda':
|
||||
url = 'https://download.pytorch.org/whl/cu117'
|
||||
optional_modules = '[xformers]'
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu117"
|
||||
optional_modules = "[xformers]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
type=Path,
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inst = Installer()
|
||||
|
@ -36,13 +36,15 @@ else:
|
||||
|
||||
|
||||
def welcome():
|
||||
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) != "":
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -58,6 +60,7 @@ def welcome():
|
||||
)
|
||||
console.line()
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
|
||||
dest_confirmed = confirm_install(dest)
|
||||
|
||||
while not dest_confirmed:
|
||||
|
||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||
@ -300,15 +302,20 @@ def introduction() -> None:
|
||||
)
|
||||
console.line(2)
|
||||
|
||||
def _platform_specific_help()->str:
|
||||
|
||||
def _platform_specific_help() -> str:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
)
|
||||
elif OS == "Windows":
|
||||
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||
enable long path support on your system.
|
||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
return text
|
||||
|
@ -90,7 +90,7 @@ async def update_model(
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
|
@ -3,6 +3,7 @@ import asyncio
|
||||
import sys
|
||||
from inspect import signature
|
||||
|
||||
import logging
|
||||
import uvicorn
|
||||
import socket
|
||||
|
||||
@ -210,11 +211,25 @@ def invoke_api():
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
||||
# Start our own event loop for eventing usage
|
||||
loop = asyncio.new_event_loop()
|
||||
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
|
||||
# Use access_log to turn off logging
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop=loop,
|
||||
log_level=app_config.log_level,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
for logname in ["uvicorn.access", "uvicorn"]:
|
||||
l = logging.getLogger(logname)
|
||||
l.handlers.clear()
|
||||
for ch in logger.handlers:
|
||||
l.addHandler(ch)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models.base import ModelType
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
@ -312,70 +312,71 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
with SilenceWarnings():
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@ -403,82 +404,83 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), unet_info as unet:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = (
|
||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||
)
|
||||
|
||||
control_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=unet.device,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = (
|
||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||
)
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=unet.device,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
@ -491,7 +493,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)")
|
||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||
metadata: Optional[CoreMetadata] = Field(
|
||||
default=None, description="Optional core metadata to be written to the image"
|
||||
|
@ -712,6 +712,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
|
@ -358,6 +358,7 @@ class ModelCache(object):
|
||||
# 2 refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
|
@ -401,7 +401,11 @@ class ModelManager(object):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> str:
|
||||
return f"{base_model}/{model_type}/{model_name}"
|
||||
# In 3.11, the behavior of (str,enum) when interpolated into a
|
||||
# string has changed. The next two lines are defensive.
|
||||
base_model = BaseModelType(base_model)
|
||||
model_type = ModelType(model_type)
|
||||
return f"{base_model.value}/{model_type.value}/{model_name}"
|
||||
|
||||
@classmethod
|
||||
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||
|
@ -19,13 +19,9 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
|
||||
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import (
|
||||
InferenceSession,
|
||||
OrtValue,
|
||||
SessionOptions,
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
get_available_providers,
|
||||
)
|
||||
|
||||
|
@ -57,7 +57,7 @@ class LoRAModel(ModelBase):
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
|
169
invokeai/frontend/web/dist/assets/App-58b095d3.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-58b095d3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-ea42d3d1.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-ea42d3d1.js
vendored
Normal file
File diff suppressed because one or more lines are too long
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-13e3db3d.js
vendored
Normal file
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-13e3db3d.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-5a784cdd.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-5a784cdd.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-b3976531.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-5a784cdd.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
@ -340,6 +340,7 @@
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
"loraModels": "LoRAs",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"onnxModels": "Onnx",
|
||||
"oliveModels": "Olives",
|
||||
|
@ -10,8 +10,11 @@ export const addAppConfigReceivedListener = () => {
|
||||
startAppListening({
|
||||
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const { infill_methods, nsfw_methods, watermarking_methods } =
|
||||
action.payload;
|
||||
const {
|
||||
infill_methods = [],
|
||||
nsfw_methods = [],
|
||||
watermarking_methods = [],
|
||||
} = action.payload;
|
||||
const infillMethod = getState().generation.infillMethod;
|
||||
|
||||
if (!infill_methods.includes(infillMethod)) {
|
||||
|
@ -148,7 +148,7 @@ const ParamPositiveConditioning = () => {
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: shouldPinParametersPanel ? 6 : 0,
|
||||
top: shouldPinParametersPanel ? 5 : 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
|
@ -0,0 +1,17 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
|
||||
export default function ParamPromptArea() {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
import { components } from 'services/api/schema';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
};
|
||||
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
||||
export const clipSkipMap = {
|
||||
'sd-1': {
|
||||
maxClip: 12,
|
||||
@ -23,3 +32,12 @@ export const clipSkipMap = {
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
};
|
||||
|
||||
type LoRAModelFormatMap = {
|
||||
[key in components['schemas']['LoRAModelFormat']]: string;
|
||||
};
|
||||
|
||||
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
|
||||
lycoris: 'LyCORIS',
|
||||
diffusers: 'Diffusers',
|
||||
};
|
||||
|
@ -0,0 +1,43 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaLink } from 'react-icons/fa';
|
||||
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
|
||||
|
||||
export default function ParamSDXLConcatButton() {
|
||||
const shouldConcatSDXLStylePrompt = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
||||
);
|
||||
|
||||
const shouldPinParametersPanel = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldPinParametersPanel
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleShouldConcatPromptChange = () => {
|
||||
dispatch(setShouldConcatSDXLStylePrompt(!shouldConcatSDXLStylePrompt));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIIconButton
|
||||
aria-label="Concat"
|
||||
tooltip="Concatenates Basic Prompt with Style (Recommended)"
|
||||
variant="outline"
|
||||
isChecked={shouldConcatSDXLStylePrompt}
|
||||
onClick={handleShouldConcatPromptChange}
|
||||
icon={<FaLink />}
|
||||
size="xs"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
insetInlineEnd: 1,
|
||||
top: shouldPinParametersPanel ? 12 : 20,
|
||||
border: 'none',
|
||||
color: shouldConcatSDXLStylePrompt ? 'accent.500' : 'base.500',
|
||||
_hover: {
|
||||
bg: 'none',
|
||||
},
|
||||
}}
|
||||
></IAIIconButton>
|
||||
);
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
|
||||
|
||||
export default function ParamSDXLConcatPrompt() {
|
||||
const shouldConcatSDXLStylePrompt = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleShouldConcatPromptChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(setShouldConcatSDXLStylePrompt(e.target.checked));
|
||||
};
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
px: 2,
|
||||
}}
|
||||
>
|
||||
<IAISwitch
|
||||
label="Concat Style Prompt"
|
||||
tooltip="Concatenates Basic Prompt with Style (Recommended)"
|
||||
isChecked={shouldConcatSDXLStylePrompt}
|
||||
onChange={handleShouldConcatPromptChange}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import ParamSDXLConcatButton from './ParamSDXLConcatButton';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import SDXLConcatLink from './SDXLConcatLink';
|
||||
|
||||
export default function ParamSDXLPromptArea() {
|
||||
const shouldPinParametersPanel = useAppSelector(
|
||||
(state: RootState) => state.ui.shouldPinParametersPanel
|
||||
);
|
||||
|
||||
const shouldConcatSDXLStylePrompt = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{shouldConcatSDXLStylePrompt && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
w: 'full',
|
||||
top: shouldPinParametersPanel ? '119px' : '175px',
|
||||
}}
|
||||
>
|
||||
<SDXLConcatLink />
|
||||
</Box>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{shouldConcatSDXLStylePrompt && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
w: 'full',
|
||||
top: shouldPinParametersPanel ? '263px' : '319px',
|
||||
}}
|
||||
>
|
||||
<SDXLConcatLink />
|
||||
</Box>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLConcatButton />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -0,0 +1,109 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { CSSObject } from '@emotion/react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { FaLink } from 'react-icons/fa';
|
||||
|
||||
const sharedConcatLinkStyle: CSSObject = {
|
||||
position: 'absolute',
|
||||
bg: 'none',
|
||||
w: 'full',
|
||||
minH: 2,
|
||||
borderRadius: 0,
|
||||
borderLeft: 'none',
|
||||
borderRight: 'none',
|
||||
zIndex: 2,
|
||||
maskImage:
|
||||
'radial-gradient(circle at center, black, black 65%, black 30%, black 15%, transparent)',
|
||||
};
|
||||
|
||||
export default function SDXLConcatLink() {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
h: 0.5,
|
||||
placeContent: 'center',
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
>
|
||||
<Box
|
||||
as={motion.div}
|
||||
initial={{
|
||||
scaleX: 0,
|
||||
borderWidth: 0,
|
||||
display: 'none',
|
||||
}}
|
||||
animate={{
|
||||
display: ['block', 'block', 'block', 'none'],
|
||||
scaleX: [0, 0.25, 0.5, 1],
|
||||
borderWidth: [0, 3, 3, 0],
|
||||
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
|
||||
}}
|
||||
sx={{
|
||||
top: '1px',
|
||||
borderTop: 'none',
|
||||
borderColor: 'base.400',
|
||||
zIndex: 2,
|
||||
...sharedConcatLinkStyle,
|
||||
_dark: {
|
||||
borderColor: 'accent.500',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
<Box
|
||||
as={motion.div}
|
||||
initial={{
|
||||
opacity: 0,
|
||||
scale: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: [0, 1, 1, 1],
|
||||
scale: [0, 0.75, 1.5, 1],
|
||||
transition: { duration: 0.42, times: [0, 0.25, 0.5, 1] },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
scale: 0,
|
||||
}}
|
||||
sx={{
|
||||
zIndex: 3,
|
||||
position: 'absolute',
|
||||
left: '48%',
|
||||
top: '3px',
|
||||
p: 1,
|
||||
borderRadius: 4,
|
||||
bg: 'accent.400',
|
||||
color: 'base.50',
|
||||
_dark: {
|
||||
bg: 'accent.500',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<FaLink size={12} />
|
||||
</Box>
|
||||
<Box
|
||||
as={motion.div}
|
||||
initial={{
|
||||
scaleX: 0,
|
||||
borderWidth: 0,
|
||||
display: 'none',
|
||||
}}
|
||||
animate={{
|
||||
display: ['block', 'block', 'block', 'none'],
|
||||
scaleX: [0, 0.25, 0.5, 1],
|
||||
borderWidth: [0, 3, 3, 0],
|
||||
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
|
||||
}}
|
||||
sx={{
|
||||
top: '17px',
|
||||
borderBottom: 'none',
|
||||
borderColor: 'base.400',
|
||||
...sharedConcatLinkStyle,
|
||||
_dark: {
|
||||
borderColor: 'accent.500',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,34 +1,14 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
||||
|
||||
const SDXLImageToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 2,
|
||||
p: 2,
|
||||
borderRadius: 4,
|
||||
bg: 'base.100',
|
||||
_dark: { bg: 'base.850' },
|
||||
}}
|
||||
>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ParamSDXLConcatPrompt />
|
||||
</Flex>
|
||||
<ParamSDXLPromptArea />
|
||||
<ProcessButtons />
|
||||
<SDXLImageToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
|
@ -1,34 +1,14 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
|
||||
const SDXLTextToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 2,
|
||||
p: 2,
|
||||
borderRadius: 4,
|
||||
bg: 'base.100',
|
||||
_dark: { bg: 'base.850' },
|
||||
}}
|
||||
>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ParamSDXLConcatPrompt />
|
||||
</Flex>
|
||||
<ParamSDXLPromptArea />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
|
@ -227,7 +227,8 @@ const InvokeTabs = () => {
|
||||
id="gallery"
|
||||
order={3}
|
||||
defaultSize={
|
||||
galleryMinSizePct > DEFAULT_GALLERY_PCT
|
||||
galleryMinSizePct > DEFAULT_GALLERY_PCT &&
|
||||
galleryMinSizePct < 100 // prevent this error https://github.com/bvaughn/react-resizable-panels/blob/main/packages/react-resizable-panels/src/Panel.ts#L96
|
||||
? galleryMinSizePct
|
||||
: DEFAULT_GALLERY_PCT
|
||||
}
|
||||
|
@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
||||
|
||||
const ImageToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamPromptArea />
|
||||
<ProcessButtons />
|
||||
<ImageToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
|
@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useState } from 'react';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
DiffusersModelConfigEntity,
|
||||
LoRAModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
const { loraModel } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
const model = mainModel ? mainModel : loraModel;
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
|
||||
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
|
||||
}
|
||||
|
||||
type ModelEditProps = {
|
||||
model: MainModelConfigEntity | undefined;
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
|
||||
};
|
||||
|
||||
const ModelEdit = (props: ModelEditProps) => {
|
||||
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
|
||||
}
|
||||
|
||||
if (model?.model_format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.id} model={model} />;
|
||||
return (
|
||||
<DiffusersModelEdit
|
||||
key={model.id}
|
||||
model={model as DiffusersModelConfigEntity}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (model?.model_type === 'lora') {
|
||||
return <LoRAModelEdit key={model.id} model={model} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,137 @@
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
LORA_MODEL_FORMAT_MAP,
|
||||
MODEL_TYPE_MAP,
|
||||
} from 'features/parameters/types/constants';
|
||||
import {
|
||||
LoRAModelConfigEntity,
|
||||
useUpdateLoRAModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { LoRAModelConfig } from 'services/api/types';
|
||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
model: LoRAModelConfigEntity;
|
||||
};
|
||||
|
||||
export default function LoRAModelEdit(props: LoRAModelEditProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const { model } = props;
|
||||
|
||||
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const loraEditForm = useForm<LoRAModelConfig>({
|
||||
initialValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'lora',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: model.model_format,
|
||||
},
|
||||
validate: {
|
||||
path: (value) =>
|
||||
value.trim().length === 0 ? 'Must provide a path' : null,
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = useCallback(
|
||||
(values: LoRAModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateLoRAModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
loraEditForm.setValues(payload as LoRAModelConfig);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
loraEditForm.reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[
|
||||
dispatch,
|
||||
loraEditForm,
|
||||
model.base_model,
|
||||
model.model_name,
|
||||
t,
|
||||
updateLoRAModel,
|
||||
]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.model_name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base_model]} Model ⋅{' '}
|
||||
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form
|
||||
onSubmit={loraEditForm.onSubmit((values) =>
|
||||
editModelFormSubmitHandler(values)
|
||||
)}
|
||||
>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.name')}
|
||||
{...loraEditForm.getInputProps('model_name')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.description')}
|
||||
{...loraEditForm.getInputProps('description')}
|
||||
/>
|
||||
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.modelLocation')}
|
||||
{...loraEditForm.getInputProps('path')}
|
||||
/>
|
||||
<IAIButton
|
||||
type="submit"
|
||||
isDisabled={isBusy || isLoading}
|
||||
isLoading={isLoading}
|
||||
>
|
||||
{t('modelManager.updateModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -11,6 +11,8 @@ import {
|
||||
OnnxModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
LoRAModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
@ -22,22 +24,42 @@ type ModelListProps = {
|
||||
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
|
||||
|
||||
type ModelType = 'main' | 'lora';
|
||||
|
||||
type CombinedModelFormat = ModelFormat | 'lora';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('images');
|
||||
useState<CombinedModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -89,6 +111,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
>
|
||||
{t('modelManager.oliveModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('lora')}
|
||||
isChecked={modelFormatFilter === 'lora'}
|
||||
>
|
||||
{t('modelManager.loraModels')}
|
||||
</IAIButton>
|
||||
</ButtonGroup>
|
||||
|
||||
<IAIInput
|
||||
@ -175,6 +204,24 @@ const ModelList = (props: ModelListProps) => {
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
LoRAs
|
||||
</Text>
|
||||
{filteredLoraModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
@ -183,15 +230,18 @@ const ModelList = (props: ModelListProps) => {
|
||||
|
||||
export default ModelList;
|
||||
|
||||
const modelsFilter = (
|
||||
data:
|
||||
| EntityState<MainModelConfigEntity>
|
||||
| EntityState<OnnxModelConfigEntity>
|
||||
| undefined,
|
||||
model_format: ModelFormat,
|
||||
const modelsFilter = <
|
||||
T extends
|
||||
| MainModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| OnnxModelConfigEntity
|
||||
>(
|
||||
data: EntityState<T> | undefined,
|
||||
model_type: ModelType,
|
||||
model_format: ModelFormat | undefined,
|
||||
nameFilter: string
|
||||
) => {
|
||||
const filteredModels: MainModelConfigEntity[] = [];
|
||||
const filteredModels: T[] = [];
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
@ -201,9 +251,11 @@ const modelsFilter = (
|
||||
.toLowerCase()
|
||||
.includes(nameFilter.toLowerCase());
|
||||
|
||||
const matchesFormat = model.model_format === model_format;
|
||||
const matchesFormat =
|
||||
model_format === undefined || model.model_format === model_format;
|
||||
const matchesType = model.model_type === model_type;
|
||||
|
||||
if (matchesFilter && matchesFormat) {
|
||||
if (matchesFilter && matchesFormat && matchesType) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
|
@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
LoRAModelConfigEntity,
|
||||
useDeleteMainModelsMutation,
|
||||
useDeleteLoRAModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfigEntity;
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity;
|
||||
isSelected: boolean;
|
||||
setSelectedModelId: (v: string | undefined) => void;
|
||||
};
|
||||
|
||||
const modelBaseTypeMap = {
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
||||
export default function ModelListItem(props: ModelListItemProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const [deleteMainModel] = useDeleteMainModelsMutation();
|
||||
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
|
||||
|
||||
const { model, isSelected, setSelectedModelId } = props;
|
||||
|
||||
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
}, [model.id, setSelectedModelId]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteMainModel(model)
|
||||
const method = { main: deleteMainModel, lora: deleteLoRAModel }[
|
||||
model.model_type
|
||||
];
|
||||
method(model)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${
|
||||
model.model_name
|
||||
}`,
|
||||
status: 'success',
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
setSelectedModelId(undefined);
|
||||
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
|
||||
}, [
|
||||
deleteMainModel,
|
||||
deleteLoRAModel,
|
||||
model,
|
||||
setSelectedModelId,
|
||||
dispatch,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
|
||||
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||
{
|
||||
modelBaseTypeMap[
|
||||
model.base_model as keyof typeof modelBaseTypeMap
|
||||
MODEL_TYPE_SHORT_MAP[
|
||||
model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
|
||||
]
|
||||
}
|
||||
</Badge>
|
||||
|
@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
|
||||
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
|
||||
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamPromptArea from '../../../../parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
||||
|
||||
const TextToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamPromptArea />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
|
@ -4,18 +4,16 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
|
||||
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
|
||||
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
|
||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
|
||||
|
||||
const UnifiedCanvasParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamPromptArea />
|
||||
<ProcessButtons />
|
||||
<UnifiedCanvasCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
|
@ -57,9 +57,17 @@ type UpdateMainModelArg = {
|
||||
body: MainModelConfig;
|
||||
};
|
||||
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAModelConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
@ -68,6 +76,10 @@ type DeleteMainModelArg = {
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
|
||||
type DeleteLoRAModelArg = DeleteMainModelArg;
|
||||
|
||||
type DeleteLoRAModelResponse = void;
|
||||
|
||||
type ConvertMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
@ -373,6 +385,31 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
updateLoRAModels: build.mutation<
|
||||
UpdateLoRAModelResponse,
|
||||
UpdateLoRAModelArg
|
||||
>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: `models/${base_model}/lora/${model_name}`,
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
deleteLoRAModels: build.mutation<
|
||||
DeleteLoRAModelResponse,
|
||||
DeleteLoRAModelArg
|
||||
>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
return {
|
||||
url: `models/${base_model}/lora/${model_name}`,
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<
|
||||
EntityState<ControlNetModelConfigEntity>,
|
||||
void
|
||||
@ -521,6 +558,8 @@ export const {
|
||||
useAddMainModelsMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useDeleteLoRAModelsMutation,
|
||||
useUpdateLoRAModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
|
@ -5857,11 +5857,11 @@ export type components = {
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
@ -5880,6 +5880,12 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -43,8 +43,13 @@ export type ControlField = components['schemas']['ControlField'];
|
||||
// Model Configs
|
||||
export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
|
||||
export type VaeModelConfig = components['schemas']['VaeModelConfig'];
|
||||
export type ControlNetModelCheckpointConfig =
|
||||
components['schemas']['ControlNetModelCheckpointConfig'];
|
||||
export type ControlNetModelDiffusersConfig =
|
||||
components['schemas']['ControlNetModelDiffusersConfig'];
|
||||
export type ControlNetModelConfig =
|
||||
components['schemas']['ControlNetModelConfig'];
|
||||
| ControlNetModelCheckpointConfig
|
||||
| ControlNetModelDiffusersConfig;
|
||||
export type TextualInversionModelConfig =
|
||||
components['schemas']['TextualInversionModelConfig'];
|
||||
export type DiffusersModelConfig =
|
||||
|
@ -13,6 +13,15 @@ const invokeAI = defineStyle((props) => ({
|
||||
var(--invokeai-colors-base-200) 70%,
|
||||
var(--invokeai-colors-base-200) 100%)`,
|
||||
},
|
||||
_disabled: {
|
||||
'::-webkit-resizer': {
|
||||
backgroundImage: `linear-gradient(135deg,
|
||||
var(--invokeai-colors-base-50) 0%,
|
||||
var(--invokeai-colors-base-50) 70%,
|
||||
var(--invokeai-colors-base-200) 70%,
|
||||
var(--invokeai-colors-base-200) 100%)`,
|
||||
},
|
||||
},
|
||||
_dark: {
|
||||
'::-webkit-resizer': {
|
||||
backgroundImage: `linear-gradient(135deg,
|
||||
@ -21,6 +30,15 @@ const invokeAI = defineStyle((props) => ({
|
||||
var(--invokeai-colors-base-800) 70%,
|
||||
var(--invokeai-colors-base-800) 100%)`,
|
||||
},
|
||||
_disabled: {
|
||||
'::-webkit-resizer': {
|
||||
backgroundImage: `linear-gradient(135deg,
|
||||
var(--invokeai-colors-base-900) 0%,
|
||||
var(--invokeai-colors-base-900) 70%,
|
||||
var(--invokeai-colors-base-800) 70%,
|
||||
var(--invokeai-colors-base-800) 100%)`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "3.0.1rc1"
|
||||
__version__ = "3.0.1rc2"
|
||||
|
@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "InvokeAI"
|
||||
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
|
||||
requires-python = ">=3.9, <3.11"
|
||||
requires-python = ">=3.9, <3.12"
|
||||
readme = { content-type = "text/markdown", file = "README.md" }
|
||||
keywords = ["stable-diffusion", "AI"]
|
||||
dynamic = ["version"]
|
||||
@ -32,16 +32,16 @@ classifiers = [
|
||||
'Topic :: Scientific/Engineering :: Image Processing',
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate~=0.16",
|
||||
"accelerate~=0.21.0",
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.0",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel~=2.0.0",
|
||||
"controlnet-aux>=0.0.6",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.18.1",
|
||||
"dnspython==2.2.1",
|
||||
"diffusers[torch]~=0.19.0",
|
||||
"dnspython~=2.4.0",
|
||||
"dynamicprompts",
|
||||
"easing-functions",
|
||||
"einops",
|
||||
@ -54,13 +54,12 @@ dependencies = [
|
||||
"flask_cors==3.0.10",
|
||||
"flask_socketio==5.3.0",
|
||||
"flaskwebgui==1.0.3",
|
||||
"gfpgan==1.3.8",
|
||||
"huggingface-hub>=0.11.1",
|
||||
"invisible-watermark>=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"npyscreen",
|
||||
"numpy<1.24",
|
||||
"numpy==1.24.4",
|
||||
"omegaconf",
|
||||
"onnx",
|
||||
"onnxruntime",
|
||||
@ -68,25 +67,26 @@ dependencies = [
|
||||
"picklescan",
|
||||
"pillow",
|
||||
"prompt-toolkit",
|
||||
"pympler==1.0.1",
|
||||
"pydantic==1.10.10",
|
||||
"pympler~=1.0.1",
|
||||
"pypatchmatch",
|
||||
'pyperclip',
|
||||
"pyreadline3",
|
||||
"python-multipart==0.0.6",
|
||||
"pytorch-lightning==1.7.7",
|
||||
"python-multipart",
|
||||
"pytorch-lightning",
|
||||
"realesrgan",
|
||||
"requests==2.28.2",
|
||||
"requests~=2.28.2",
|
||||
"rich~=13.3",
|
||||
"safetensors~=0.3.0",
|
||||
"scikit-image>=0.19",
|
||||
"scikit-image~=0.21.0",
|
||||
"send2trash",
|
||||
"test-tube>=0.7.5",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.5",
|
||||
"test-tube~=0.7.5",
|
||||
"torch~=2.0.1",
|
||||
"torchvision~=0.15.2",
|
||||
"torchmetrics~=1.0.1",
|
||||
"torchsde~=0.2.5",
|
||||
"transformers~=4.31.0",
|
||||
"uvicorn[standard]==0.21.1",
|
||||
"uvicorn[standard]~=0.21.1",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
@ -100,7 +100,7 @@ dependencies = [
|
||||
"dev" = [
|
||||
"pudb",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov"]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
@ -187,7 +187,7 @@ directory = "coverage/html"
|
||||
output = "coverage/index.xml"
|
||||
#=== End: PyTest and Coverage
|
||||
|
||||
[flake8]
|
||||
[tool.flake8]
|
||||
max-line-length = 120
|
||||
|
||||
[tool.black]
|
||||
|
@ -1,8 +1,16 @@
|
||||
#!/bin/env python
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe
|
||||
|
||||
info = ModelProbe().probe(Path(sys.argv[1]))
|
||||
parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
"model_path",
|
||||
type=Path,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
info = ModelProbe().probe(args.model_path)
|
||||
print(info)
|
||||
|
Loading…
Reference in New Issue
Block a user