Merge branch 'main' into feat/onnx

This commit is contained in:
Brandon Rising 2023-07-28 09:59:35 -04:00
commit da751da3dd
50 changed files with 1456 additions and 331 deletions

View File

@ -1 +1,2 @@
b3dccfaeb636599c02effc377cdd8a87d658256c b3dccfaeb636599c02effc377cdd8a87d658256c
218b6d0546b990fc449c876fb99f44b50c4daa35

27
.github/workflows/style-checks.yml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Black # TODO: add isort and flake8 later
on:
pull_request: {}
push:
branches: master
tags: "*"
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install --upgrade pip wheel
pip install .[test]
# - run: isort --check-only .
- run: black --check .
# - run: flake8

1
.gitignore vendored
View File

@ -38,7 +38,6 @@ develop-eggs/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
lib/
lib64/ lib64/
parts/ parts/
sdist/ sdist/

10
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,10 @@
# See https://pre-commit.com/ for usage and config
repos:
- repo: local
hooks:
- id: black
name: black
stages: [commit]
language: system
entry: black
types: [python]

View File

@ -123,7 +123,7 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals) ### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 or 3.10 installed on your machine. Earlier or You must have Python 3.9 through 3.11 installed on your machine. Earlier or
later versions are not supported. later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed) the command `npm install -g yarn` if needed)

View File

@ -40,10 +40,8 @@ experimental versions later.
this, open up a command-line window ("Terminal" on Linux and this, open up a command-line window ("Terminal" on Linux and
Macintosh, "Command" or "Powershell" on Windows) and type `python Macintosh, "Command" or "Powershell" on Windows) and type `python
--version`. If Python is installed, it will print out the version --version`. If Python is installed, it will print out the version
number. If it is version `3.9.*` or `3.10.*`, you meet number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
requirements. We do not recommend using Python 3.11 or higher, requirements.
as not all the libraries that InvokeAI depends on work properly
with this version.
!!! warning "What to do if you have an unsupported version" !!! warning "What to do if you have an unsupported version"

View File

@ -32,7 +32,7 @@ gaming):
* **Python** * **Python**
version 3.9 or 3.10 (3.11 is not recommended). version 3.9 through 3.11
* **CUDA Tools** * **CUDA Tools**
@ -65,7 +65,7 @@ gaming):
To install InvokeAI with virtual environments and the PIP package To install InvokeAI with virtual environments and the PIP package
manager, please follow these steps: manager, please follow these steps:
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install 1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
procedure depends on this and will not work with other versions: procedure depends on this and will not work with other versions:
```bash ```bash

View File

@ -9,16 +9,20 @@ cd $scriptdir
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; } function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
MINIMUM_PYTHON_VERSION=3.9.0 MINIMUM_PYTHON_VERSION=3.9.0
MAXIMUM_PYTHON_VERSION=3.11.0 MAXIMUM_PYTHON_VERSION=3.11.100
PYTHON="" PYTHON=""
for candidate in python3.10 python3.9 python3 python ; do for candidate in python3.11 python3.10 python3.9 python3 python ; do
if ppath=`which $candidate`; then if ppath=`which $candidate`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
python_version=$($ppath -V | awk '{ print $2 }') python_version=$($ppath -V | awk '{ print $2 }')
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath PYTHON=$ppath
break break
fi fi
fi fi
fi fi
done done

View File

@ -141,15 +141,16 @@ class Installer:
# upgrade pip in Python 3.9 environments # upgrade pip in Python 3.9 environments
if int(platform.python_version_tuple()[1]) == 9: if int(platform.python_version_tuple()[1]) == 9:
from plumbum import FG, local from plumbum import FG, local
pip = local[get_pip_from_venv(venv_dir)] pip = local[get_pip_from_venv(venv_dir)]
pip[ "install", "--upgrade", "pip"] & FG pip["install", "--upgrade", "pip"] & FG
return venv_dir return venv_dir
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None: def install(
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None
) -> None:
""" """
Install the InvokeAI application into the given runtime path Install the InvokeAI application into the given runtime path
@ -175,7 +176,7 @@ class Installer:
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version) self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
# install dependencies and the InvokeAI application # install dependencies and the InvokeAI application
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None) (extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
self.instance.install( self.instance.install(
extra_index_url, extra_index_url,
optional_modules, optional_modules,
@ -188,6 +189,7 @@ class Installer:
# run through the configuration flow # run through the configuration flow
self.instance.configure() self.instance.configure()
class InvokeAiInstance: class InvokeAiInstance:
""" """
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory. Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
@ -196,7 +198,6 @@ class InvokeAiInstance:
""" """
def __init__(self, runtime: Path, venv: Path, version: str) -> None: def __init__(self, runtime: Path, venv: Path, version: str) -> None:
self.runtime = runtime self.runtime = runtime
self.venv = venv self.venv = venv
self.pip = get_pip_from_venv(venv) self.pip = get_pip_from_venv(venv)
@ -312,7 +313,7 @@ class InvokeAiInstance:
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"--use-pep517", "--use-pep517",
str(src)+(optional_modules if optional_modules else ''), str(src) + (optional_modules if optional_modules else ""),
"--find-links" if find_links is not None else None, "--find-links" if find_links is not None else None,
find_links, find_links,
"--extra-index-url" if extra_index_url is not None else None, "--extra-index-url" if extra_index_url is not None else None,
@ -329,12 +330,12 @@ class InvokeAiInstance:
# set sys.argv to a consistent state # set sys.argv to a consistent state
new_argv = [sys.argv[0]] new_argv = [sys.argv[0]]
for i in range(1,len(sys.argv)): for i in range(1, len(sys.argv)):
el = sys.argv[i] el = sys.argv[i]
if el in ['-r','--root']: if el in ["-r", "--root"]:
new_argv.append(el) new_argv.append(el)
new_argv.append(sys.argv[i+1]) new_argv.append(sys.argv[i + 1])
elif el in ['-y','--yes','--yes-to-all']: elif el in ["-y", "--yes", "--yes-to-all"]:
new_argv.append(el) new_argv.append(el)
sys.argv = new_argv sys.argv = new_argv
@ -353,16 +354,16 @@ class InvokeAiInstance:
invokeai_configure() invokeai_configure()
succeeded = True succeeded = True
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(f'\nA network error was encountered during configuration and download: {str(e)}') print(f"\nA network error was encountered during configuration and download: {str(e)}")
except OSError as e: except OSError as e:
print(f'\nAn OS error was encountered during configuration and download: {str(e)}') print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
except Exception as e: except Exception as e:
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}') print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
finally: finally:
if not succeeded: if not succeeded:
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"') print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.') print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
print('Alternatively you can relaunch the installer.') print("Alternatively you can relaunch the installer.")
def install_user_scripts(self): def install_user_scripts(self):
""" """
@ -371,11 +372,11 @@ class InvokeAiInstance:
ext = "bat" if OS == "Windows" else "sh" ext = "bat" if OS == "Windows" else "sh"
#scripts = ['invoke', 'update'] # scripts = ['invoke', 'update']
scripts = ['invoke'] scripts = ["invoke"]
for script in scripts: for script in scripts:
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in" src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
dest = self.runtime / f"{script}.{ext}" dest = self.runtime / f"{script}.{ext}"
shutil.copy(src, dest) shutil.copy(src, dest)
os.chmod(dest, 0o0755) os.chmod(dest, 0o0755)
@ -420,11 +421,7 @@ def set_sys_path(venv_path: Path) -> None:
# filter out any paths in sys.path that may be system- or user-wide # filter out any paths in sys.path that may be system- or user-wide
# but leave the temporary bootstrap virtualenv as it contains packages we # but leave the temporary bootstrap virtualenv as it contains packages we
# temporarily need at install time # temporarily need at install time
sys.path = list(filter( sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
lambda p: not p.endswith("-packages")
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
sys.path
))
# determine site-packages/lib directory location for the venv # determine site-packages/lib directory location for the venv
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}" lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
@ -433,7 +430,7 @@ def set_sys_path(venv_path: Path) -> None:
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve())) sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
def get_torch_source() -> (Union[str, None],str): def get_torch_source() -> (Union[str, None], str):
""" """
Determine the extra index URL for pip to use for torch installation. Determine the extra index URL for pip to use for torch installation.
This depends on the OS and the graphics accelerator in use. This depends on the OS and the graphics accelerator in use.
@ -461,9 +458,9 @@ def get_torch_source() -> (Union[str, None],str):
elif device == "cpu": elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda': if device == "cuda":
url = 'https://download.pytorch.org/whl/cu117' url = "https://download.pytorch.org/whl/cu117"
optional_modules = '[xformers]' optional_modules = "[xformers]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13 # in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -36,13 +36,15 @@ else:
def welcome(): def welcome():
@group() @group()
def text(): def text():
if (platform_specific := _platform_specific_help()) != "": if (platform_specific := _platform_specific_help()) != "":
yield platform_specific yield platform_specific
yield "" yield ""
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center") yield Text.from_markup(
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
justify="center",
)
console.rule() console.rule()
print( print(
@ -58,6 +60,7 @@ def welcome():
) )
console.line() console.line()
def confirm_install(dest: Path) -> bool: def confirm_install(dest: Path) -> bool:
if dest.exists(): if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:") print(f":exclamation: Directory {dest} already exists :exclamation:")
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
dest_confirmed = confirm_install(dest) dest_confirmed = confirm_install(dest)
while not dest_confirmed: while not dest_confirmed:
# if the given destination already exists, the starting point for browsing is its parent directory. # if the given destination already exists, the starting point for browsing is its parent directory.
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one. # the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
# if the destination dir does NOT exist, then the user must have changed their mind about the selection. # if the destination dir does NOT exist, then the user must have changed their mind about the selection.
@ -300,15 +302,20 @@ def introduction() -> None:
) )
console.line(2) console.line(2)
def _platform_specific_help()->str:
def _platform_specific_help() -> str:
if OS == "Darwin": if OS == "Darwin":
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""") text = Text.from_markup(
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
)
elif OS == "Windows": elif OS == "Windows":
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following: text = Text.from_markup(
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to 1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
enable long path support on your system. enable long path support on your system.
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from 2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""") [deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
)
else: else:
text = "" text = ""
return text return text

View File

@ -90,7 +90,7 @@ async def update_model(
new_name=info.model_name, new_name=info.model_name,
new_base=info.base_model, new_base=info.base_model,
) )
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}") logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes # update information to support an update of attributes
model_name = info.model_name model_name = info.model_name
base_model = info.base_model base_model = info.base_model

View File

@ -3,6 +3,7 @@ import asyncio
import sys import sys
from inspect import signature from inspect import signature
import logging
import uvicorn import uvicorn
import socket import socket
@ -210,11 +211,25 @@ def invoke_api():
port = find_port(app_config.port) port = find_port(app_config.port)
if port != app_config.port: if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}") logger.warn(f"Port {app_config.port} in use, using port {port}")
# Start our own event loop for eventing usage # Start our own event loop for eventing usage
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop) config = uvicorn.Config(
# Use access_log to turn off logging app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
server = uvicorn.Server(config) server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
loop.run_until_complete(server.serve()) loop.run_until_complete(server.serve())

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
@ -312,70 +312,71 @@ class TextToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), **lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
) )
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model( pipeline = self.create_pipeline(unet, scheduler)
**self.unet.unet.dict(), conditioning_data = self.get_conditioning_data(context, scheduler, unet)
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( control_data = self.prep_control_data(
context=context, model=pipeline,
scheduler_info=self.unet.scheduler, context=context,
scheduler_name=self.scheduler, control_input=self.control,
) latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler) # TODO: Verify the noise is the right size
conditioning_data = self.get_conditioning_data(context, scheduler, unet) result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
control_data = self.prep_control_data( # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
model=pipeline, result_latents = result_latents.to("cpu")
context=context, torch.cuda.empty_cache()
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size name = f"{context.graph_execution_state_id}__{self.id}"
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( context.services.latents.save(name, result_latents)
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), return build_latents_output(latents_name=name, latents=result_latents)
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -403,82 +404,83 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) with SilenceWarnings(): # this quenches NSFW nag from diffusers
latent = context.services.latents.get(self.latents.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), **lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context, context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
) )
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model( pipeline = self.create_pipeline(unet, scheduler)
**self.unet.unet.dict(), conditioning_data = self.get_conditioning_data(context, scheduler, unet)
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( control_data = self.prep_control_data(
context=context, model=pipeline,
scheduler_info=self.unet.scheduler, context=context,
scheduler_name=self.scheduler, control_input=self.control,
) latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
pipeline = self.create_pipeline(unet, scheduler) # TODO: Verify the noise is the right size
conditioning_data = self.get_conditioning_data(context, scheduler, unet) initial_latents = (
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
)
control_data = self.prep_control_data( timesteps, _ = pipeline.get_img2img_timesteps(
model=pipeline, self.steps,
context=context, self.strength,
control_input=self.control, device=unet.device,
latents_shape=noise.shape, )
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
initial_latents = ( latents=initial_latents,
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) timesteps=timesteps,
) noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
timesteps, _ = pipeline.get_img2img_timesteps( # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
self.steps, result_latents = result_latents.to("cpu")
self.strength, torch.cuda.empty_cache()
device=unet.device,
)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( name = f"{context.graph_execution_state_id}__{self.id}"
latents=initial_latents, context.services.latents.save(name, result_latents)
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -491,7 +493,7 @@ class LatentsToImageInvocation(BaseInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from") latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)") tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field( metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image" default=None, description="Optional core metadata to be written to the image"

View File

@ -712,6 +712,7 @@ class TextualInversionManager(BaseTextualInversionManager):
class ONNXModelPatcher: class ONNXModelPatcher:
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora_unet( def apply_lora_unet(

View File

@ -358,6 +358,7 @@ class ModelCache(object):
# 2 refs: # 2 refs:
# 1 from cache_entry # 1 from cache_entry
# 1 from getrefcount function # 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2: if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
self.logger.debug( self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"

View File

@ -401,7 +401,11 @@ class ModelManager(object):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> str: ) -> str:
return f"{base_model}/{model_type}/{model_name}" # In 3.11, the behavior of (str,enum) when interpolated into a
# string has changed. The next two lines are defensive.
base_model = BaseModelType(base_model)
model_type = ModelType(model_type)
return f"{base_model.value}/{model_type.value}/{model_name}"
@classmethod @classmethod
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:

View File

@ -19,13 +19,9 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
import onnx import onnx
from onnx import numpy_helper from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import ( from onnxruntime import (
InferenceSession, InferenceSession,
OrtValue,
SessionOptions, SessionOptions,
ExecutionMode,
GraphOptimizationLevel,
get_available_providers, get_available_providers,
) )

View File

@ -57,7 +57,7 @@ class LoRAModel(ModelBase):
@classproperty @classproperty
def save_to_config(cls) -> bool: def save_to_config(cls) -> bool:
return False return True
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-b3976531.js"></script> <script type="module" crossorigin src="./assets/index-5a784cdd.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -340,6 +340,7 @@
"allModels": "All Models", "allModels": "All Models",
"checkpointModels": "Checkpoints", "checkpointModels": "Checkpoints",
"diffusersModels": "Diffusers", "diffusersModels": "Diffusers",
"loraModels": "LoRAs",
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"onnxModels": "Onnx", "onnxModels": "Onnx",
"oliveModels": "Olives", "oliveModels": "Olives",

View File

@ -10,8 +10,11 @@ export const addAppConfigReceivedListener = () => {
startAppListening({ startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled, matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const { infill_methods, nsfw_methods, watermarking_methods } = const {
action.payload; infill_methods = [],
nsfw_methods = [],
watermarking_methods = [],
} = action.payload;
const infillMethod = getState().generation.infillMethod; const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) { if (!infill_methods.includes(infillMethod)) {

View File

@ -148,7 +148,7 @@ const ParamPositiveConditioning = () => {
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
top: shouldPinParametersPanel ? 6 : 0, top: shouldPinParametersPanel ? 5 : 0,
insetInlineEnd: 0, insetInlineEnd: 0,
}} }}
> >

View File

@ -0,0 +1,17 @@
import { Flex } from '@chakra-ui/react';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
export default function ParamPromptArea() {
return (
<Flex
sx={{
flexDirection: 'column',
gap: 2,
}}
>
<ParamPositiveConditioning />
<ParamNegativeConditioning />
</Flex>
);
}

View File

@ -1,3 +1,5 @@
import { components } from 'services/api/schema';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x', 'sd-2': 'Stable Diffusion 2.x',
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
'sdxl-refiner': 'Stable Diffusion XL Refiner', 'sdxl-refiner': 'Stable Diffusion XL Refiner',
}; };
export const MODEL_TYPE_SHORT_MAP = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export const clipSkipMap = { export const clipSkipMap = {
'sd-1': { 'sd-1': {
maxClip: 12, maxClip: 12,
@ -23,3 +32,12 @@ export const clipSkipMap = {
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
}, },
}; };
type LoRAModelFormatMap = {
[key in components['schemas']['LoRAModelFormat']]: string;
};
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
lycoris: 'LyCORIS',
diffusers: 'Diffusers',
};

View File

@ -0,0 +1,43 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaLink } from 'react-icons/fa';
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
export default function ParamSDXLConcatButton() {
const shouldConcatSDXLStylePrompt = useAppSelector(
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
);
const shouldPinParametersPanel = useAppSelector(
(state: RootState) => state.ui.shouldPinParametersPanel
);
const dispatch = useAppDispatch();
const handleShouldConcatPromptChange = () => {
dispatch(setShouldConcatSDXLStylePrompt(!shouldConcatSDXLStylePrompt));
};
return (
<IAIIconButton
aria-label="Concat"
tooltip="Concatenates Basic Prompt with Style (Recommended)"
variant="outline"
isChecked={shouldConcatSDXLStylePrompt}
onClick={handleShouldConcatPromptChange}
icon={<FaLink />}
size="xs"
sx={{
position: 'absolute',
insetInlineEnd: 1,
top: shouldPinParametersPanel ? 12 : 20,
border: 'none',
color: shouldConcatSDXLStylePrompt ? 'accent.500' : 'base.500',
_hover: {
bg: 'none',
},
}}
></IAIIconButton>
);
}

View File

@ -1,33 +0,0 @@
import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { ChangeEvent } from 'react';
import { setShouldConcatSDXLStylePrompt } from '../store/sdxlSlice';
export default function ParamSDXLConcatPrompt() {
const shouldConcatSDXLStylePrompt = useAppSelector(
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
);
const dispatch = useAppDispatch();
const handleShouldConcatPromptChange = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldConcatSDXLStylePrompt(e.target.checked));
};
return (
<Box
sx={{
px: 2,
}}
>
<IAISwitch
label="Concat Style Prompt"
tooltip="Concatenates Basic Prompt with Style (Recommended)"
isChecked={shouldConcatSDXLStylePrompt}
onChange={handleShouldConcatPromptChange}
/>
</Box>
);
}

View File

@ -0,0 +1,61 @@
import { Box, Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import { AnimatePresence } from 'framer-motion';
import ParamSDXLConcatButton from './ParamSDXLConcatButton';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import SDXLConcatLink from './SDXLConcatLink';
export default function ParamSDXLPromptArea() {
const shouldPinParametersPanel = useAppSelector(
(state: RootState) => state.ui.shouldPinParametersPanel
);
const shouldConcatSDXLStylePrompt = useAppSelector(
(state: RootState) => state.sdxl.shouldConcatSDXLStylePrompt
);
return (
<Flex
sx={{
flexDirection: 'column',
gap: 2,
}}
>
<AnimatePresence>
{shouldConcatSDXLStylePrompt && (
<Box
sx={{
position: 'absolute',
w: 'full',
top: shouldPinParametersPanel ? '119px' : '175px',
}}
>
<SDXLConcatLink />
</Box>
)}
</AnimatePresence>
<AnimatePresence>
{shouldConcatSDXLStylePrompt && (
<Box
sx={{
position: 'absolute',
w: 'full',
top: shouldPinParametersPanel ? '263px' : '319px',
}}
>
<SDXLConcatLink />
</Box>
)}
</AnimatePresence>
<ParamPositiveConditioning />
<ParamSDXLConcatButton />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
</Flex>
);
}

View File

@ -0,0 +1,109 @@
import { Box, Flex } from '@chakra-ui/react';
import { CSSObject } from '@emotion/react';
import { motion } from 'framer-motion';
import { FaLink } from 'react-icons/fa';
const sharedConcatLinkStyle: CSSObject = {
position: 'absolute',
bg: 'none',
w: 'full',
minH: 2,
borderRadius: 0,
borderLeft: 'none',
borderRight: 'none',
zIndex: 2,
maskImage:
'radial-gradient(circle at center, black, black 65%, black 30%, black 15%, transparent)',
};
export default function SDXLConcatLink() {
return (
<Flex
sx={{
h: 0.5,
placeContent: 'center',
gap: 2,
flexDirection: 'column',
}}
>
<Box
as={motion.div}
initial={{
scaleX: 0,
borderWidth: 0,
display: 'none',
}}
animate={{
display: ['block', 'block', 'block', 'none'],
scaleX: [0, 0.25, 0.5, 1],
borderWidth: [0, 3, 3, 0],
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
}}
sx={{
top: '1px',
borderTop: 'none',
borderColor: 'base.400',
zIndex: 2,
...sharedConcatLinkStyle,
_dark: {
borderColor: 'accent.500',
},
}}
/>
<Box
as={motion.div}
initial={{
opacity: 0,
scale: 0,
}}
animate={{
opacity: [0, 1, 1, 1],
scale: [0, 0.75, 1.5, 1],
transition: { duration: 0.42, times: [0, 0.25, 0.5, 1] },
}}
exit={{
opacity: 0,
scale: 0,
}}
sx={{
zIndex: 3,
position: 'absolute',
left: '48%',
top: '3px',
p: 1,
borderRadius: 4,
bg: 'accent.400',
color: 'base.50',
_dark: {
bg: 'accent.500',
},
}}
>
<FaLink size={12} />
</Box>
<Box
as={motion.div}
initial={{
scaleX: 0,
borderWidth: 0,
display: 'none',
}}
animate={{
display: ['block', 'block', 'block', 'none'],
scaleX: [0, 0.25, 0.5, 1],
borderWidth: [0, 3, 3, 0],
transition: { duration: 0.37, times: [0, 0.25, 0.5, 1] },
}}
sx={{
top: '17px',
borderBottom: 'none',
borderColor: 'base.400',
...sharedConcatLinkStyle,
_dark: {
borderColor: 'accent.500',
},
}}
/>
</Flex>
);
}

View File

@ -1,34 +1,14 @@
import { Flex } from '@chakra-ui/react';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt'; import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters'; import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
const SDXLImageToImageTabParameters = () => { const SDXLImageToImageTabParameters = () => {
return ( return (
<> <>
<Flex <ParamSDXLPromptArea />
sx={{
flexDirection: 'column',
gap: 2,
p: 2,
borderRadius: 4,
bg: 'base.100',
_dark: { bg: 'base.850' },
}}
>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ParamSDXLConcatPrompt />
</Flex>
<ProcessButtons /> <ProcessButtons />
<SDXLImageToImageTabCoreParameters /> <SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse /> <ParamSDXLRefinerCollapse />

View File

@ -1,34 +1,14 @@
import { Flex } from '@chakra-ui/react';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLConcatPrompt from './ParamSDXLConcatPrompt'; import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse'; import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLTextToImageTabParameters = () => { const SDXLTextToImageTabParameters = () => {
return ( return (
<> <>
<Flex <ParamSDXLPromptArea />
sx={{
flexDirection: 'column',
gap: 2,
p: 2,
borderRadius: 4,
bg: 'base.100',
_dark: { bg: 'base.850' },
}}
>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ParamSDXLConcatPrompt />
</Flex>
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse /> <ParamSDXLRefinerCollapse />

View File

@ -227,7 +227,8 @@ const InvokeTabs = () => {
id="gallery" id="gallery"
order={3} order={3}
defaultSize={ defaultSize={
galleryMinSizePct > DEFAULT_GALLERY_PCT galleryMinSizePct > DEFAULT_GALLERY_PCT &&
galleryMinSizePct < 100 // prevent this error https://github.com/bvaughn/react-resizable-panels/blob/main/packages/react-resizable-panels/src/Panel.ts#L96
? galleryMinSizePct ? galleryMinSizePct
: DEFAULT_GALLERY_PCT : DEFAULT_GALLERY_PCT
} }

View File

@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse'; import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; // import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
const ImageToImageTabParameters = () => { const ImageToImageTabParameters = () => {
return ( return (
<> <>
<ParamPositiveConditioning /> <ParamPromptArea />
<ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamControlNetCollapse /> <ParamControlNetCollapse />

View File

@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
import { useState } from 'react'; import { useState } from 'react';
import { import {
MainModelConfigEntity, MainModelConfigEntity,
DiffusersModelConfigEntity,
LoRAModelConfigEntity,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetLoRAModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined, mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}), }),
}); });
const { loraModel } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const model = mainModel ? mainModel : loraModel;
return ( return (
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}> <Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
} }
type ModelEditProps = { type ModelEditProps = {
model: MainModelConfigEntity | undefined; model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
}; };
const ModelEdit = (props: ModelEditProps) => { const ModelEdit = (props: ModelEditProps) => {
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
} }
if (model?.model_format === 'diffusers') { if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />; return (
<DiffusersModelEdit
key={model.id}
model={model as DiffusersModelConfigEntity}
/>
);
}
if (model?.model_type === 'lora') {
return <LoRAModelEdit key={model.id} model={model} />;
} }
return ( return (

View File

@ -0,0 +1,137 @@
import { Divider, Flex, Text } from '@chakra-ui/react';
import { useForm } from '@mantine/form';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
LORA_MODEL_FORMAT_MAP,
MODEL_TYPE_MAP,
} from 'features/parameters/types/constants';
import {
LoRAModelConfigEntity,
useUpdateLoRAModelsMutation,
} from 'services/api/endpoints/models';
import { LoRAModelConfig } from 'services/api/types';
import BaseModelSelect from '../shared/BaseModelSelect';
type LoRAModelEditProps = {
model: LoRAModelConfigEntity;
};
export default function LoRAModelEdit(props: LoRAModelEditProps) {
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const loraEditForm = useForm<LoRAModelConfig>({
initialValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'lora',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: model.model_format,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = useCallback(
(values: LoRAModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateLoRAModel(responseBody)
.unwrap()
.then((payload) => {
loraEditForm.setValues(payload as LoRAModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
loraEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
dispatch,
loraEditForm,
model.base_model,
model.model_name,
t,
updateLoRAModel,
]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} Model {' '}
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
</Text>
</Flex>
<Divider />
<form
onSubmit={loraEditForm.onSubmit((values) =>
editModelFormSubmitHandler(values)
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIMantineTextInput
label={t('modelManager.name')}
{...loraEditForm.getInputProps('model_name')}
/>
<IAIMantineTextInput
label={t('modelManager.description')}
{...loraEditForm.getInputProps('description')}
/>
<BaseModelSelect {...loraEditForm.getInputProps('base_model')} />
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...loraEditForm.getInputProps('path')}
/>
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
);
}

View File

@ -11,6 +11,8 @@ import {
OnnxModelConfigEntity, OnnxModelConfigEntity,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
@ -22,22 +24,42 @@ type ModelListProps = {
type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora';
type CombinedModelFormat = ModelFormat | 'lora';
const ModelList = (props: ModelListProps) => { const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props; const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images'); useState<CombinedModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), filteredDiffusersModels: modelsFilter(
data,
'main',
'diffusers',
nameFilter
),
}), }),
}); });
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, { const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), filteredCheckpointModels: modelsFilter(
data,
'main',
'checkpoint',
nameFilter
),
}),
});
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
}), }),
}); });
@ -89,6 +111,13 @@ const ModelList = (props: ModelListProps) => {
> >
{t('modelManager.oliveModels')} {t('modelManager.oliveModels')}
</IAIButton> </IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('lora')}
isChecked={modelFormatFilter === 'lora'}
>
{t('modelManager.loraModels')}
</IAIButton>
</ButtonGroup> </ButtonGroup>
<IAIInput <IAIInput
@ -175,6 +204,24 @@ const ModelList = (props: ModelListProps) => {
</Flex> </Flex>
</StyledModelContainer> </StyledModelContainer>
)} )}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
</Flex> </Flex>
</Flex> </Flex>
</Flex> </Flex>
@ -183,15 +230,18 @@ const ModelList = (props: ModelListProps) => {
export default ModelList; export default ModelList;
const modelsFilter = ( const modelsFilter = <
data: T extends
| EntityState<MainModelConfigEntity> | MainModelConfigEntity
| EntityState<OnnxModelConfigEntity> | LoRAModelConfigEntity
| undefined, | OnnxModelConfigEntity
model_format: ModelFormat, >(
data: EntityState<T> | undefined,
model_type: ModelType,
model_format: ModelFormat | undefined,
nameFilter: string nameFilter: string
) => { ) => {
const filteredModels: MainModelConfigEntity[] = []; const filteredModels: T[] = [];
forEach(data?.entities, (model) => { forEach(data?.entities, (model) => {
if (!model) { if (!model) {
return; return;
@ -201,9 +251,11 @@ const modelsFilter = (
.toLowerCase() .toLowerCase()
.includes(nameFilter.toLowerCase()); .includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format; const matchesFormat =
model_format === undefined || model.model_format === model_format;
const matchesType = model.model_type === model_type;
if (matchesFilter && matchesFormat) { if (matchesFilter && matchesFormat && matchesType) {
filteredModels.push(model); filteredModels.push(model);
} }
}); });

View File

@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { import {
MainModelConfigEntity, MainModelConfigEntity,
LoRAModelConfigEntity,
useDeleteMainModelsMutation, useDeleteMainModelsMutation,
useDeleteLoRAModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
model: MainModelConfigEntity; model: MainModelConfigEntity | LoRAModelConfigEntity;
isSelected: boolean; isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void; setSelectedModelId: (v: string | undefined) => void;
}; };
const modelBaseTypeMap = {
'sd-1': 'SD1',
'sd-2': 'SD2',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const isBusy = useAppSelector(selectIsBusy); const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation(); const [deleteMainModel] = useDeleteMainModelsMutation();
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
const { model, isSelected, setSelectedModelId } = props; const { model, isSelected, setSelectedModelId } = props;
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
}, [model.id, setSelectedModelId]); }, [model.id, setSelectedModelId]);
const handleModelDelete = useCallback(() => { const handleModelDelete = useCallback(() => {
deleteMainModel(model) const method = { main: deleteMainModel, lora: deleteLoRAModel }[
model.model_type
];
method(model)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
title: `${t('modelManager.modelDeleteFailed')}: ${ title: `${t('modelManager.modelDeleteFailed')}: ${
model.model_name model.model_name
}`, }`,
status: 'success', status: 'error',
}) })
) )
); );
} }
}); });
setSelectedModelId(undefined); setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]); }, [
deleteMainModel,
deleteLoRAModel,
model,
setSelectedModelId,
dispatch,
t,
]);
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}> <Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
<Flex gap={4} alignItems="center"> <Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid"> <Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{ {
modelBaseTypeMap[ MODEL_TYPE_SHORT_MAP[
model.base_model as keyof typeof modelBaseTypeMap model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
] ]
} }
</Badge> </Badge>

View File

@ -2,20 +2,18 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse'; import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Advanced/ParamAdvancedCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; // import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamPromptArea from '../../../../parameters/components/Parameters/Prompt/ParamPromptArea';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
const TextToImageTabParameters = () => { const TextToImageTabParameters = () => {
return ( return (
<> <>
<ParamPositiveConditioning /> <ParamPromptArea />
<ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamControlNetCollapse /> <ParamControlNetCollapse />

View File

@ -4,18 +4,16 @@ import ParamAdvancedCollapse from 'features/parameters/components/Parameters/Adv
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; // import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamPromptArea from 'features/parameters/components/Parameters/Prompt/ParamPromptArea';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
const UnifiedCanvasParameters = () => { const UnifiedCanvasParameters = () => {
return ( return (
<> <>
<ParamPositiveConditioning /> <ParamPromptArea />
<ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamControlNetCollapse /> <ParamControlNetCollapse />

View File

@ -57,9 +57,17 @@ type UpdateMainModelArg = {
body: MainModelConfig; body: MainModelConfig;
}; };
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse = type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json']; paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type UpdateLoRAModelResponse = UpdateMainModelResponse;
type DeleteMainModelArg = { type DeleteMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -68,6 +76,10 @@ type DeleteMainModelArg = {
type DeleteMainModelResponse = void; type DeleteMainModelResponse = void;
type DeleteLoRAModelArg = DeleteMainModelArg;
type DeleteLoRAModelResponse = void;
type ConvertMainModelArg = { type ConvertMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -373,6 +385,31 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateLoRAModels: build.mutation<
UpdateLoRAModelResponse,
UpdateLoRAModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<
DeleteLoRAModelResponse,
DeleteLoRAModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/lora/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query< getControlNetModels: build.query<
EntityState<ControlNetModelConfigEntity>, EntityState<ControlNetModelConfigEntity>,
void void
@ -521,6 +558,8 @@ export const {
useAddMainModelsMutation, useAddMainModelsMutation,
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,

View File

@ -5857,11 +5857,11 @@ export type components = {
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/** /**
* StableDiffusion1ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusionXLModelFormat * StableDiffusionXLModelFormat
* @description An enumeration. * @description An enumeration.
@ -5880,6 +5880,12 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
ControlNetModelFormat: "checkpoint" | "diffusers"; ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -43,8 +43,13 @@ export type ControlField = components['schemas']['ControlField'];
// Model Configs // Model Configs
export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; export type LoRAModelConfig = components['schemas']['LoRAModelConfig'];
export type VaeModelConfig = components['schemas']['VaeModelConfig']; export type VaeModelConfig = components['schemas']['VaeModelConfig'];
export type ControlNetModelCheckpointConfig =
components['schemas']['ControlNetModelCheckpointConfig'];
export type ControlNetModelDiffusersConfig =
components['schemas']['ControlNetModelDiffusersConfig'];
export type ControlNetModelConfig = export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig']; | ControlNetModelCheckpointConfig
| ControlNetModelDiffusersConfig;
export type TextualInversionModelConfig = export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig']; components['schemas']['TextualInversionModelConfig'];
export type DiffusersModelConfig = export type DiffusersModelConfig =

View File

@ -13,6 +13,15 @@ const invokeAI = defineStyle((props) => ({
var(--invokeai-colors-base-200) 70%, var(--invokeai-colors-base-200) 70%,
var(--invokeai-colors-base-200) 100%)`, var(--invokeai-colors-base-200) 100%)`,
}, },
_disabled: {
'::-webkit-resizer': {
backgroundImage: `linear-gradient(135deg,
var(--invokeai-colors-base-50) 0%,
var(--invokeai-colors-base-50) 70%,
var(--invokeai-colors-base-200) 70%,
var(--invokeai-colors-base-200) 100%)`,
},
},
_dark: { _dark: {
'::-webkit-resizer': { '::-webkit-resizer': {
backgroundImage: `linear-gradient(135deg, backgroundImage: `linear-gradient(135deg,
@ -21,6 +30,15 @@ const invokeAI = defineStyle((props) => ({
var(--invokeai-colors-base-800) 70%, var(--invokeai-colors-base-800) 70%,
var(--invokeai-colors-base-800) 100%)`, var(--invokeai-colors-base-800) 100%)`,
}, },
_disabled: {
'::-webkit-resizer': {
backgroundImage: `linear-gradient(135deg,
var(--invokeai-colors-base-900) 0%,
var(--invokeai-colors-base-900) 70%,
var(--invokeai-colors-base-800) 70%,
var(--invokeai-colors-base-800) 100%)`,
},
},
}, },
})); }));

View File

@ -1 +1 @@
__version__ = "3.0.1rc1" __version__ = "3.0.1rc2"

View File

@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "InvokeAI" name = "InvokeAI"
description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process" description = "An implementation of Stable Diffusion which provides various new features and options to aid the image generation process"
requires-python = ">=3.9, <3.11" requires-python = ">=3.9, <3.12"
readme = { content-type = "text/markdown", file = "README.md" } readme = { content-type = "text/markdown", file = "README.md" }
keywords = ["stable-diffusion", "AI"] keywords = ["stable-diffusion", "AI"]
dynamic = ["version"] dynamic = ["version"]
@ -32,16 +32,16 @@ classifiers = [
'Topic :: Scientific/Engineering :: Image Processing', 'Topic :: Scientific/Engineering :: Image Processing',
] ]
dependencies = [ dependencies = [
"accelerate~=0.16", "accelerate~=0.21.0",
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.0", "compel~=2.0.0",
"controlnet-aux>=0.0.6", "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",
"diffusers[torch]~=0.18.1", "diffusers[torch]~=0.19.0",
"dnspython==2.2.1", "dnspython~=2.4.0",
"dynamicprompts", "dynamicprompts",
"easing-functions", "easing-functions",
"einops", "einops",
@ -54,13 +54,12 @@ dependencies = [
"flask_cors==3.0.10", "flask_cors==3.0.10",
"flask_socketio==5.3.0", "flask_socketio==5.3.0",
"flaskwebgui==1.0.3", "flaskwebgui==1.0.3",
"gfpgan==1.3.8",
"huggingface-hub>=0.11.1", "huggingface-hub>=0.11.1",
"invisible-watermark>=0.2.0", # needed to install SDXL base and refiner using their repo_ids "invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions "matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen", "npyscreen",
"numpy<1.24", "numpy==1.24.4",
"omegaconf", "omegaconf",
"onnx", "onnx",
"onnxruntime", "onnxruntime",
@ -68,25 +67,26 @@ dependencies = [
"picklescan", "picklescan",
"pillow", "pillow",
"prompt-toolkit", "prompt-toolkit",
"pympler==1.0.1", "pydantic==1.10.10",
"pympler~=1.0.1",
"pypatchmatch", "pypatchmatch",
'pyperclip', 'pyperclip',
"pyreadline3", "pyreadline3",
"python-multipart==0.0.6", "python-multipart",
"pytorch-lightning==1.7.7", "pytorch-lightning",
"realesrgan", "realesrgan",
"requests==2.28.2", "requests~=2.28.2",
"rich~=13.3", "rich~=13.3",
"safetensors~=0.3.0", "safetensors~=0.3.0",
"scikit-image>=0.19", "scikit-image~=0.21.0",
"send2trash", "send2trash",
"test-tube>=0.7.5", "test-tube~=0.7.5",
"torch~=2.0.0", "torch~=2.0.1",
"torchvision>=0.14.1", "torchvision~=0.15.2",
"torchmetrics==0.11.4", "torchmetrics~=1.0.1",
"torchsde==0.2.5", "torchsde~=0.2.5",
"transformers~=4.31.0", "transformers~=4.31.0",
"uvicorn[standard]==0.21.1", "uvicorn[standard]~=0.21.1",
"windows-curses; sys_platform=='win32'", "windows-curses; sys_platform=='win32'",
] ]
@ -100,7 +100,7 @@ dependencies = [
"dev" = [ "dev" = [
"pudb", "pudb",
] ]
"test" = ["pytest>6.0.0", "pytest-cov"] "test" = ["pytest>6.0.0", "pytest-cov", "black"]
"xformers" = [ "xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'", "xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'", "triton; sys_platform=='linux'",
@ -187,7 +187,7 @@ directory = "coverage/html"
output = "coverage/index.xml" output = "coverage/index.xml"
#=== End: PyTest and Coverage #=== End: PyTest and Coverage
[flake8] [tool.flake8]
max-line-length = 120 max-line-length = 120
[tool.black] [tool.black]

View File

@ -1,8 +1,16 @@
#!/bin/env python #!/bin/env python
import argparse
import sys import sys
from pathlib import Path from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe from invokeai.backend.model_management.model_probe import ModelProbe
info = ModelProbe().probe(Path(sys.argv[1])) parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
"model_path",
type=Path,
)
args = parser.parse_args()
info = ModelProbe().probe(args.model_path)
print(info) print(info)