mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into unify-prompt
This commit is contained in:
commit
3bb81bedbd
@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('image1_path')
|
parser.add_argument("image1_path")
|
||||||
parser.add_argument('image2_path')
|
parser.add_argument("image2_path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
||||||
print(mean_L1)
|
print(mean_L1)
|
||||||
|
@ -1 +1,2 @@
|
|||||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||||
|
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||||
|
27
.github/workflows/style-checks.yml
vendored
Normal file
27
.github/workflows/style-checks.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: Black # TODO: add isort and flake8 later
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request: {}
|
||||||
|
push:
|
||||||
|
branches: master
|
||||||
|
tags: "*"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install dependencies with pip
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip wheel
|
||||||
|
pip install .[test]
|
||||||
|
|
||||||
|
# - run: isort --check-only .
|
||||||
|
- run: black --check .
|
||||||
|
# - run: flake8
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,7 +38,6 @@ develop-eggs/
|
|||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
lib/
|
|
||||||
lib64/
|
lib64/
|
||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
|
10
.pre-commit-config.yaml
Normal file
10
.pre-commit-config.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# See https://pre-commit.com/ for usage and config
|
||||||
|
repos:
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
name: black
|
||||||
|
stages: [commit]
|
||||||
|
language: system
|
||||||
|
entry: black
|
||||||
|
types: [python]
|
@ -141,15 +141,16 @@ class Installer:
|
|||||||
|
|
||||||
# upgrade pip in Python 3.9 environments
|
# upgrade pip in Python 3.9 environments
|
||||||
if int(platform.python_version_tuple()[1]) == 9:
|
if int(platform.python_version_tuple()[1]) == 9:
|
||||||
|
|
||||||
from plumbum import FG, local
|
from plumbum import FG, local
|
||||||
|
|
||||||
pip = local[get_pip_from_venv(venv_dir)]
|
pip = local[get_pip_from_venv(venv_dir)]
|
||||||
pip[ "install", "--upgrade", "pip"] & FG
|
pip["install", "--upgrade", "pip"] & FG
|
||||||
|
|
||||||
return venv_dir
|
return venv_dir
|
||||||
|
|
||||||
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
def install(
|
||||||
|
self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Install the InvokeAI application into the given runtime path
|
Install the InvokeAI application into the given runtime path
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ class Installer:
|
|||||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||||
|
|
||||||
# install dependencies and the InvokeAI application
|
# install dependencies and the InvokeAI application
|
||||||
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||||
self.instance.install(
|
self.instance.install(
|
||||||
extra_index_url,
|
extra_index_url,
|
||||||
optional_modules,
|
optional_modules,
|
||||||
@ -188,6 +189,7 @@ class Installer:
|
|||||||
# run through the configuration flow
|
# run through the configuration flow
|
||||||
self.instance.configure()
|
self.instance.configure()
|
||||||
|
|
||||||
|
|
||||||
class InvokeAiInstance:
|
class InvokeAiInstance:
|
||||||
"""
|
"""
|
||||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||||
@ -196,7 +198,6 @@ class InvokeAiInstance:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||||
|
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
self.pip = get_pip_from_venv(venv)
|
self.pip = get_pip_from_venv(venv)
|
||||||
@ -312,7 +313,7 @@ class InvokeAiInstance:
|
|||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"--use-pep517",
|
"--use-pep517",
|
||||||
str(src)+(optional_modules if optional_modules else ''),
|
str(src) + (optional_modules if optional_modules else ""),
|
||||||
"--find-links" if find_links is not None else None,
|
"--find-links" if find_links is not None else None,
|
||||||
find_links,
|
find_links,
|
||||||
"--extra-index-url" if extra_index_url is not None else None,
|
"--extra-index-url" if extra_index_url is not None else None,
|
||||||
@ -329,15 +330,15 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
# set sys.argv to a consistent state
|
# set sys.argv to a consistent state
|
||||||
new_argv = [sys.argv[0]]
|
new_argv = [sys.argv[0]]
|
||||||
for i in range(1,len(sys.argv)):
|
for i in range(1, len(sys.argv)):
|
||||||
el = sys.argv[i]
|
el = sys.argv[i]
|
||||||
if el in ['-r','--root']:
|
if el in ["-r", "--root"]:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
new_argv.append(sys.argv[i+1])
|
new_argv.append(sys.argv[i + 1])
|
||||||
elif el in ['-y','--yes','--yes-to-all']:
|
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
sys.argv = new_argv
|
sys.argv = new_argv
|
||||||
|
|
||||||
import requests # to catch download exceptions
|
import requests # to catch download exceptions
|
||||||
from messages import introduction
|
from messages import introduction
|
||||||
|
|
||||||
@ -353,16 +354,16 @@ class InvokeAiInstance:
|
|||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
succeeded = True
|
succeeded = True
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
if not succeeded:
|
if not succeeded:
|
||||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||||
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||||
print('Alternatively you can relaunch the installer.')
|
print("Alternatively you can relaunch the installer.")
|
||||||
|
|
||||||
def install_user_scripts(self):
|
def install_user_scripts(self):
|
||||||
"""
|
"""
|
||||||
@ -371,11 +372,11 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
ext = "bat" if OS == "Windows" else "sh"
|
ext = "bat" if OS == "Windows" else "sh"
|
||||||
|
|
||||||
#scripts = ['invoke', 'update']
|
# scripts = ['invoke', 'update']
|
||||||
scripts = ['invoke']
|
scripts = ["invoke"]
|
||||||
|
|
||||||
for script in scripts:
|
for script in scripts:
|
||||||
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
||||||
dest = self.runtime / f"{script}.{ext}"
|
dest = self.runtime / f"{script}.{ext}"
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
os.chmod(dest, 0o0755)
|
os.chmod(dest, 0o0755)
|
||||||
@ -420,11 +421,7 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
# filter out any paths in sys.path that may be system- or user-wide
|
# filter out any paths in sys.path that may be system- or user-wide
|
||||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||||
# temporarily need at install time
|
# temporarily need at install time
|
||||||
sys.path = list(filter(
|
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
||||||
lambda p: not p.endswith("-packages")
|
|
||||||
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
|
||||||
sys.path
|
|
||||||
))
|
|
||||||
|
|
||||||
# determine site-packages/lib directory location for the venv
|
# determine site-packages/lib directory location for the venv
|
||||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||||
@ -433,7 +430,7 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||||
|
|
||||||
|
|
||||||
def get_torch_source() -> (Union[str, None],str):
|
def get_torch_source() -> (Union[str, None], str):
|
||||||
"""
|
"""
|
||||||
Determine the extra index URL for pip to use for torch installation.
|
Determine the extra index URL for pip to use for torch installation.
|
||||||
This depends on the OS and the graphics accelerator in use.
|
This depends on the OS and the graphics accelerator in use.
|
||||||
@ -461,9 +458,9 @@ def get_torch_source() -> (Union[str, None],str):
|
|||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == "cuda":
|
||||||
url = 'https://download.pytorch.org/whl/cu117'
|
url = "https://download.pytorch.org/whl/cu117"
|
||||||
optional_modules = '[xformers]'
|
optional_modules = "[xformers]"
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
|||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inst = Installer()
|
inst = Installer()
|
||||||
|
@ -36,13 +36,15 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def welcome():
|
def welcome():
|
||||||
|
|
||||||
@group()
|
@group()
|
||||||
def text():
|
def text():
|
||||||
if (platform_specific := _platform_specific_help()) != "":
|
if (platform_specific := _platform_specific_help()) != "":
|
||||||
yield platform_specific
|
yield platform_specific
|
||||||
yield ""
|
yield ""
|
||||||
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
yield Text.from_markup(
|
||||||
|
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||||
|
justify="center",
|
||||||
|
)
|
||||||
|
|
||||||
console.rule()
|
console.rule()
|
||||||
print(
|
print(
|
||||||
@ -58,6 +60,7 @@ def welcome():
|
|||||||
)
|
)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
|
|
||||||
def confirm_install(dest: Path) -> bool:
|
def confirm_install(dest: Path) -> bool:
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||||
@ -92,7 +95,6 @@ def dest_path(dest=None) -> Path:
|
|||||||
dest_confirmed = confirm_install(dest)
|
dest_confirmed = confirm_install(dest)
|
||||||
|
|
||||||
while not dest_confirmed:
|
while not dest_confirmed:
|
||||||
|
|
||||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||||
@ -300,15 +302,20 @@ def introduction() -> None:
|
|||||||
)
|
)
|
||||||
console.line(2)
|
console.line(2)
|
||||||
|
|
||||||
def _platform_specific_help()->str:
|
|
||||||
|
def _platform_specific_help() -> str:
|
||||||
if OS == "Darwin":
|
if OS == "Darwin":
|
||||||
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
text = Text.from_markup(
|
||||||
|
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||||
|
)
|
||||||
elif OS == "Windows":
|
elif OS == "Windows":
|
||||||
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
text = Text.from_markup(
|
||||||
|
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||||
enable long path support on your system.
|
enable long path support on your system.
|
||||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
return text
|
return text
|
||||||
|
@ -78,9 +78,7 @@ class ApiDependencies:
|
|||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
DiskLatentsStorage(f"{output_folder}/latents")
|
|
||||||
)
|
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||||
@ -125,9 +123,7 @@ class ApiDependencies:
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
filename=db_location, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
@ -15,6 +15,7 @@ from invokeai.version import __version__
|
|||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(int, Enum):
|
class LogLevel(int, Enum):
|
||||||
NotSet = logging.NOTSET
|
NotSet = logging.NOTSET
|
||||||
Debug = logging.DEBUG
|
Debug = logging.DEBUG
|
||||||
@ -23,10 +24,12 @@ class LogLevel(int, Enum):
|
|||||||
Error = logging.ERROR
|
Error = logging.ERROR
|
||||||
Critical = logging.CRITICAL
|
Critical = logging.CRITICAL
|
||||||
|
|
||||||
|
|
||||||
class Upscaler(BaseModel):
|
class Upscaler(BaseModel):
|
||||||
upscaling_method: str = Field(description="Name of upscaling method")
|
upscaling_method: str = Field(description="Name of upscaling method")
|
||||||
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
||||||
|
|
||||||
|
|
||||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||||
|
|
||||||
|
|
||||||
@ -45,38 +48,30 @@ class AppConfig(BaseModel):
|
|||||||
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||||
|
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
|
||||||
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
|
|
||||||
)
|
|
||||||
async def get_version() -> AppVersion:
|
async def get_version() -> AppVersion:
|
||||||
return AppVersion(version=__version__)
|
return AppVersion(version=__version__)
|
||||||
|
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||||
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
|
|
||||||
)
|
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ['tile']
|
infill_methods = ["tile"]
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append('patchmatch')
|
infill_methods.append("patchmatch")
|
||||||
|
|
||||||
|
|
||||||
upscaling_models = []
|
upscaling_models = []
|
||||||
for model in typing.get_args(ESRGAN_MODELS):
|
for model in typing.get_args(ESRGAN_MODELS):
|
||||||
upscaling_models.append(str(Path(model).stem))
|
upscaling_models.append(str(Path(model).stem))
|
||||||
upscaler = Upscaler(
|
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||||
upscaling_method = 'esrgan',
|
|
||||||
upscaling_models = upscaling_models
|
|
||||||
)
|
|
||||||
|
|
||||||
nsfw_methods = []
|
nsfw_methods = []
|
||||||
if SafetyChecker.safety_checker_available():
|
if SafetyChecker.safety_checker_available():
|
||||||
nsfw_methods.append('nsfw_checker')
|
nsfw_methods.append("nsfw_checker")
|
||||||
|
|
||||||
watermarking_methods = []
|
watermarking_methods = []
|
||||||
if InvisibleWatermark.invisible_watermark_available():
|
if InvisibleWatermark.invisible_watermark_available():
|
||||||
watermarking_methods.append('invisible_watermark')
|
watermarking_methods.append("invisible_watermark")
|
||||||
|
|
||||||
return AppConfig(
|
return AppConfig(
|
||||||
infill_methods=infill_methods,
|
infill_methods=infill_methods,
|
||||||
upscaling_methods=[upscaler],
|
upscaling_methods=[upscaler],
|
||||||
@ -84,25 +79,26 @@ async def get_config() -> AppConfig:
|
|||||||
watermarking_methods=watermarking_methods,
|
watermarking_methods=watermarking_methods,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get(
|
||||||
"/logging",
|
"/logging",
|
||||||
operation_id="get_log_level",
|
operation_id="get_log_level",
|
||||||
responses={200: {"description" : "The operation was successful"}},
|
responses={200: {"description": "The operation was successful"}},
|
||||||
response_model = LogLevel,
|
response_model=LogLevel,
|
||||||
)
|
)
|
||||||
async def get_log_level(
|
async def get_log_level() -> LogLevel:
|
||||||
) -> LogLevel:
|
|
||||||
"""Returns the log level"""
|
"""Returns the log level"""
|
||||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||||
|
|
||||||
|
|
||||||
@app_router.post(
|
@app_router.post(
|
||||||
"/logging",
|
"/logging",
|
||||||
operation_id="set_log_level",
|
operation_id="set_log_level",
|
||||||
responses={200: {"description" : "The operation was successful"}},
|
responses={200: {"description": "The operation was successful"}},
|
||||||
response_model = LogLevel,
|
response_model=LogLevel,
|
||||||
)
|
)
|
||||||
async def set_log_level(
|
async def set_log_level(
|
||||||
level: LogLevel = Body(description="New log verbosity level"),
|
level: LogLevel = Body(description="New log verbosity level"),
|
||||||
) -> LogLevel:
|
) -> LogLevel:
|
||||||
"""Sets the log verbosity level"""
|
"""Sets the log verbosity level"""
|
||||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||||
|
@ -52,4 +52,3 @@ async def remove_board_image(
|
|||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel):
|
|||||||
deleted_board_images: list[str] = Field(
|
deleted_board_images: list[str] = Field(
|
||||||
description="The image names of the board-images relationships that were deleted."
|
description="The image names of the board-images relationships that were deleted."
|
||||||
)
|
)
|
||||||
deleted_images: list[str] = Field(
|
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
|
||||||
description="The names of the images that were deleted."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@boards_router.post(
|
@boards_router.post(
|
||||||
@ -73,22 +71,16 @@ async def update_board(
|
|||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
"""Updates a board"""
|
"""Updates a board"""
|
||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.boards.update(
|
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||||
board_id=board_id, changes=changes
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|
||||||
@boards_router.delete(
|
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||||
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
|
|
||||||
)
|
|
||||||
async def delete_board(
|
async def delete_board(
|
||||||
board_id: str = Path(description="The id of board to delete"),
|
board_id: str = Path(description="The id of board to delete"),
|
||||||
include_images: Optional[bool] = Query(
|
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||||
description="Permanently delete all images on the board", default=False
|
|
||||||
),
|
|
||||||
) -> DeleteBoardResult:
|
) -> DeleteBoardResult:
|
||||||
"""Deletes a board"""
|
"""Deletes a board"""
|
||||||
try:
|
try:
|
||||||
@ -96,9 +88,7 @@ async def delete_board(
|
|||||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||||
board_id=board_id
|
board_id=board_id
|
||||||
)
|
)
|
||||||
ApiDependencies.invoker.services.images.delete_images_on_board(
|
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
||||||
board_id=board_id
|
|
||||||
)
|
|
||||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||||
return DeleteBoardResult(
|
return DeleteBoardResult(
|
||||||
board_id=board_id,
|
board_id=board_id,
|
||||||
@ -127,9 +117,7 @@ async def delete_board(
|
|||||||
async def list_boards(
|
async def list_boards(
|
||||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||||
limit: Optional[int] = Query(
|
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||||
default=None, description="The number of boards per page"
|
|
||||||
),
|
|
||||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||||
"""Gets a list of boards"""
|
"""Gets a list of boards"""
|
||||||
if all:
|
if all:
|
||||||
|
@ -40,15 +40,9 @@ async def upload_image(
|
|||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = Query(description="The category of the image"),
|
image_category: ImageCategory = Query(description="The category of the image"),
|
||||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||||
board_id: Optional[str] = Query(
|
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||||
default=None, description="The board to add this image to, if any"
|
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||||
),
|
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||||
session_id: Optional[str] = Query(
|
|
||||||
default=None, description="The session ID associated with this upload, if any"
|
|
||||||
),
|
|
||||||
crop_visible: Optional[bool] = Query(
|
|
||||||
default=False, description="Whether to crop the image"
|
|
||||||
),
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
@ -115,9 +109,7 @@ async def clear_intermediates() -> int:
|
|||||||
)
|
)
|
||||||
async def update_image(
|
async def update_image(
|
||||||
image_name: str = Path(description="The name of the image to update"),
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
image_changes: ImageRecordChanges = Body(
|
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||||
description="The changes to apply to the image"
|
|
||||||
),
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Updates an image"""
|
"""Updates an image"""
|
||||||
|
|
||||||
@ -212,15 +204,11 @@ async def get_image_thumbnail(
|
|||||||
"""Gets a thumbnail image file"""
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||||
image_name, thumbnail=True
|
|
||||||
)
|
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
response = FileResponse(
|
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||||
path, media_type="image/webp", content_disposition_type="inline"
|
|
||||||
)
|
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -239,9 +227,7 @@ async def get_image_urls(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
|
||||||
image_name, thumbnail=True
|
|
||||||
)
|
|
||||||
return ImageUrlsDTO(
|
return ImageUrlsDTO(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
@ -257,15 +243,9 @@ async def get_image_urls(
|
|||||||
response_model=OffsetPaginatedResults[ImageDTO],
|
response_model=OffsetPaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_image_dtos(
|
async def list_image_dtos(
|
||||||
image_origin: Optional[ResourceOrigin] = Query(
|
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||||
default=None, description="The origin of images to list."
|
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||||
),
|
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||||
categories: Optional[list[ImageCategory]] = Query(
|
|
||||||
default=None, description="The categories of image to include."
|
|
||||||
),
|
|
||||||
is_intermediate: Optional[bool] = Query(
|
|
||||||
default=None, description="Whether to list intermediate images."
|
|
||||||
),
|
|
||||||
board_id: Optional[str] = Query(
|
board_id: Optional[str] = Query(
|
||||||
default=None,
|
default=None,
|
||||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||||
|
@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList}},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models)>0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = list()
|
models_raw = list()
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"description" : "The model was updated successfully"},
|
responses={
|
||||||
400: {"description" : "Bad request"},
|
200: {"description": "The model was updated successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
400: {"description": "Bad request"},
|
||||||
409: {"description" : "There is already a model corresponding to the new name"},
|
404: {"description": "The model could not be found"},
|
||||||
},
|
409: {"description": "There is already a model corresponding to the new name"},
|
||||||
status_code = 200,
|
},
|
||||||
response_model = UpdateModelResponse,
|
status_code=200,
|
||||||
|
response_model=UpdateModelResponse,
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> UpdateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -81,13 +84,13 @@ async def update_model(
|
|||||||
# rename operation requested
|
# rename operation requested
|
||||||
if info.model_name != model_name or info.base_model != base_model:
|
if info.model_name != model_name or info.base_model != base_model:
|
||||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
model_name = model_name,
|
model_name=model_name,
|
||||||
new_name = info.model_name,
|
new_name=info.model_name,
|
||||||
new_base = info.base_model,
|
new_base=info.base_model,
|
||||||
)
|
)
|
||||||
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||||
# update information to support an update of attributes
|
# update information to support an update of attributes
|
||||||
model_name = info.model_name
|
model_name = info.model_name
|
||||||
base_model = info.base_model
|
base_model = info.base_model
|
||||||
@ -96,16 +99,15 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
if new_info.get("path") != previous_info.get(
|
||||||
info.path = new_info.get('path')
|
"path"
|
||||||
|
): # model manager moved model path during rename - don't overwrite it
|
||||||
|
info.path = new_info.get("path")
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name,
|
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
model_attributes=info.dict()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
@ -123,49 +125,48 @@ async def update_model(
|
|||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
415: {"description" : "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse,
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||||
|
),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
if not info:
|
if not info:
|
||||||
logger.error("Import failed")
|
logger.error("Import failed")
|
||||||
raise HTTPException(status_code=415)
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
logger.info(f'Successfully imported {location}, got {info}')
|
logger.info(f"Successfully imported {location}, got {info}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name,
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type
|
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -175,38 +176,34 @@ async def import_model(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/add",
|
"/add",
|
||||||
operation_id="add_model",
|
operation_id="add_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse,
|
||||||
)
|
)
|
||||||
async def add_model(
|
async def add_model(
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
info.model_name,
|
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||||
info.base_model,
|
|
||||||
info.model_type,
|
|
||||||
model_attributes = info.dict()
|
|
||||||
)
|
)
|
||||||
logger.info(f'Successfully added {info.model_name}')
|
logger.info(f"Successfully added {info.model_name}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.model_name,
|
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type
|
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
@ -216,66 +213,66 @@ async def add_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||||
204: { "description": "Model deleted successfully" },
|
status_code=204,
|
||||||
404: { "description": "Model not found" }
|
response_model=None,
|
||||||
},
|
|
||||||
status_code = 204,
|
|
||||||
response_model = None,
|
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
ApiDependencies.invoker.services.model_manager.del_model(
|
||||||
base_model = base_model,
|
model_name, base_model=base_model, model_type=model_type
|
||||||
model_type = model_type
|
)
|
||||||
)
|
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: {"description" : "Bad request" },
|
400: {"description": "Bad request"},
|
||||||
404: { "description": "Model not found" },
|
404: {"description": "Model not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = ConvertModelResponse,
|
response_model=ConvertModelResponse,
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
convert_dest_directory: Optional[str] = Query(
|
||||||
|
default=None, description="Save the converted model to the designated directory"
|
||||||
|
),
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||||
base_model = base_model,
|
model_name,
|
||||||
model_type = model_type,
|
base_model=base_model,
|
||||||
convert_dest_directory = dest,
|
model_type=model_type,
|
||||||
)
|
convert_dest_directory=dest,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
)
|
||||||
base_model = base_model,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_type = model_type)
|
model_name, base_model=base_model, model_type=model_type
|
||||||
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
@ -283,91 +280,101 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/search",
|
"/search",
|
||||||
operation_id="search_for_models",
|
operation_id="search_for_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Directory searched successfully" },
|
200: {"description": "Directory searched successfully"},
|
||||||
404: { "description": "Invalid directory path" },
|
404: {"description": "Invalid directory path"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path],
|
||||||
)
|
)
|
||||||
async def search_for_models(
|
async def search_for_models(
|
||||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||||
)->List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||||
|
)
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/ckpt_confs",
|
"/ckpt_confs",
|
||||||
operation_id="list_ckpt_configs",
|
operation_id="list_ckpt_configs",
|
||||||
responses={
|
responses={
|
||||||
200: { "description" : "paths retrieved successfully" },
|
200: {"description": "paths retrieved successfully"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path],
|
||||||
)
|
)
|
||||||
async def list_ckpt_configs(
|
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||||
)->List[pathlib.Path]:
|
|
||||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_to_config",
|
operation_id="sync_to_config",
|
||||||
responses={
|
responses={
|
||||||
201: { "description": "synchronization successful" },
|
201: {"description": "synchronization successful"},
|
||||||
},
|
},
|
||||||
status_code = 201,
|
status_code=201,
|
||||||
response_model = bool
|
response_model=bool,
|
||||||
)
|
)
|
||||||
async def sync_to_config(
|
async def sync_to_config() -> bool:
|
||||||
)->bool:
|
|
||||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
in-memory data structures with disk data structures."""
|
in-memory data structures with disk data structures."""
|
||||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: { "description": "Incompatible models" },
|
400: {"description": "Incompatible models"},
|
||||||
404: { "description": "One or more models not found" },
|
404: {"description": "One or more models not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = MergeModelResponse,
|
response_model=MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(
|
||||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
description="Force merging of models created with different versions of diffusers", default=False
|
||||||
|
),
|
||||||
|
merge_dest_directory: Optional[str] = Body(
|
||||||
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||||
base_model,
|
model_names,
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
base_model,
|
||||||
alpha=alpha,
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
interp=interp,
|
alpha=alpha,
|
||||||
force=force,
|
interp=interp,
|
||||||
merge_dest_directory = dest
|
force=force,
|
||||||
)
|
merge_dest_directory=dest,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
)
|
||||||
base_model = base_model,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_type = ModelType.Main,
|
result.name,
|
||||||
)
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
|
@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def create_session(
|
async def create_session(
|
||||||
graph: Optional[Graph] = Body(
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||||
default=None, description="The graph to initialize the session with"
|
|
||||||
)
|
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||||
@ -51,13 +49,9 @@ async def list_sessions(
|
|||||||
) -> PaginatedResults[GraphExecutionState]:
|
) -> PaginatedResults[GraphExecutionState]:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
"""Gets a list of sessions, optionally searching"""
|
||||||
if query == "":
|
if query == "":
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||||
page, per_page
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||||
query, page, per_page
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -91,9 +85,9 @@ async def get_session(
|
|||||||
)
|
)
|
||||||
async def add_node(
|
async def add_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node: Annotated[
|
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
description="The node to add"
|
||||||
] = Body(description="The node to add"),
|
),
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@ -124,9 +118,9 @@ async def add_node(
|
|||||||
async def update_node(
|
async def update_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node_path: str = Path(description="The path to the node in the graph"),
|
node_path: str = Path(description="The path to the node in the graph"),
|
||||||
node: Annotated[
|
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
description="The new node"
|
||||||
] = Body(description="The new node"),
|
),
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@ -230,7 +224,7 @@ async def delete_edge(
|
|||||||
try:
|
try:
|
||||||
edge = Edge(
|
edge = Edge(
|
||||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||||
)
|
)
|
||||||
session.delete_edge(edge)
|
session.delete_edge(edge)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
@ -255,9 +249,7 @@ async def delete_edge(
|
|||||||
)
|
)
|
||||||
async def invoke_session(
|
async def invoke_session(
|
||||||
session_id: str = Path(description="The id of the session to invoke"),
|
session_id: str = Path(description="The id of the session to invoke"),
|
||||||
all: bool = Query(
|
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||||
default=False, description="Whether or not to invoke all remaining invocations"
|
|
||||||
),
|
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@ -274,9 +266,7 @@ async def invoke_session(
|
|||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
"/{session_id}/invoke",
|
"/{session_id}/invoke",
|
||||||
operation_id="cancel_session_invoke",
|
operation_id="cancel_session_invoke",
|
||||||
responses={
|
responses={202: {"description": "The invocation is canceled"}},
|
||||||
202: {"description": "The invocation is canceled"}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
async def cancel_session_invoke(
|
async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
|
@ -16,9 +16,7 @@ class SocketIO:
|
|||||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||||
|
|
||||||
local_handler.register(
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_session_event(self, event: Event):
|
async def _handle_session_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
|
@ -16,9 +16,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
#This should come early so that modules can log their initialization properly
|
# This should come early so that modules can log their initialization properly
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
app_config.parse_args()
|
app_config.parse_args()
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
@ -27,7 +28,7 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
# we call this early so that the message appears before
|
# we call this early so that the message appears before
|
||||||
# other invokeai initialization messages
|
# other invokeai initialization messages
|
||||||
if app_config.version:
|
if app_config.version:
|
||||||
print(f'InvokeAI version {__version__}')
|
print(f"InvokeAI version {__version__}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
@ -37,17 +38,18 @@ from .api.dependencies import ApiDependencies
|
|||||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
mimetypes.add_type('application/javascript', '.js')
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type('text/css', '.css')
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
@ -57,14 +59,13 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
|||||||
event_handler_id: int = id(app)
|
event_handler_id: int = id(app)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
EventHandlerASGIMiddleware,
|
EventHandlerASGIMiddleware,
|
||||||
handlers=[
|
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
||||||
local_handler
|
|
||||||
], # TODO: consider doing this in services to support different configurations
|
|
||||||
middleware_id=event_handler_id,
|
middleware_id=event_handler_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
@ -76,9 +77,7 @@ async def startup_event():
|
|||||||
allow_headers=app_config.allow_headers,
|
allow_headers=app_config.allow_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
config=app_config, event_handler_id=event_handler_id, logger=logger
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Shut down threads
|
# Shut down threads
|
||||||
@ -103,7 +102,8 @@ app.include_router(boards.boards_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(board_images.board_images_router, prefix="/api")
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(app_info.app_router, prefix='/api')
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
@ -144,6 +144,7 @@ def custom_openapi():
|
|||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
|
||||||
for model_config_format_enum in set(get_model_config_enums()):
|
for model_config_format_enum in set(get_model_config_enums()):
|
||||||
name = model_config_format_enum.__qualname__
|
name = model_config_format_enum.__qualname__
|
||||||
|
|
||||||
@ -166,7 +167,8 @@ def custom_openapi():
|
|||||||
app.openapi = custom_openapi
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def overridden_swagger():
|
def overridden_swagger():
|
||||||
@ -187,11 +189,8 @@ def overridden_redoc():
|
|||||||
|
|
||||||
|
|
||||||
# Must mount *after* the other routes else it borks em
|
# Must mount *after* the other routes else it borks em
|
||||||
app.mount("/",
|
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
|
||||||
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
|
||||||
html=True
|
|
||||||
), name="ui"
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke_api():
|
def invoke_api():
|
||||||
def find_port(port: int):
|
def find_port(port: int):
|
||||||
@ -203,10 +202,11 @@ def invoke_api():
|
|||||||
return find_port(port=port + 1)
|
return find_port(port=port + 1)
|
||||||
else:
|
else:
|
||||||
return port
|
return port
|
||||||
|
|
||||||
from invokeai.backend.install.check_root import check_invokeai_root
|
from invokeai.backend.install.check_root import check_invokeai_root
|
||||||
|
|
||||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||||
|
|
||||||
port = find_port(app_config.port)
|
port = find_port(app_config.port)
|
||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||||
@ -217,5 +217,6 @@ def invoke_api():
|
|||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
invoke_api()
|
invoke_api()
|
||||||
|
@ -14,8 +14,14 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
|||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
def add_field_argument(command_parser, name: str, field, default_override=None):
|
||||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
default = (
|
||||||
|
default_override
|
||||||
|
if default_override is not None
|
||||||
|
else field.default
|
||||||
|
if field.default_factory is None
|
||||||
|
else field.default_factory()
|
||||||
|
)
|
||||||
if get_origin(field.type_) == Literal:
|
if get_origin(field.type_) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.type_)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
@ -47,8 +53,8 @@ def add_parsers(
|
|||||||
commands: list[type],
|
commands: list[type],
|
||||||
command_field: str = "type",
|
command_field: str = "type",
|
||||||
exclude_fields: list[str] = ["id", "type"],
|
exclude_fields: list[str] = ["id", "type"],
|
||||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
||||||
):
|
):
|
||||||
"""Adds parsers for each command to the subparsers"""
|
"""Adds parsers for each command to the subparsers"""
|
||||||
|
|
||||||
# Create subparsers for each command
|
# Create subparsers for each command
|
||||||
@ -61,7 +67,7 @@ def add_parsers(
|
|||||||
add_arguments(command_parser)
|
add_arguments(command_parser)
|
||||||
|
|
||||||
# Convert all fields to arguments
|
# Convert all fields to arguments
|
||||||
fields = command.__fields__ # type: ignore
|
fields = command.__fields__ # type: ignore
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name in exclude_fields:
|
if name in exclude_fields:
|
||||||
continue
|
continue
|
||||||
@ -70,13 +76,11 @@ def add_parsers(
|
|||||||
|
|
||||||
|
|
||||||
def add_graph_parsers(
|
def add_graph_parsers(
|
||||||
subparsers,
|
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||||
graphs: list[LibraryGraph],
|
|
||||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
|
||||||
):
|
):
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||||
|
|
||||||
if add_arguments is not None:
|
if add_arguments is not None:
|
||||||
add_arguments(command_parser)
|
add_arguments(command_parser)
|
||||||
|
|
||||||
@ -128,6 +132,7 @@ class CliContext:
|
|||||||
|
|
||||||
class ExitCli(Exception):
|
class ExitCli(Exception):
|
||||||
"""Exception to exit the CLI"""
|
"""Exception to exit the CLI"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_commands_map(cls):
|
def get_commands_map(cls):
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
|
|
||||||
class ExitCommand(BaseCommand):
|
class ExitCommand(BaseCommand):
|
||||||
"""Exits the CLI"""
|
"""Exits the CLI"""
|
||||||
type: Literal['exit'] = 'exit'
|
|
||||||
|
type: Literal["exit"] = "exit"
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
raise ExitCli()
|
raise ExitCli()
|
||||||
@ -173,7 +179,8 @@ class ExitCommand(BaseCommand):
|
|||||||
|
|
||||||
class HelpCommand(BaseCommand):
|
class HelpCommand(BaseCommand):
|
||||||
"""Shows help"""
|
"""Shows help"""
|
||||||
type: Literal['help'] = 'help'
|
|
||||||
|
type: Literal["help"] = "help"
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
context.parser.print_help()
|
context.parser.print_help()
|
||||||
@ -183,11 +190,7 @@ def get_graph_execution_history(
|
|||||||
graph_execution_state: GraphExecutionState,
|
graph_execution_state: GraphExecutionState,
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||||
return (
|
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
||||||
n
|
|
||||||
for n in reversed(graph_execution_state.executed_history)
|
|
||||||
if n in graph_execution_state.graph.nodes
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_invocation_command(invocation) -> str:
|
def get_invocation_command(invocation) -> str:
|
||||||
@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str:
|
|||||||
|
|
||||||
class HistoryCommand(BaseCommand):
|
class HistoryCommand(BaseCommand):
|
||||||
"""Shows the invocation history"""
|
"""Shows the invocation history"""
|
||||||
type: Literal['history'] = 'history'
|
|
||||||
|
type: Literal["history"] = "history"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand):
|
|||||||
|
|
||||||
class SetDefaultCommand(BaseCommand):
|
class SetDefaultCommand(BaseCommand):
|
||||||
"""Sets a default value for a field"""
|
"""Sets a default value for a field"""
|
||||||
type: Literal['default'] = 'default'
|
|
||||||
|
type: Literal["default"] = "default"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand):
|
|||||||
|
|
||||||
class DrawGraphCommand(BaseCommand):
|
class DrawGraphCommand(BaseCommand):
|
||||||
"""Debugs a graph"""
|
"""Debugs a graph"""
|
||||||
type: Literal['draw_graph'] = 'draw_graph'
|
|
||||||
|
type: Literal["draw_graph"] = "draw_graph"
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand):
|
|||||||
|
|
||||||
class DrawExecutionGraphCommand(BaseCommand):
|
class DrawExecutionGraphCommand(BaseCommand):
|
||||||
"""Debugs an execution graph"""
|
"""Debugs an execution graph"""
|
||||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
|
||||||
|
type: Literal["draw_xgraph"] = "draw_xgraph"
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand):
|
|||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||||
def _iter_indented_subactions(self, action):
|
def _iter_indented_subactions(self, action):
|
||||||
try:
|
try:
|
||||||
|
@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
|
|||||||
# singleton object, class variable
|
# singleton object, class variable
|
||||||
completer = None
|
completer = None
|
||||||
|
|
||||||
|
|
||||||
class Completer(object):
|
class Completer(object):
|
||||||
|
|
||||||
def __init__(self, model_manager: ModelManager):
|
def __init__(self, model_manager: ModelManager):
|
||||||
self.commands = self.get_commands()
|
self.commands = self.get_commands()
|
||||||
self.matches = None
|
self.matches = None
|
||||||
@ -43,7 +43,7 @@ class Completer(object):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
options = options or list(self.parse_commands().keys())
|
options = options or list(self.parse_commands().keys())
|
||||||
|
|
||||||
if not text: # first time
|
if not text: # first time
|
||||||
self.matches = options
|
self.matches = options
|
||||||
else:
|
else:
|
||||||
@ -56,17 +56,17 @@ class Completer(object):
|
|||||||
return match
|
return match
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_commands(self)->List[object]:
|
def get_commands(self) -> List[object]:
|
||||||
"""
|
"""
|
||||||
Return a list of all the client commands and invocations.
|
Return a list of all the client commands and invocations.
|
||||||
"""
|
"""
|
||||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||||
|
|
||||||
def get_current_command(self, buffer: str)->tuple[str, str]:
|
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse the readline buffer to find the most recent command and its switch.
|
Parse the readline buffer to find the most recent command and its switch.
|
||||||
"""
|
"""
|
||||||
if len(buffer)==0:
|
if len(buffer) == 0:
|
||||||
return None, None
|
return None, None
|
||||||
tokens = shlex.split(buffer)
|
tokens = shlex.split(buffer)
|
||||||
command = None
|
command = None
|
||||||
@ -78,11 +78,11 @@ class Completer(object):
|
|||||||
else:
|
else:
|
||||||
switch = t
|
switch = t
|
||||||
# don't try to autocomplete switches that are already complete
|
# don't try to autocomplete switches that are already complete
|
||||||
if switch and buffer.endswith(' '):
|
if switch and buffer.endswith(" "):
|
||||||
switch=None
|
switch = None
|
||||||
return command or '', switch or ''
|
return command or "", switch or ""
|
||||||
|
|
||||||
def parse_commands(self)->Dict[str, List[str]]:
|
def parse_commands(self) -> Dict[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Return a dict in which the keys are the command name
|
Return a dict in which the keys are the command name
|
||||||
and the values are the parameters the command takes.
|
and the values are the parameters the command takes.
|
||||||
@ -90,11 +90,11 @@ class Completer(object):
|
|||||||
result = dict()
|
result = dict()
|
||||||
for command in self.commands:
|
for command in self.commands:
|
||||||
hints = get_type_hints(command)
|
hints = get_type_hints(command)
|
||||||
name = get_args(hints['type'])[0]
|
name = get_args(hints["type"])[0]
|
||||||
result.update({name:hints})
|
result.update({name: hints})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_command_options(self, command: str, switch: str)->List[str]:
|
def get_command_options(self, command: str, switch: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Return all the parameters that can be passed to the command as
|
Return all the parameters that can be passed to the command as
|
||||||
command-line switches. Returns None if the command is unrecognized.
|
command-line switches. Returns None if the command is unrecognized.
|
||||||
@ -102,42 +102,46 @@ class Completer(object):
|
|||||||
parsed_commands = self.parse_commands()
|
parsed_commands = self.parse_commands()
|
||||||
if command not in parsed_commands:
|
if command not in parsed_commands:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# handle switches in the format "-foo=bar"
|
# handle switches in the format "-foo=bar"
|
||||||
argument = None
|
argument = None
|
||||||
if switch and '=' in switch:
|
if switch and "=" in switch:
|
||||||
switch, argument = switch.split('=')
|
switch, argument = switch.split("=")
|
||||||
|
|
||||||
parameter = switch.strip('-')
|
parameter = switch.strip("-")
|
||||||
if parameter in parsed_commands[command]:
|
if parameter in parsed_commands[command]:
|
||||||
if argument is None:
|
if argument is None:
|
||||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||||
else:
|
else:
|
||||||
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
return [
|
||||||
|
f"--{parameter}={x}"
|
||||||
|
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||||
|
|
||||||
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Given a parameter type (such as Literal), offers autocompletions.
|
Given a parameter type (such as Literal), offers autocompletions.
|
||||||
"""
|
"""
|
||||||
if get_origin(typehint) == Literal:
|
if get_origin(typehint) == Literal:
|
||||||
return get_args(typehint)
|
return get_args(typehint)
|
||||||
if parameter == 'model':
|
if parameter == "model":
|
||||||
return self.manager.model_names()
|
return self.manager.model_names()
|
||||||
|
|
||||||
def _pre_input_hook(self):
|
def _pre_input_hook(self):
|
||||||
if self.linebuffer:
|
if self.linebuffer:
|
||||||
readline.insert_text(self.linebuffer)
|
readline.insert_text(self.linebuffer)
|
||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
self.linebuffer = None
|
self.linebuffer = None
|
||||||
|
|
||||||
|
|
||||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||||
global completer
|
global completer
|
||||||
|
|
||||||
if completer:
|
if completer:
|
||||||
return completer
|
return completer
|
||||||
|
|
||||||
completer = Completer(services.model_manager)
|
completer = Completer(services.model_manager)
|
||||||
|
|
||||||
readline.set_completer(completer.complete)
|
readline.set_completer(completer.complete)
|
||||||
@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
|||||||
pass
|
pass
|
||||||
except OSError: # file likely corrupted
|
except OSError: # file likely corrupted
|
||||||
newname = f"{histfile}.old"
|
newname = f"{histfile}.old"
|
||||||
logger.error(
|
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
||||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
|
||||||
)
|
|
||||||
histfile.replace(Path(newname))
|
histfile.replace(Path(newname))
|
||||||
atexit.register(readline.write_history_file, histfile)
|
atexit.register(readline.write_history_file, histfile)
|
||||||
|
@ -13,6 +13,7 @@ from pydantic.fields import Field
|
|||||||
# This should come early so that the logger can pick up its configuration options
|
# This should come early so that the logger can pick up its configuration options
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
config.parse_args()
|
config.parse_args()
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
|
|
||||||
# we call this early so that the message appears before other invokeai initialization messages
|
# we call this early so that the message appears before other invokeai initialization messages
|
||||||
if config.version:
|
if config.version:
|
||||||
print(f'InvokeAI version {__version__}')
|
print(f"InvokeAI version {__version__}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
@ -36,18 +37,21 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
create_system_graphs)
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
||||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
from .services.graph import (
|
||||||
GraphInvocation, LibraryGraph,
|
Edge,
|
||||||
are_connection_types_compatible)
|
EdgeConnection,
|
||||||
|
GraphExecutionState,
|
||||||
|
GraphInvocation,
|
||||||
|
LibraryGraph,
|
||||||
|
are_connection_types_compatible,
|
||||||
|
)
|
||||||
from .services.image_file_storage import DiskImageFileStorage
|
from .services.image_file_storage import DiskImageFileStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -58,6 +62,7 @@ from .services.sqlite import SqliteItemStorage
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
@ -69,6 +74,7 @@ class CliCommand(BaseModel):
|
|||||||
class InvalidArgs(Exception):
|
class InvalidArgs(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def add_invocation_args(command_parser):
|
def add_invocation_args(command_parser):
|
||||||
# Add linking capability
|
# Add linking capability
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
@ -113,7 +119,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class NodeField():
|
class NodeField:
|
||||||
alias: str
|
alias: str
|
||||||
node_path: str
|
node_path: str
|
||||||
field: str
|
field: str
|
||||||
@ -126,15 +132,20 @@ class NodeField():
|
|||||||
self.field_type = field_type
|
self.field_type = field_type
|
||||||
|
|
||||||
|
|
||||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
|
||||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
"""Gets the node field for the specified field alias"""
|
"""Gets the node field for the specified field alias"""
|
||||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
return NodeField(
|
||||||
|
alias=exposed_input.alias,
|
||||||
|
node_path=f"{node_id}.{exposed_input.node_path}",
|
||||||
|
field=exposed_input.field,
|
||||||
|
field_type=get_type_hints(node_type)[exposed_input.field],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
@ -142,7 +153,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
|
|||||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||||
node_output_type = node_type.get_output_type()
|
node_output_type = node_type.get_output_type()
|
||||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
return NodeField(
|
||||||
|
alias=exposed_output.alias,
|
||||||
|
node_path=f"{node_id}.{exposed_output.node_path}",
|
||||||
|
field=exposed_output.field,
|
||||||
|
field_type=get_type_hints(node_output_type)[exposed_output.field],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||||
@ -165,9 +181,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
|
|||||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||||
|
|
||||||
|
|
||||||
def generate_matching_edges(
|
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
|
||||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
|
||||||
) -> list[Edge]:
|
|
||||||
"""Generates all possible edges between two invocations"""
|
"""Generates all possible edges between two invocations"""
|
||||||
afields = get_node_outputs(a, context)
|
afields = get_node_outputs(a, context)
|
||||||
bfields = get_node_inputs(b, context)
|
bfields = get_node_inputs(b, context)
|
||||||
@ -179,12 +193,14 @@ def generate_matching_edges(
|
|||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
# Validate types
|
# Validate types
|
||||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
matching_fields = [
|
||||||
|
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
|
||||||
|
]
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
|
||||||
)
|
)
|
||||||
for alias in matching_fields
|
for alias in matching_fields
|
||||||
]
|
]
|
||||||
@ -193,6 +209,7 @@ def generate_matching_edges(
|
|||||||
|
|
||||||
class SessionError(Exception):
|
class SessionError(Exception):
|
||||||
"""Raised when a session error has occurred"""
|
"""Raised when a session error has occurred"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -209,22 +226,23 @@ def invoke_all(context: CliContext):
|
|||||||
context.invoker.services.logger.error(
|
context.invoker.services.logger.error(
|
||||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
raise SessionError()
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
logger.info(f'InvokeAI version {__version__}')
|
logger.info(f"InvokeAI version {__version__}")
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
parser.add_argument('commands',nargs='*')
|
parser.add_argument("commands", nargs="*")
|
||||||
invocation_commands = parser.parse_args().commands
|
invocation_commands = parser.parse_args().commands
|
||||||
|
|
||||||
# get the optional file to read commands from.
|
# get the optional file to read commands from.
|
||||||
# Simplest is to use it for STDIN
|
# Simplest is to use it for STDIN
|
||||||
if infile := config.from_file:
|
if infile := config.from_file:
|
||||||
sys.stdin = open(infile,"r")
|
sys.stdin = open(infile, "r")
|
||||||
|
|
||||||
model_manager = ModelManagerService(config,logger)
|
model_manager = ModelManagerService(config, logger)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
@ -234,13 +252,13 @@ def invoke_cli():
|
|||||||
db_location = ":memory:"
|
db_location = ":memory:"
|
||||||
else:
|
else:
|
||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
@ -281,24 +299,21 @@ def invoke_cli():
|
|||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
filename=db_location, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
system_graph_names = set([g.name for g in system_graphs])
|
system_graph_names = set([g.name for g in system_graphs])
|
||||||
@ -308,7 +323,7 @@ def invoke_cli():
|
|||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser(services)
|
parser = get_command_parser(services)
|
||||||
|
|
||||||
re_negid = re.compile('^-[0-9]+$')
|
re_negid = re.compile("^-[0-9]+$")
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
# print(services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
@ -318,7 +333,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
command_line_args_exist = len(invocation_commands) > 0
|
command_line_args_exist = len(invocation_commands) > 0
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
try:
|
try:
|
||||||
if command_line_args_exist:
|
if command_line_args_exist:
|
||||||
@ -332,7 +347,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
#history = list(get_graph_execution_history(context.session))
|
# history = list(get_graph_execution_history(context.session))
|
||||||
history = list(reversed(context.nodes_added))
|
history = list(reversed(context.nodes_added))
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
@ -353,17 +368,17 @@ def invoke_cli():
|
|||||||
args[field_name] = field_default
|
args[field_name] = field_default
|
||||||
|
|
||||||
# Parse invocation
|
# Parse invocation
|
||||||
command: CliCommand = None # type:ignore
|
command: CliCommand = None # type:ignore
|
||||||
system_graph: Optional[LibraryGraph] = None
|
system_graph: Optional[LibraryGraph] = None
|
||||||
if args['type'] in system_graph_names:
|
if args["type"] in system_graph_names:
|
||||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
|
||||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||||
for exposed_input in system_graph.exposed_inputs:
|
for exposed_input in system_graph.exposed_inputs:
|
||||||
if exposed_input.alias in args:
|
if exposed_input.alias in args:
|
||||||
node = invocation.graph.get_node(exposed_input.node_path)
|
node = invocation.graph.get_node(exposed_input.node_path)
|
||||||
field = exposed_input.field
|
field = exposed_input.field
|
||||||
setattr(node, field, args[exposed_input.alias])
|
setattr(node, field, args[exposed_input.alias])
|
||||||
command = CliCommand(command = invocation)
|
command = CliCommand(command=invocation)
|
||||||
context.graph_nodes[invocation.id] = system_graph.id
|
context.graph_nodes[invocation.id] = system_graph.id
|
||||||
else:
|
else:
|
||||||
args["id"] = current_id
|
args["id"] = current_id
|
||||||
@ -385,17 +400,13 @@ def invoke_cli():
|
|||||||
# Pipe previous command output (if there was a previous command)
|
# Pipe previous command output (if there was a previous command)
|
||||||
edges: list[Edge] = list()
|
edges: list[Edge] = list()
|
||||||
if len(history) > 0 or current_id != start_id:
|
if len(history) > 0 or current_id != start_id:
|
||||||
from_id = (
|
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
||||||
history[0] if current_id == start_id else str(current_id - 1)
|
|
||||||
)
|
|
||||||
from_node = (
|
from_node = (
|
||||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||||
if current_id != start_id
|
if current_id != start_id
|
||||||
else context.session.graph.get_node(from_id)
|
else context.session.graph.get_node(from_id)
|
||||||
)
|
)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(from_node, command.command, context)
|
||||||
from_node, command.command, context
|
|
||||||
)
|
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
# Parse provided links
|
# Parse provided links
|
||||||
@ -406,16 +417,18 @@ def invoke_cli():
|
|||||||
node_id = str(current_id + int(node_id))
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
link_node = context.session.graph.get_node(node_id)
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(link_node, command.command, context)
|
||||||
link_node, command.command, context
|
|
||||||
)
|
|
||||||
matching_destinations = [e.destination for e in matching_edges]
|
matching_destinations = [e.destination for e in matching_edges]
|
||||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
if "link" in args and args["link"]:
|
if "link" in args and args["link"]:
|
||||||
for link in args["link"]:
|
for link in args["link"]:
|
||||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
edges = [
|
||||||
|
e
|
||||||
|
for e in edges
|
||||||
|
if e.destination.node_id != command.command.id or e.destination.field != link[2]
|
||||||
|
]
|
||||||
|
|
||||||
node_id = link[0]
|
node_id = link[0]
|
||||||
if re_negid.match(node_id):
|
if re_negid.match(node_id):
|
||||||
@ -428,7 +441,7 @@ def invoke_cli():
|
|||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,9 +4,5 @@ __all__ = []
|
|||||||
|
|
||||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||||
for f in os.listdir(dirname):
|
for f in os.listdir(dirname):
|
||||||
if (
|
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
||||||
f != "__init__.py"
|
|
||||||
and os.path.isfile("%s/%s" % (dirname, f))
|
|
||||||
and f[-3:] == ".py"
|
|
||||||
):
|
|
||||||
__all__.append(f[:-3])
|
__all__.append(f[:-3])
|
||||||
|
@ -4,8 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
||||||
get_type_hints)
|
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel, Field
|
from pydantic import BaseConfig, BaseModel, Field
|
||||||
|
|
||||||
|
@ -8,8 +8,7 @@ from pydantic import Field, validator
|
|||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.models.image import ImageField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
||||||
InvocationConfig, InvocationContext, UIConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
@ -27,8 +26,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["float_collection"] = "float_collection"
|
type: Literal["float_collection"] = "float_collection"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[float] = Field(
|
collection: list[float] = Field(default=[], description="The float collection")
|
||||||
default=[], description="The float collection")
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
@ -37,8 +35,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["image_collection"] = "image_collection"
|
type: Literal["image_collection"] = "image_collection"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ImageField] = Field(
|
collection: list[ImageField] = Field(default=[], description="The output images")
|
||||||
default=[], description="The output images")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["type", "collection"]}
|
schema_extra = {"required": ["type", "collection"]}
|
||||||
@ -56,10 +53,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
||||||
"title": "Range",
|
|
||||||
"tags": ["range", "integer", "collection"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@validator("stop")
|
@validator("stop")
|
||||||
@ -69,9 +63,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(
|
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
collection=list(range(self.start, self.stop, self.step))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
@ -86,18 +78,11 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
||||||
"title": "Sized Range",
|
|
||||||
"tags": ["range", "integer", "size", "collection"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(
|
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
||||||
collection=list(
|
|
||||||
range(
|
|
||||||
self.start, self.start + self.size,
|
|
||||||
self.step)))
|
|
||||||
|
|
||||||
|
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
@ -107,9 +92,7 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
high: int = Field(
|
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
|
||||||
)
|
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
seed: int = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
@ -120,19 +103,12 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
||||||
"title": "Random Range",
|
|
||||||
"tags": ["range", "integer", "random", "collection"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
rng = np.random.default_rng(self.seed)
|
rng = np.random.default_rng(self.seed)
|
||||||
return IntCollectionOutput(
|
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||||
collection=list(
|
|
||||||
rng.integers(
|
|
||||||
low=self.low, high=self.high,
|
|
||||||
size=self.size)))
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
|
@ -3,64 +3,63 @@ from pydantic import BaseModel, Field
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import (Blend, Conjunction,
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
CrossAttentionControlSubstitute,
|
|
||||||
FlattenedPrompt, Fragment)
|
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.model_management import ModelType
|
from ...backend.model_management import ModelType
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
conditioning_name: Optional[str] = Field(
|
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
||||||
default=None, description="The name of conditioning data")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
schema_extra = {"required": ["conditioning_name"]}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
#type: Literal["basic_conditioning"] = "basic_conditioning"
|
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||||
# weight: float
|
# weight: float
|
||||||
# mode: ConditioningAlgo
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
ConditioningInfoType = Annotated[
|
|
||||||
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||||
Field(discriminator="type")
|
|
||||||
]
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningFieldData:
|
class ConditioningFieldData:
|
||||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||||
#unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
|
|
||||||
#class ConditioningAlgo(str, Enum):
|
|
||||||
|
# class ConditioningAlgo(str, Enum):
|
||||||
# Compose = "compose"
|
# Compose = "compose"
|
||||||
# ComposeEx = "compose_ex"
|
# ComposeEx = "compose_ex"
|
||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
class CompelOutput(BaseInvocationOutput):
|
class CompelOutput(BaseInvocationOutput):
|
||||||
"""Compel parser output"""
|
"""Compel parser output"""
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
type: Literal["compel_output"] = "compel_output"
|
type: Literal["compel_output"] = "compel_output"
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
@ -74,33 +73,28 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||||
"title": "Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(), context=context,
|
**self.clip.tokenizer.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(), context=context,
|
**self.clip.text_encoder.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||||
**lora.dict(exclude={"weight"}), context=context)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
@ -116,15 +110,18 @@ class CompelInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
#import traceback
|
# import traceback
|
||||||
#print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(
|
||||||
|
text_encoder_info.context.model, _lora_loader()
|
||||||
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
|
tokenizer,
|
||||||
|
ti_manager,
|
||||||
|
), ModelPatcher.apply_clip_skip(
|
||||||
|
text_encoder_info.context.model, self.clip.skipped_layers
|
||||||
|
), text_encoder_info as text_encoder:
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -139,14 +136,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
if context.services.configuration.log_tokenization:
|
if context.services.configuration.log_tokenization:
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
prompt)
|
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
tokenizer, conjunction),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
cross_attention_control_args=options.get(
|
)
|
||||||
"cross_attention_control", None),)
|
|
||||||
|
|
||||||
c = c.detach().to("cpu")
|
c = c.detach().to("cpu")
|
||||||
|
|
||||||
@ -168,24 +163,26 @@ class CompelInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(), context=context,
|
**clip_field.tokenizer.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(), context=context,
|
**clip_field.text_encoder.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||||
**lora.dict(exclude={"weight"}), context=context)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
@ -201,15 +198,18 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
#import traceback
|
# import traceback
|
||||||
#print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(
|
||||||
|
text_encoder_info.context.model, _lora_loader()
|
||||||
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
|
tokenizer,
|
||||||
|
ti_manager,
|
||||||
|
), ModelPatcher.apply_clip_skip(
|
||||||
|
text_encoder_info.context.model, clip_field.skipped_layers
|
||||||
|
), text_encoder_info as text_encoder:
|
||||||
text_inputs = tokenizer(
|
text_inputs = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
@ -241,21 +241,22 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(), context=context,
|
**clip_field.tokenizer.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(), context=context,
|
**clip_field.text_encoder.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
||||||
**lora.dict(exclude={"weight"}), context=context)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
@ -271,22 +272,25 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
#import traceback
|
# import traceback
|
||||||
#print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f"Warn: trigger: \"{trigger}\" not found")
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
|
||||||
text_encoder_info as text_encoder:
|
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(
|
||||||
|
text_encoder_info.context.model, _lora_loader()
|
||||||
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
|
tokenizer,
|
||||||
|
ti_manager,
|
||||||
|
), ModelPatcher.apply_clip_skip(
|
||||||
|
text_encoder_info.context.model, clip_field.skipped_layers
|
||||||
|
), text_encoder_info as text_encoder:
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=True,
|
requires_pooled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -320,6 +324,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@ -339,13 +344,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||||
"title": "SDXL Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -360,9 +359,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
target_size = (self.target_height, self.target_width)
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([
|
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||||
original_size + crop_coords + target_size
|
|
||||||
])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@ -384,12 +381,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = Field(1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = Field(1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = Field(0, description="")
|
||||||
@ -403,9 +401,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Prompt (Compel)",
|
"title": "SDXL Refiner Prompt (Compel)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {
|
"type_hints": {"model": "model"},
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,9 +412,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([
|
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||||
original_size + crop_coords + (self.aesthetic_score,)
|
|
||||||
])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@ -426,7 +420,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
embeds=c2,
|
embeds=c2,
|
||||||
pooled_embeds=c2_pooled,
|
pooled_embeds=c2_pooled,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
extra_conditioning=ec2, # or None
|
extra_conditioning=ec2, # or None
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -440,6 +434,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||||
|
|
||||||
@ -459,13 +454,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
||||||
"title": "SDXL Prompt (Raw)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -480,9 +469,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
target_size = (self.target_height, self.target_width)
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([
|
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
||||||
original_size + crop_coords + target_size
|
|
||||||
])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@ -504,12 +491,13 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = Field(1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = Field(1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = Field(0, description="")
|
||||||
@ -523,9 +511,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Prompt (Raw)",
|
"title": "SDXL Refiner Prompt (Raw)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {
|
"type_hints": {"model": "model"},
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -536,9 +522,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([
|
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||||
original_size + crop_coords + (self.aesthetic_score,)
|
|
||||||
])
|
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@ -546,7 +530,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
embeds=c2,
|
embeds=c2,
|
||||||
pooled_embeds=c2_pooled,
|
pooled_embeds=c2_pooled,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
extra_conditioning=ec2, # or None
|
extra_conditioning=ec2, # or None
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -563,11 +547,14 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||||
|
|
||||||
|
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
type: Literal["clip_skip"] = "clip_skip"
|
type: Literal["clip_skip"] = "clip_skip"
|
||||||
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = Field(None, description="Clip to use")
|
||||||
@ -575,10 +562,7 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
||||||
"title": "CLIP Skip",
|
|
||||||
"tags": ["clip", "skip"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
@ -589,46 +573,26 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||||
truncate_if_too_long=False) -> int:
|
) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
return max(
|
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
|
||||||
[
|
|
||||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
|
||||||
for p in blend.prompts
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif type(prompt) is Conjunction:
|
elif type(prompt) is Conjunction:
|
||||||
conjunction: Conjunction = prompt
|
conjunction: Conjunction = prompt
|
||||||
return sum(
|
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
|
||||||
[
|
|
||||||
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
|
||||||
for p in conjunction.prompts
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return len(
|
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||||
get_tokens_for_prompt_object(
|
|
||||||
tokenizer, prompt, truncate_if_too_long))
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(
|
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
|
||||||
) -> List[str]:
|
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
raise ValueError(
|
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
|
||||||
)
|
|
||||||
|
|
||||||
text_fragments = [
|
text_fragments = [
|
||||||
x.text
|
x.text
|
||||||
if type(x) is Fragment
|
if type(x) is Fragment
|
||||||
else (
|
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||||
" ".join([f.text for f in x.original])
|
|
||||||
if type(x) is CrossAttentionControlSubstitute
|
|
||||||
else str(x)
|
|
||||||
)
|
|
||||||
for x in parsed_prompt.children
|
for x in parsed_prompt.children
|
||||||
]
|
]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
@ -639,25 +603,17 @@ def get_tokens_for_prompt_object(
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_conjunction(
|
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||||
c: Conjunction, tokenizer, display_label_prefix=None
|
|
||||||
):
|
|
||||||
display_label_prefix = display_label_prefix or ""
|
display_label_prefix = display_label_prefix or ""
|
||||||
for i, p in enumerate(c.prompts):
|
for i, p in enumerate(c.prompts):
|
||||||
if len(c.prompts) > 1:
|
if len(c.prompts) > 1:
|
||||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||||
else:
|
else:
|
||||||
this_display_label_prefix = display_label_prefix
|
this_display_label_prefix = display_label_prefix
|
||||||
log_tokenization_for_prompt_object(
|
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||||
p,
|
|
||||||
tokenizer,
|
|
||||||
display_label_prefix=this_display_label_prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
|
||||||
):
|
|
||||||
display_label_prefix = display_label_prefix or ""
|
display_label_prefix = display_label_prefix or ""
|
||||||
if type(p) is Blend:
|
if type(p) is Blend:
|
||||||
blend: Blend = p
|
blend: Blend = p
|
||||||
@ -694,13 +650,10 @@ def log_tokenization_for_prompt_object(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = " ".join([x.text for x in flattened_prompt.children])
|
text = " ".join([x.text for x in flattened_prompt.children])
|
||||||
log_tokenization_for_text(
|
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||||
text, tokenizer, display_label=display_label_prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(
|
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||||
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
|
||||||
"""shows how the prompt is tokenized
|
"""shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
|
@ -6,20 +6,29 @@ from typing import Dict, List, Literal, Optional, Union
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
|
from controlnet_aux import (
|
||||||
LeresDetector, LineartAnimeDetector,
|
CannyDetector,
|
||||||
LineartDetector, MediapipeFaceDetector,
|
ContentShuffleDetector,
|
||||||
MidasDetector, MLSDdetector, NormalBaeDetector,
|
HEDdetector,
|
||||||
OpenposeDetector, PidiNetDetector, SamDetector,
|
LeresDetector,
|
||||||
ZoeDetector)
|
LineartAnimeDetector,
|
||||||
|
LineartDetector,
|
||||||
|
MediapipeFaceDetector,
|
||||||
|
MidasDetector,
|
||||||
|
MLSDdetector,
|
||||||
|
NormalBaeDetector,
|
||||||
|
OpenposeDetector,
|
||||||
|
PidiNetDetector,
|
||||||
|
SamDetector,
|
||||||
|
ZoeDetector,
|
||||||
|
)
|
||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from ..models.image import ImageOutput, PILInvocationConfig
|
from ..models.image import ImageOutput, PILInvocationConfig
|
||||||
|
|
||||||
CONTROLNET_DEFAULT_MODELS = [
|
CONTROLNET_DEFAULT_MODELS = [
|
||||||
@ -34,7 +43,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"lllyasviel/sd-controlnet-scribble",
|
"lllyasviel/sd-controlnet-scribble",
|
||||||
"lllyasviel/sd-controlnet-normal",
|
"lllyasviel/sd-controlnet-normal",
|
||||||
"lllyasviel/sd-controlnet-mlsd",
|
"lllyasviel/sd-controlnet-mlsd",
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||||
#############################################
|
#############################################
|
||||||
@ -56,7 +64,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"lllyasviel/control_v11e_sd15_shuffle",
|
"lllyasviel/control_v11e_sd15_shuffle",
|
||||||
"lllyasviel/control_v11e_sd15_ip2p",
|
"lllyasviel/control_v11e_sd15_ip2p",
|
||||||
"lllyasviel/control_v11f1e_sd15_tile",
|
"lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||||
##################################################
|
##################################################
|
||||||
@ -71,7 +78,6 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
# ControlNetMediaPipeface, ControlNet v1.1
|
# ControlNetMediaPipeface, ControlNet v1.1
|
||||||
##############################################
|
##############################################
|
||||||
@ -83,10 +89,17 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
CONTROLNET_MODE_VALUES = Literal[tuple(
|
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||||
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
CONTROLNET_RESIZE_VALUES = Literal[
|
||||||
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
tuple(
|
||||||
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
[
|
||||||
|
"just_resize",
|
||||||
|
"crop_resize",
|
||||||
|
"fill_resize",
|
||||||
|
"just_resize_simple",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelField(BaseModel):
|
class ControlNetModelField(BaseModel):
|
||||||
@ -98,21 +111,17 @@ class ControlNetModelField(BaseModel):
|
|||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[ControlNetModelField] = Field(
|
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||||
default=None, description="The ControlNet model to use")
|
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
control_weight: Union[float, List[float]] = Field(
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
default=1, description="The weight given to the ControlNet")
|
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1,
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
)
|
||||||
end_step_percent: float = Field(
|
end_step_percent: float = Field(
|
||||||
default=1, ge=0, le=1,
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
)
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
default="balanced", description="The control mode to use")
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
|
||||||
default="just_resize", description="The resize mode to use")
|
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
@ -120,11 +129,10 @@ class ControlField(BaseModel):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < -1 or i > 2:
|
if i < -1 or i > 2:
|
||||||
raise ValueError(
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
'Control weights must be within -1 to 2 range')
|
|
||||||
else:
|
else:
|
||||||
if v < -1 or v > 2:
|
if v < -1 or v > 2:
|
||||||
raise ValueError('Control weights must be within -1 to 2 range')
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -136,12 +144,13 @@ class ControlField(BaseModel):
|
|||||||
"control_model": "controlnet_model",
|
"control_model": "controlnet_model",
|
||||||
# "control_weight": "number",
|
# "control_weight": "number",
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ControlOutput(BaseInvocationOutput):
|
class ControlOutput(BaseInvocationOutput):
|
||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The control info")
|
control: ControlField = Field(default=None, description="The control info")
|
||||||
@ -150,6 +159,7 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -176,7 +186,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number",
|
||||||
"control_weight": "float",
|
"control_weight": "float",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,10 +215,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
||||||
"title": "Image Processor",
|
|
||||||
"tags": ["image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -233,7 +240,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.CONTROL,
|
image_category=ImageCategory.CONTROL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
@ -248,9 +255,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CannyImageProcessorInvocation(
|
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
# Input
|
# Input
|
||||||
@ -260,22 +267,18 @@ class CannyImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
||||||
"title": "Canny Processor",
|
|
||||||
"tags": ["controlnet", "canny", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
canny_processor = CannyDetector()
|
canny_processor = CannyDetector()
|
||||||
processed_image = canny_processor(
|
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||||
image, self.low_threshold, self.high_threshold)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class HedImageProcessorInvocation(
|
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -288,27 +291,25 @@ class HedImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
||||||
"title": "Softedge(HED) Processor",
|
|
||||||
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = hed_processor(image,
|
processed_image = hed_processor(
|
||||||
detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
image_resolution=self.image_resolution,
|
||||||
# safe=self.safe,
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
scribble=self.scribble,
|
# safe=self.safe,
|
||||||
)
|
scribble=self.scribble,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartImageProcessorInvocation(
|
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -319,24 +320,20 @@ class LineartImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
||||||
"title": "Lineart Processor",
|
|
||||||
"tags": ["controlnet", "lineart", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
lineart_processor = LineartDetector.from_pretrained(
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
|
||||||
processed_image = lineart_processor(
|
processed_image = lineart_processor(
|
||||||
image, detect_resolution=self.detect_resolution,
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
||||||
image_resolution=self.image_resolution, coarse=self.coarse)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartAnimeImageProcessorInvocation(
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -348,23 +345,23 @@ class LineartAnimeImageProcessorInvocation(
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Lineart Anime Processor",
|
"title": "Lineart Anime Processor",
|
||||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
|
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
processor = LineartAnimeDetector.from_pretrained(
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
processed_image = processor(
|
||||||
processed_image = processor(image,
|
image,
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution,
|
||||||
)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class OpenposeImageProcessorInvocation(
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -375,25 +372,23 @@ class OpenposeImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
||||||
"title": "Openpose Processor",
|
|
||||||
"tags": ["controlnet", "openpose", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
openpose_processor = OpenposeDetector.from_pretrained(
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
|
||||||
processed_image = openpose_processor(
|
processed_image = openpose_processor(
|
||||||
image, detect_resolution=self.detect_resolution,
|
image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution,
|
||||||
hand_and_face=self.hand_and_face,)
|
hand_and_face=self.hand_and_face,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MidasDepthImageProcessorInvocation(
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -405,26 +400,24 @@ class MidasDepthImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
||||||
"title": "Midas (Depth) Processor",
|
|
||||||
"tags": ["controlnet", "midas", "depth", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(image,
|
processed_image = midas_processor(
|
||||||
a=np.pi * self.a_mult,
|
image,
|
||||||
bg_th=self.bg_th,
|
a=np.pi * self.a_mult,
|
||||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
bg_th=self.bg_th,
|
||||||
# depth_and_normal=self.depth_and_normal,
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||||
)
|
# depth_and_normal=self.depth_and_normal,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class NormalbaeImageProcessorInvocation(
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -434,24 +427,20 @@ class NormalbaeImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
||||||
"title": "Normal BAE Processor",
|
|
||||||
"tags": ["controlnet", "normal", "bae", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained(
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
|
||||||
processed_image = normalbae_processor(
|
processed_image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution,
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
image_resolution=self.image_resolution)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MlsdImageProcessorInvocation(
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -463,24 +452,24 @@ class MlsdImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
||||||
"title": "MLSD Processor",
|
|
||||||
"tags": ["controlnet", "mlsd", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image, detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution, thr_v=self.thr_v,
|
detect_resolution=self.detect_resolution,
|
||||||
thr_d=self.thr_d)
|
image_resolution=self.image_resolution,
|
||||||
|
thr_v=self.thr_v,
|
||||||
|
thr_d=self.thr_d,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class PidiImageProcessorInvocation(
|
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -492,25 +481,24 @@ class PidiImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
||||||
"title": "PIDI Processor",
|
|
||||||
"tags": ["controlnet", "pidi", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
pidi_processor = PidiNetDetector.from_pretrained(
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image, detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution, safe=self.safe,
|
detect_resolution=self.detect_resolution,
|
||||||
scribble=self.scribble)
|
image_resolution=self.image_resolution,
|
||||||
|
safe=self.safe,
|
||||||
|
scribble=self.scribble,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class ContentShuffleImageProcessorInvocation(
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -525,48 +513,45 @@ class ContentShuffleImageProcessorInvocation(
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Content Shuffle Processor",
|
"title": "Content Shuffle Processor",
|
||||||
"tags": ["controlnet", "contentshuffle", "image", "processor"]
|
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(image,
|
processed_image = content_shuffle_processor(
|
||||||
detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
h=self.h,
|
image_resolution=self.image_resolution,
|
||||||
w=self.w,
|
h=self.h,
|
||||||
f=self.f
|
w=self.w,
|
||||||
)
|
f=self.f,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
class ZoeDepthImageProcessorInvocation(
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
||||||
"title": "Zoe (Depth) Processor",
|
|
||||||
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained(
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
"lllyasviel/Annotators")
|
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MediapipeFaceProcessorInvocation(
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -576,26 +561,22 @@ class MediapipeFaceProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
||||||
"title": "Mediapipe Processor",
|
|
||||||
"tags": ["controlnet", "mediapipe", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||||
# so convert to RGB if needed
|
# so convert to RGB if needed
|
||||||
if image.mode == 'RGBA':
|
if image.mode == "RGBA":
|
||||||
image = image.convert('RGB')
|
image = image.convert("RGB")
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||||
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LeresImageProcessorInvocation(
|
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -608,24 +589,23 @@ class LeresImageProcessorInvocation(
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
||||||
"title": "Leres (Depth) Processor",
|
|
||||||
"tags": ["controlnet", "leres", "depth", "image", "processor"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
|
image,
|
||||||
|
thr_a=self.thr_a,
|
||||||
|
thr_b=self.thr_b,
|
||||||
|
boost=self.boost,
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
image_resolution=self.image_resolution)
|
image_resolution=self.image_resolution,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class TileResamplerProcessorInvocation(
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -637,16 +617,17 @@ class TileResamplerProcessorInvocation(
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Tile Resample Processor",
|
"title": "Tile Resample Processor",
|
||||||
"tags": ["controlnet", "tile", "resample", "image", "processor"]
|
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||||
def tile_resample(self,
|
def tile_resample(
|
||||||
np_img: np.ndarray,
|
self,
|
||||||
res=512, # never used?
|
np_img: np.ndarray,
|
||||||
down_sampling_rate=1.0,
|
res=512, # never used?
|
||||||
):
|
down_sampling_rate=1.0,
|
||||||
|
):
|
||||||
np_img = HWC3(np_img)
|
np_img = HWC3(np_img)
|
||||||
if down_sampling_rate < 1.1:
|
if down_sampling_rate < 1.1:
|
||||||
return np_img
|
return np_img
|
||||||
@ -658,36 +639,41 @@ class TileResamplerProcessorInvocation(
|
|||||||
|
|
||||||
def run_processor(self, img):
|
def run_processor(self, img):
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(img, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(np_img,
|
processed_np_image = self.tile_resample(
|
||||||
# res=self.tile_size,
|
np_img,
|
||||||
down_sampling_rate=self.down_sampling_rate
|
# res=self.tile_size,
|
||||||
)
|
down_sampling_rate=self.down_sampling_rate,
|
||||||
|
)
|
||||||
processed_image = Image.fromarray(processed_np_image)
|
processed_image = Image.fromarray(processed_np_image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingProcessorInvocation(
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
ImageProcessorInvocation, PILInvocationConfig):
|
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
|
schema_extra = {
|
||||||
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
|
"ui": {
|
||||||
|
"title": "Segment Anything Processor",
|
||||||
|
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints")
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||||
|
)
|
||||||
np_img = np.array(image, dtype=np.uint8)
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
processed_image = segment_anything_processor(np_img)
|
processed_image = segment_anything_processor(np_img)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SamDetectorReproducibleColors(SamDetector):
|
class SamDetectorReproducibleColors(SamDetector):
|
||||||
|
|
||||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||||
# base class show_anns() method randomizes colors,
|
# base class show_anns() method randomizes colors,
|
||||||
# which seems to also lead to non-reproducible image generation
|
# which seems to also lead to non-reproducible image generation
|
||||||
@ -695,19 +681,15 @@ class SamDetectorReproducibleColors(SamDetector):
|
|||||||
def show_anns(self, anns: List[Dict]):
|
def show_anns(self, anns: List[Dict]):
|
||||||
if len(anns) == 0:
|
if len(anns) == 0:
|
||||||
return
|
return
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
||||||
h, w = anns[0]['segmentation'].shape
|
h, w = anns[0]["segmentation"].shape
|
||||||
final_img = Image.fromarray(
|
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||||
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
|
||||||
palette = ade_palette()
|
palette = ade_palette()
|
||||||
for i, ann in enumerate(sorted_anns):
|
for i, ann in enumerate(sorted_anns):
|
||||||
m = ann['segmentation']
|
m = ann["segmentation"]
|
||||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||||
ann_color = palette[i % len(palette)]
|
ann_color = palette[i % len(palette)]
|
||||||
img[:, :] = ann_color
|
img[:, :] = ann_color
|
||||||
final_img.paste(
|
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||||
Image.fromarray(img, mode="RGB"),
|
|
||||||
(0, 0),
|
|
||||||
Image.fromarray(np.uint8(m * 255)))
|
|
||||||
return np.array(final_img, dtype=np.uint8)
|
return np.array(final_img, dtype=np.uint8)
|
||||||
|
@ -37,10 +37,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
||||||
"title": "OpenCV Inpaint",
|
|
||||||
"tags": ["opencv", "inpaint"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
@ -6,8 +6,7 @@ from typing import Literal, Optional, get_args
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||||
ResourceOrigin)
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
|
|
||||||
@ -25,13 +24,12 @@ from contextlib import contextmanager, ExitStack, ContextDecorator
|
|||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = (
|
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from .latent import get_scheduler
|
from .latent import get_scheduler
|
||||||
|
|
||||||
|
|
||||||
class OldModelContext(ContextDecorator):
|
class OldModelContext(ContextDecorator):
|
||||||
model: StableDiffusionGeneratorPipeline
|
model: StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
@ -44,6 +42,7 @@ class OldModelContext(ContextDecorator):
|
|||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OldModelInfo:
|
class OldModelInfo:
|
||||||
name: str
|
name: str
|
||||||
hash: str
|
hash: str
|
||||||
@ -64,20 +63,34 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
|
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
seed: int = Field(
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
)
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
width: int = Field(
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The width of the resulting image",
|
||||||
|
)
|
||||||
|
height: int = Field(
|
||||||
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The height of the resulting image",
|
||||||
|
)
|
||||||
|
cfg_scale: float = Field(
|
||||||
|
default=7.5,
|
||||||
|
ge=1,
|
||||||
|
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
||||||
|
)
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
||||||
unet: UNetField = Field(default=None, description="UNet model")
|
unet: UNetField = Field(default=None, description="UNet model")
|
||||||
vae: VaeField = Field(default=None, description="Vae model")
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
image: Optional[ImageField] = Field(description="The input image")
|
||||||
strength: float = Field(
|
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
||||||
default=0.75, gt=0, le=1, description="The strength of the original image"
|
|
||||||
)
|
|
||||||
fit: bool = Field(
|
fit: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
@ -86,18 +99,10 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
mask: Optional[ImageField] = Field(description="The mask")
|
mask: Optional[ImageField] = Field(description="The mask")
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
seam_blur: int = Field(
|
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||||
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
|
||||||
)
|
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||||
seam_strength: float = Field(
|
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
|
||||||
)
|
|
||||||
seam_steps: int = Field(
|
|
||||||
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
|
||||||
)
|
|
||||||
tile_size: int = Field(
|
|
||||||
default=32, ge=1, description="The tile infill method size (px)"
|
|
||||||
)
|
|
||||||
infill_method: INFILL_METHODS = Field(
|
infill_method: INFILL_METHODS = Field(
|
||||||
default=DEFAULT_INFILL_METHOD,
|
default=DEFAULT_INFILL_METHOD,
|
||||||
description="The method used to infill empty regions (px)",
|
description="The method used to infill empty regions (px)",
|
||||||
@ -128,10 +133,7 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
||||||
"tags": ["stable-diffusion", "image"],
|
|
||||||
"title": "Inpaint"
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
@ -162,18 +164,23 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}), context=context,)
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
|
||||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
|
||||||
|
|
||||||
with vae_info as vae,\
|
unet_info = context.services.model_manager.get_model(
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
**self.unet.unet.dict(),
|
||||||
unet_info as unet:
|
context=context,
|
||||||
|
)
|
||||||
|
vae_info = context.services.model_manager.get_model(
|
||||||
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
device = context.services.model_manager.mgr.cache.execution_device
|
device = context.services.model_manager.mgr.cache.execution_device
|
||||||
dtype = context.services.model_manager.mgr.cache.precision
|
dtype = context.services.model_manager.mgr.cache.precision
|
||||||
|
|
||||||
@ -197,21 +204,11 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
|
||||||
None
|
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
|
||||||
if self.image is None
|
|
||||||
else context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
)
|
|
||||||
mask = (
|
|
||||||
None
|
|
||||||
if self.mask is None
|
|
||||||
else context.services.images.get_pil_image(self.mask.image_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
|
@ -9,9 +9,13 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from ..models.image import (
|
from ..models.image import (
|
||||||
ImageCategory, ImageField, ResourceOrigin,
|
ImageCategory,
|
||||||
PILInvocationConfig, ImageOutput, MaskOutput,
|
ImageField,
|
||||||
)
|
ResourceOrigin,
|
||||||
|
PILInvocationConfig,
|
||||||
|
ImageOutput,
|
||||||
|
MaskOutput,
|
||||||
|
)
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -20,6 +24,7 @@ from .baseinvocation import (
|
|||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
|
|
||||||
|
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image and provide it as output."""
|
"""Load an image and provide it as output."""
|
||||||
|
|
||||||
@ -34,10 +39,7 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
||||||
"title": "Load Image",
|
|
||||||
"tags": ["image", "load"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -56,16 +58,11 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
type: Literal["show_image"] = "show_image"
|
type: Literal["show_image"] = "show_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(
|
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
||||||
default=None, description="The image to show"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
||||||
"title": "Show Image",
|
|
||||||
"tags": ["image", "show"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -98,18 +95,13 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
||||||
"title": "Crop Image",
|
|
||||||
"tags": ["image", "crop"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_crop = Image.new(
|
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
|
||||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
|
||||||
)
|
|
||||||
image_crop.paste(image, (-self.x, -self.y))
|
image_crop.paste(image, (-self.x, -self.y))
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@ -144,21 +136,14 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
||||||
"title": "Paste Image",
|
|
||||||
"tags": ["image", "paste"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
|
||||||
if self.mask is None
|
|
||||||
else ImageOps.invert(
|
|
||||||
context.services.images.get_pil_image(self.mask.image_name)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||||
|
|
||||||
@ -167,9 +152,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
max_x = max(base_image.width, image.width + self.x)
|
max_x = max(base_image.width, image.width + self.x)
|
||||||
max_y = max(base_image.height, image.height + self.y)
|
max_y = max(base_image.height, image.height + self.y)
|
||||||
|
|
||||||
new_image = Image.new(
|
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
||||||
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
|
||||||
)
|
|
||||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||||
|
|
||||||
@ -202,10 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
||||||
"title": "Mask From Alpha",
|
|
||||||
"tags": ["image", "mask", "alpha"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
@ -244,10 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
||||||
"title": "Multiply Images",
|
|
||||||
"tags": ["image", "multiply"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -288,10 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
||||||
"title": "Image Channel",
|
|
||||||
"tags": ["image", "channel"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -331,10 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
||||||
"title": "Convert Image",
|
|
||||||
"tags": ["image", "convert"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -357,6 +328,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@ -371,19 +343,14 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
||||||
"title": "Blur Image",
|
|
||||||
"tags": ["image", "blur"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
blur = (
|
blur = (
|
||||||
ImageFilter.GaussianBlur(self.radius)
|
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
||||||
if self.blur_type == "gaussian"
|
|
||||||
else ImageFilter.BoxBlur(self.radius)
|
|
||||||
)
|
)
|
||||||
blur_image = image.filter(blur)
|
blur_image = image.filter(blur)
|
||||||
|
|
||||||
@ -438,10 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
||||||
"title": "Resize Image",
|
|
||||||
"tags": ["image", "resize"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -484,10 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
||||||
"title": "Scale Image",
|
|
||||||
"tags": ["image", "scale"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -532,10 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
||||||
"title": "Image Linear Interpolation",
|
|
||||||
"tags": ["image", "linear", "interpolation", "lerp"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -561,6 +519,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@ -577,7 +536,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Image Inverse Linear Interpolation",
|
"title": "Image Inverse Linear Interpolation",
|
||||||
"tags": ["image", "linear", "interpolation", "inverse"]
|
"tags": ["image", "linear", "interpolation", "inverse"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,12 +544,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
image_arr = (
|
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
||||||
numpy.minimum(
|
|
||||||
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
|
||||||
)
|
|
||||||
* 255
|
|
||||||
)
|
|
||||||
|
|
||||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
@ -609,6 +563,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@ -622,22 +577,19 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
||||||
"title": "Blur NSFW Images",
|
|
||||||
"tags": ["image", "nsfw", "checker"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
logger = context.services.logger
|
logger = context.services.logger
|
||||||
logger.debug("Running NSFW checker")
|
logger.debug("Running NSFW checker")
|
||||||
if SafetyChecker.has_nsfw_concept(image):
|
if SafetyChecker.has_nsfw_concept(image):
|
||||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
caution = self._get_caution_img()
|
caution = self._get_caution_img()
|
||||||
blurry_image.paste(caution,(0,0),caution)
|
blurry_image.paste(caution, (0, 0), caution)
|
||||||
image = blurry_image
|
image = blurry_image
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@ -649,20 +601,22 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_caution_img(self)->Image:
|
def _get_caution_img(self) -> Image:
|
||||||
import invokeai.app.assets.images as image_assets
|
import invokeai.app.assets.images as image_assets
|
||||||
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
|
|
||||||
return caution.resize((caution.width // 2, caution.height //2))
|
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||||
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
|
||||||
|
|
||||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
""" Add an invisible watermark to an image """
|
"""Add an invisible watermark to an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
type: Literal["img_watermark"] = "img_watermark"
|
||||||
@ -675,10 +629,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
||||||
"title": "Add Invisible Watermark",
|
|
||||||
"tags": ["image", "watermark", "invisible"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -699,6 +650,3 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,9 +30,7 @@ def infill_methods() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = (
|
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||||
@ -44,9 +42,7 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
im_patched_np = PatchMatch.inpaint(
|
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
|
||||||
)
|
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
@ -68,9 +64,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def tile_fill_missing(
|
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||||
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
|
||||||
) -> Image.Image:
|
|
||||||
# Only fill if there's an alpha layer
|
# Only fill if there's an alpha layer
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
@ -103,9 +97,7 @@ def tile_fill_missing(
|
|||||||
# Find all invalid tiles and replace with a random valid tile
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
replace_count = (tiles_mask == False).sum()
|
replace_count = (tiles_mask == False).sum()
|
||||||
rng = np.random.default_rng(seed=seed)
|
rng = np.random.default_rng(seed=seed)
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert back to an image
|
# Convert back to an image
|
||||||
tiles_all = tiles_all.reshape(tshape)
|
tiles_all = tiles_all.reshape(tshape)
|
||||||
@ -126,9 +118,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
image: Optional[ImageField] = Field(
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
default=None, description="The image to infill"
|
|
||||||
)
|
|
||||||
color: ColorField = Field(
|
color: ColorField = Field(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
@ -136,10 +126,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
||||||
"title": "Color Infill",
|
|
||||||
"tags": ["image", "inpaint", "color", "infill"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -171,9 +158,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
default=None, description="The image to infill"
|
|
||||||
)
|
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
@ -184,18 +169,13 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
||||||
"title": "Tile Infill",
|
|
||||||
"tags": ["image", "inpaint", "tile", "infill"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
infilled = tile_fill_missing(
|
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
|
||||||
)
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@ -219,16 +199,11 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
default=None, description="The image to infill"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
||||||
"title": "Patch Match Infill",
|
|
||||||
"tags": ["image", "inpaint", "patchmatch", "infill"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
@ -17,15 +17,16 @@ from invokeai.backend.model_management.models.base import ModelType
|
|||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
ConditioningData,
|
||||||
image_resized_to_grid_as_tensor)
|
ControlNetData,
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
StableDiffusionGeneratorPipeline,
|
||||||
PostprocessingSettings
|
image_resized_to_grid_as_tensor,
|
||||||
|
)
|
||||||
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@ -46,8 +47,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
default=None, description="The name of the latents")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
@ -55,14 +55,15 @@ class LatentsField(BaseModel):
|
|||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["latents_output"] = "latents_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
width: int = Field(description="The width of the latents in pixels")
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
height: int = Field(description="The height of the latents in pixels")
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
@ -73,9 +74,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
tuple(list(SCHEDULER_MAP.keys()))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
@ -83,11 +82,10 @@ def get_scheduler(
|
|||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
scheduler_name, SCHEDULER_MAP['ddim']
|
|
||||||
)
|
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict(), context=context,
|
**scheduler_info.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@ -102,7 +100,7 @@ def get_scheduler(
|
|||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
@ -123,8 +121,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -133,10 +131,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -149,8 +147,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number"
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,16 +188,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
warmup=0.2, # warmup,
|
warmup=0.2, # warmup,
|
||||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||||
scheduler,
|
scheduler,
|
||||||
|
|
||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
|
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
@ -247,7 +243,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * 8
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * 8
|
||||||
@ -261,7 +256,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_list = control_input
|
control_list = control_input
|
||||||
else:
|
else:
|
||||||
control_list = None
|
control_list = None
|
||||||
if (control_list is None):
|
if control_list is None:
|
||||||
control_data = None
|
control_data = None
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
else:
|
else:
|
||||||
@ -281,9 +276,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(
|
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||||
control_image_field.image_name
|
|
||||||
)
|
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -321,9 +314,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -332,19 +323,20 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}), context=context,
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(), context=context,
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
unet_info.context.model, _lora_loader()
|
||||||
unet_info as unet:
|
), unet_info as unet:
|
||||||
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -357,7 +349,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline,
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
@ -378,7 +372,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
@ -389,11 +383,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
description="The latents to use as a base image")
|
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||||
strength: float = Field(
|
|
||||||
default=0.7, ge=0, le=1,
|
|
||||||
description="The strength of the latents to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -405,7 +396,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,9 +406,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -426,19 +415,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}), context=context,
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(), context=context,
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
unet_info.context.model, _lora_loader()
|
||||||
unet_info as unet:
|
), unet_info as unet:
|
||||||
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
@ -452,7 +442,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline,
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
@ -460,8 +452,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = (
|
||||||
latent, device=unet.device, dtype=latent.dtype
|
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
@ -477,14 +469,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=control_data, # list[ControlNetData]
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
@ -496,14 +488,13 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
description="The latents to generate an image from")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(
|
tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)")
|
||||||
default=False,
|
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||||
description="Decode latents by overlapping tiles(less memory consumption)")
|
metadata: Optional[CoreMetadata] = Field(
|
||||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
default=None, description="Optional core metadata to be written to the image"
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
)
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -519,7 +510,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(), context=context,
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -586,8 +578,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
@ -596,36 +587,30 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
description="The latents to resize")
|
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
width: Union[int, None] = Field(default=512,
|
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
ge=64, multiple_of=8, description="The width to resize to (px)")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
height: Union[int, None] = Field(default=512,
|
|
||||||
ge=64, multiple_of=8, description="The height to resize to (px)")
|
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
|
||||||
default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False,
|
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||||
"title": "Resize Latents",
|
|
||||||
"tags": ["latents", "resize"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device=choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device), size=(self.height // 8, self.width // 8),
|
latents.to(device),
|
||||||
mode=self.mode, antialias=self.antialias
|
size=(self.height // 8, self.width // 8),
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@ -644,35 +629,30 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||||
description="The latents to scale")
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||||
scale_factor: float = Field(
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
gt=0, description="The factor by which to scale the latents")
|
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
|
||||||
default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False,
|
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||||
"title": "Scale Latents",
|
|
||||||
"tags": ["latents", "scale"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device=choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
latents.to(device),
|
||||||
antialias=self.antialias
|
scale_factor=self.scale_factor,
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@ -693,19 +673,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The image to encode")
|
image: Optional[ImageField] = Field(description="The image to encode")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(
|
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||||
default=False,
|
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
|
||||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||||
"title": "Image To Latents",
|
|
||||||
"tags": ["latents", "image"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -715,9 +689,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# )
|
# )
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(), context=context,
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
@ -744,12 +719,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.post_quant_conv.to(orig_dtype)
|
vae.post_quant_conv.to(orig_dtype)
|
||||||
vae.decoder.conv_in.to(orig_dtype)
|
vae.decoder.conv_in.to(orig_dtype)
|
||||||
vae.decoder.mid_block.to(orig_dtype)
|
vae.decoder.mid_block.to(orig_dtype)
|
||||||
#else:
|
# else:
|
||||||
# latents = latents.float()
|
# latents = latents.float()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
#latents = latents.half()
|
# latents = latents.half()
|
||||||
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
@ -760,9 +735,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
latents = image_tensor_dist.sample().to(
|
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
dtype=vae.dtype
|
|
||||||
) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
@ -54,10 +54,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Add", "tags": ["math", "add"]},
|
||||||
"title": "Add",
|
|
||||||
"tags": ["math", "add"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@ -75,10 +72,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
||||||
"title": "Subtract",
|
|
||||||
"tags": ["math", "subtract"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@ -96,10 +90,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
||||||
"title": "Multiply",
|
|
||||||
"tags": ["math", "multiply"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@ -117,10 +108,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
||||||
"title": "Divide",
|
|
||||||
"tags": ["math", "divide"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@ -140,10 +128,7 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
||||||
"title": "Random Integer",
|
|
||||||
"tags": ["math", "random", "integer"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
@ -37,9 +38,7 @@ class CoreMetadata(BaseModel):
|
|||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(
|
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||||
description="The ControlNets used for inference"
|
|
||||||
)
|
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Union[VAEModelField, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -51,38 +50,24 @@ class CoreMetadata(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(
|
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||||
default=None, description="The name of the initial image"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(
|
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||||
default=None, description="The positive style prompt parameter"
|
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||||
)
|
|
||||||
negative_style_prompt: Union[str, None] = Field(
|
|
||||||
default=None, description="The negative style prompt parameter"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(
|
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||||
default=None, description="The SDXL Refiner model used"
|
|
||||||
)
|
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(
|
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
default=None, description="The number of steps used for the refiner"
|
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
)
|
|
||||||
refiner_scheduler: Union[str, None] = Field(
|
|
||||||
default=None, description="The scheduler used for the refiner"
|
|
||||||
)
|
|
||||||
refiner_aesthetic_store: Union[float, None] = Field(
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(
|
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
default=None, description="The start value used for refiner denoising"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
class ImageMetadata(BaseModel):
|
||||||
@ -92,9 +77,7 @@ class ImageMetadata(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||||
)
|
)
|
||||||
graph: Optional[dict] = Field(
|
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
||||||
default=None, description="The graph that created the image"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||||
@ -126,50 +109,34 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(
|
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||||
description="The ControlNets used for inference"
|
|
||||||
)
|
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
strength: Union[float, None] = Field(
|
strength: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(
|
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
||||||
default=None, description="The name of the initial image"
|
|
||||||
)
|
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Union[VAEModelField, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(
|
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
||||||
default=None, description="The positive style prompt parameter"
|
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
||||||
)
|
|
||||||
negative_style_prompt: Union[str, None] = Field(
|
|
||||||
default=None, description="The negative style prompt parameter"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(
|
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
||||||
default=None, description="The SDXL Refiner model used"
|
|
||||||
)
|
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(
|
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
||||||
default=None, description="The number of steps used for the refiner"
|
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
||||||
)
|
|
||||||
refiner_scheduler: Union[str, None] = Field(
|
|
||||||
default=None, description="The scheduler used for the refiner"
|
|
||||||
)
|
|
||||||
refiner_aesthetic_store: Union[float, None] = Field(
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(
|
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
default=None, description="The start value used for refiner denoising"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
|
@ -4,17 +4,14 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
submodel: Optional[SubModelType] = Field(
|
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||||
default=None, description="Info to load submodel"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
@ -33,6 +30,7 @@ class ClipField(BaseModel):
|
|||||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
@ -49,6 +47,7 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
"""Main model field"""
|
"""Main model field"""
|
||||||
|
|
||||||
@ -62,6 +61,7 @@ class LoRAModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the LoRA model")
|
model_name: str = Field(description="Name of the LoRA model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
@ -197,9 +197,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(
|
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||||
default=None, description="Lora model name"
|
|
||||||
)
|
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
@ -228,14 +226,10 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
):
|
):
|
||||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||||
|
|
||||||
if self.unet is not None and any(
|
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||||
lora.model_name == lora_name for lora in self.unet.loras
|
|
||||||
):
|
|
||||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(
|
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||||
lora.model_name == lora_name for lora in self.clip.loras
|
|
||||||
):
|
|
||||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
output = LoraLoaderOutput()
|
output = LoraLoaderOutput()
|
||||||
|
@ -12,16 +12,37 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from easing_functions import (
|
from easing_functions import (
|
||||||
LinearInOut,
|
LinearInOut,
|
||||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
QuadEaseInOut,
|
||||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
QuadEaseIn,
|
||||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
QuadEaseOut,
|
||||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
CubicEaseInOut,
|
||||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
CubicEaseIn,
|
||||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
CubicEaseOut,
|
||||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
QuarticEaseInOut,
|
||||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
QuarticEaseIn,
|
||||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
QuarticEaseOut,
|
||||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
QuinticEaseInOut,
|
||||||
|
QuinticEaseIn,
|
||||||
|
QuinticEaseOut,
|
||||||
|
SineEaseInOut,
|
||||||
|
SineEaseIn,
|
||||||
|
SineEaseOut,
|
||||||
|
CircularEaseIn,
|
||||||
|
CircularEaseInOut,
|
||||||
|
CircularEaseOut,
|
||||||
|
ExponentialEaseInOut,
|
||||||
|
ExponentialEaseIn,
|
||||||
|
ExponentialEaseOut,
|
||||||
|
ElasticEaseIn,
|
||||||
|
ElasticEaseInOut,
|
||||||
|
ElasticEaseOut,
|
||||||
|
BackEaseIn,
|
||||||
|
BackEaseInOut,
|
||||||
|
BackEaseOut,
|
||||||
|
BounceEaseIn,
|
||||||
|
BounceEaseInOut,
|
||||||
|
BounceEaseOut,
|
||||||
|
)
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -45,17 +66,12 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
||||||
"title": "Linear Range (Float)",
|
|
||||||
"tags": ["math", "float", "linear", "range"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
return FloatCollectionOutput(
|
return FloatCollectionOutput(collection=param_list)
|
||||||
collection=param_list
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EASING_FUNCTIONS_MAP = {
|
EASING_FUNCTIONS_MAP = {
|
||||||
@ -92,9 +108,7 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS: Any = Literal[
|
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@ -123,13 +137,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
||||||
"title": "Param Easing By Step",
|
|
||||||
"tags": ["param", "step", "easing"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
log_diagnostics = False
|
log_diagnostics = False
|
||||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||||
@ -170,12 +180,13 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
# and create reverse copy of list[1:end-1]
|
# and create reverse copy of list[1:end-1]
|
||||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||||
|
|
||||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||||
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
if log_diagnostics:
|
||||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
easing_function = easing_class(start=self.start_value,
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||||
end=self.end_value,
|
easing_function = easing_class(
|
||||||
duration=base_easing_duration - 1)
|
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
||||||
|
)
|
||||||
base_easing_vals = list()
|
base_easing_vals = list()
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
@ -214,9 +225,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
#
|
#
|
||||||
|
|
||||||
else: # no mirroring (default)
|
else: # no mirroring (default)
|
||||||
easing_function = easing_class(start=self.start_value,
|
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
||||||
end=self.end_value,
|
|
||||||
duration=num_easing_steps - 1)
|
|
||||||
for step_index in range(num_easing_steps):
|
for step_index in range(num_easing_steps):
|
||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
@ -240,13 +249,11 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
plt.savefig(buf, format='png')
|
plt.savefig(buf, format="png")
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
im = PIL.Image.open(buf)
|
im = PIL.Image.open(buf)
|
||||||
im.show()
|
im.show()
|
||||||
buf.close()
|
buf.close()
|
||||||
|
|
||||||
# output array of size steps, each entry list[i] is param value for step i
|
# output array of size steps, each entry list[i] is param value for step i
|
||||||
return FloatCollectionOutput(
|
return FloatCollectionOutput(collection=param_list)
|
||||||
collection=param_list
|
|
||||||
)
|
|
||||||
|
@ -4,67 +4,63 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from .math import FloatOutput, IntOutput
|
from .math import FloatOutput, IntOutput
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
|
|
||||||
class ParamIntInvocation(BaseInvocation):
|
class ParamIntInvocation(BaseInvocation):
|
||||||
"""An integer parameter"""
|
"""An integer parameter"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["param_int"] = "param_int"
|
type: Literal["param_int"] = "param_int"
|
||||||
a: int = Field(default=0, description="The integer value")
|
a: int = Field(default=0, description="The integer value")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
||||||
"tags": ["param", "integer"],
|
}
|
||||||
"title": "Integer Parameter"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a)
|
return IntOutput(a=self.a)
|
||||||
|
|
||||||
|
|
||||||
class ParamFloatInvocation(BaseInvocation):
|
class ParamFloatInvocation(BaseInvocation):
|
||||||
"""A float parameter"""
|
"""A float parameter"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["param_float"] = "param_float"
|
type: Literal["param_float"] = "param_float"
|
||||||
param: float = Field(default=0.0, description="The float value")
|
param: float = Field(default=0.0, description="The float value")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
||||||
"tags": ["param", "float"],
|
}
|
||||||
"title": "Float Parameter"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
return FloatOutput(param=self.param)
|
return FloatOutput(param=self.param)
|
||||||
|
|
||||||
|
|
||||||
class StringOutput(BaseInvocationOutput):
|
class StringOutput(BaseInvocationOutput):
|
||||||
"""A string output"""
|
"""A string output"""
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
type: Literal["string_output"] = "string_output"
|
||||||
text: str = Field(default=None, description="The output string")
|
text: str = Field(default=None, description="The output string")
|
||||||
|
|
||||||
|
|
||||||
class ParamStringInvocation(BaseInvocation):
|
class ParamStringInvocation(BaseInvocation):
|
||||||
"""A string parameter"""
|
"""A string parameter"""
|
||||||
type: Literal['param_string'] = 'param_string'
|
|
||||||
text: str = Field(default='', description='The string value')
|
type: Literal["param_string"] = "param_string"
|
||||||
|
text: str = Field(default="", description="The string value")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
||||||
"tags": ["param", "string"],
|
}
|
||||||
"title": "String Parameter"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
return StringOutput(text=self.text)
|
return StringOutput(text=self.text)
|
||||||
|
|
@ -7,19 +7,21 @@ from pydantic import Field, validator
|
|||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||||
|
|
||||||
|
|
||||||
class PromptOutput(BaseInvocationOutput):
|
class PromptOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a prompt"""
|
"""Base class for invocations that output a prompt"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["prompt"] = "prompt"
|
type: Literal["prompt"] = "prompt"
|
||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
prompt: str = Field(default=None, description="The output prompt")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'type',
|
"type",
|
||||||
'prompt',
|
"prompt",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,16 +46,11 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = Field(
|
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
||||||
default=False, description="Whether to use the combinatorial generator"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
||||||
"title": "Dynamic Prompt",
|
|
||||||
"tags": ["prompt", "dynamic"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
@ -65,10 +62,11 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
|
||||||
|
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
|
||||||
|
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
'''Loads prompts from a text file'''
|
"""Loads prompts from a text file"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||||
|
|
||||||
@ -78,14 +76,11 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
||||||
"title": "Prompts From File",
|
|
||||||
"tags": ["prompt", "file"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@validator("file_path")
|
@validator("file_path")
|
||||||
@ -103,11 +98,13 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
if i >= start_line and i < end_line:
|
if i >= start_line and i < end_line:
|
||||||
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
||||||
if i >= end_line:
|
if i >= end_line:
|
||||||
break
|
break
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
prompts = self.promptsFromFile(
|
||||||
|
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
||||||
|
)
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
@ -7,13 +7,13 @@ from pydantic import Field, validator
|
|||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
|
|
||||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
@ -26,16 +26,19 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -125,8 +128,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
model: MainModelField = Field(description="The model to load")
|
||||||
@ -196,7 +201,8 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
@ -213,9 +219,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -224,10 +230,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -237,10 +243,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
"title": "SDXL Text To Latents",
|
"title": "SDXL Text To Latents",
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number"
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,9 +271,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.noise.latents_name)
|
latents = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
@ -293,14 +297,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
latents = latents * scheduler.init_noise_sigma
|
latents = latents * scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
unet_info = context.services.model_manager.get_model(
|
|
||||||
**self.unet.unet.dict(), context=context
|
|
||||||
)
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = dict()
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
extra_step_kwargs.update(
|
extra_step_kwargs.update(
|
||||||
@ -350,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
#del noise_pred_uncond
|
# del noise_pred_uncond
|
||||||
#del noise_pred_text
|
# del noise_pred_text
|
||||||
|
|
||||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
@ -364,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
# if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -378,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
with tqdm(total=num_inference_steps) as progress_bar:
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||||
|
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
|
||||||
@ -411,42 +411,41 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
#del noise_pred_text
|
# del noise_pred_text
|
||||||
#del noise_pred_uncond
|
# del noise_pred_uncond
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
#del noise_pred
|
# del noise_pred
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
# if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
|
||||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
@ -466,9 +465,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
|
|
||||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -477,10 +476,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -490,10 +489,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
"title": "SDXL Latents to Latents",
|
"title": "SDXL Latents to Latents",
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number"
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -518,9 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
@ -545,7 +542,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler.set_timesteps(num_inference_steps)
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||||
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||||
num_inference_steps = num_inference_steps - t_start
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
# apply noise(if provided)
|
# apply noise(if provided)
|
||||||
@ -555,12 +552,12 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
del noise
|
del noise
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(), context=context,
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
|
|
||||||
# apply scheduler extra args
|
# apply scheduler extra args
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = dict()
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
@ -611,10 +608,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
#del noise_pred_uncond
|
# del noise_pred_uncond
|
||||||
#del noise_pred_text
|
# del noise_pred_text
|
||||||
|
|
||||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
@ -625,7 +622,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
# if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -639,13 +636,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
with tqdm(total=num_inference_steps) as progress_bar:
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||||
|
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
|
||||||
@ -672,38 +669,36 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
#del noise_pred_text
|
# del noise_pred_text
|
||||||
#del noise_pred_uncond
|
# del noise_pred_uncond
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
#del noise_pred
|
# del noise_pred
|
||||||
#import gc
|
# import gc
|
||||||
#gc.collect()
|
# gc.collect()
|
||||||
#torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
#if callback is not None and i % callback_steps == 0:
|
# if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
@ -29,16 +29,11 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["esrgan"] = "esrgan"
|
type: Literal["esrgan"] = "esrgan"
|
||||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
model_name: ESRGAN_MODELS = Field(
|
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
||||||
"title": "Upscale (RealESRGAN)",
|
|
||||||
"tags": ["image", "upscale", "realesrgan"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -108,9 +103,7 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||||
|
|
||||||
# back to PIL
|
# back to PIL
|
||||||
pil_image = Image.fromarray(
|
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
|
||||||
).convert("RGBA")
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
class CanceledException(Exception):
|
class CanceledException(Exception):
|
||||||
"""Execution canceled by user."""
|
"""Execution canceled by user."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -8,6 +8,7 @@ from ..invocations.baseinvocation import (
|
|||||||
InvocationConfig,
|
InvocationConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ class ProgressImage(BaseModel):
|
|||||||
height: int = Field(description="The effective height of the image in pixels")
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
class PILInvocationConfig(BaseModel):
|
||||||
"""Helper class to provide all PIL invocations with additional config"""
|
"""Helper class to provide all PIL invocations with additional config"""
|
||||||
|
|
||||||
@ -44,6 +46,7 @@ class PILInvocationConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
@ -76,6 +79,7 @@ class MaskOutput(BaseInvocationOutput):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
"""The origin of a resource (eg image).
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
@ -132,5 +136,3 @@ class InvalidImageCategoryException(ValueError):
|
|||||||
|
|
||||||
def __init__(self, message="Invalid image category."):
|
def __init__(self, message="Invalid image category."):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,9 +207,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return OffsetPaginatedResults(
|
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||||
items=images, offset=offset, limit=limit, total=count
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
||||||
try:
|
try:
|
||||||
|
@ -102,9 +102,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
self,
|
self,
|
||||||
board_id: str,
|
board_id: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
return self._services.board_image_records.get_all_board_image_names_for_board(
|
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
board_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_board_for_image(
|
def get_board_for_image(
|
||||||
self,
|
self,
|
||||||
@ -114,9 +112,7 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
return board_id
|
return board_id
|
||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(
|
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||||
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Converts a board record to a board DTO."""
|
"""Converts a board record to a board DTO."""
|
||||||
return BoardDTO(
|
return BoardDTO(
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
**board_record.dict(exclude={"cover_image_name"}),
|
||||||
|
@ -15,9 +15,7 @@ from pydantic import BaseModel, Field, Extra
|
|||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
cover_image_name: Optional[str] = Field(
|
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||||
description="The name of the board's new cover image."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordNotFoundException(Exception):
|
class BoardRecordNotFoundException(Exception):
|
||||||
@ -292,9 +290,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
|
|
||||||
count = cast(int, self._cursor.fetchone()[0])
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardRecord](
|
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||||
items=boards, offset=offset, limit=limit, total=count
|
|
||||||
)
|
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
|
@ -108,16 +108,12 @@ class BoardService(BoardServiceABC):
|
|||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
board_record = self._services.board_records.get(board_id)
|
board_record = self._services.board_records.get(board_id)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
board_record.board_id
|
|
||||||
)
|
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||||
board_id
|
|
||||||
)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@ -126,60 +122,44 @@ class BoardService(BoardServiceABC):
|
|||||||
changes: BoardChanges,
|
changes: BoardChanges,
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
board_record = self._services.board_records.update(board_id, changes)
|
board_record = self._services.board_records.update(board_id, changes)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
board_record.board_id
|
|
||||||
)
|
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||||
board_id
|
|
||||||
)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
def delete(self, board_id: str) -> None:
|
||||||
self._services.board_records.delete(board_id)
|
self._services.board_records.delete(board_id)
|
||||||
|
|
||||||
def get_many(
|
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
self, offset: int = 0, limit: int = 10
|
|
||||||
) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
board_records = self._services.board_records.get_many(offset, limit)
|
board_records = self._services.board_records.get_many(offset, limit)
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records.items:
|
for r in board_records.items:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
r.board_id
|
|
||||||
)
|
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
r.board_id
|
|
||||||
)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardDTO](
|
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
def get_all(self) -> list[BoardDTO]:
|
||||||
board_records = self._services.board_records.get_all()
|
board_records = self._services.board_records.get_all()
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records:
|
for r in board_records:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
r.board_id
|
|
||||||
)
|
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
r.board_id
|
|
||||||
)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return board_dtos
|
return board_dtos
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
'''Invokeai configuration system.
|
"""Invokeai configuration system.
|
||||||
|
|
||||||
Arguments and fields are taken from the pydantic definition of the
|
Arguments and fields are taken from the pydantic definition of the
|
||||||
model. Defaults can be set by creating a yaml configuration file that
|
model. Defaults can be set by creating a yaml configuration file that
|
||||||
@ -158,7 +158,7 @@ two configs are kept in separate sections of the config file:
|
|||||||
outdir: outputs
|
outdir: outputs
|
||||||
...
|
...
|
||||||
|
|
||||||
'''
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import pydoc
|
import pydoc
|
||||||
@ -170,64 +170,68 @@ from pathlib import Path
|
|||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path('invokeai.yaml')
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
MODEL_CORE = Path('models/core')
|
MODEL_CORE = Path("models/core")
|
||||||
DB_FILE = Path('invokeai.db')
|
DB_FILE = Path("invokeai.db")
|
||||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
'''
|
"""
|
||||||
Runtime configuration settings in which default values are
|
Runtime configuration settings in which default values are
|
||||||
read from an omegaconf .yaml file.
|
read from an omegaconf .yaml file.
|
||||||
'''
|
"""
|
||||||
initconf : ClassVar[DictConfig] = None
|
|
||||||
argparse_groups : ClassVar[Dict] = {}
|
|
||||||
|
|
||||||
def parse_args(self, argv: list=sys.argv[1:]):
|
initconf: ClassVar[DictConfig] = None
|
||||||
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
|
|
||||||
|
def parse_args(self, argv: list = sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt = parser.parse_args(argv)
|
opt = parser.parse_args(argv)
|
||||||
for name in self.__fields__:
|
for name in self.__fields__:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
setattr(self, name, getattr(opt,name))
|
setattr(self, name, getattr(opt, name))
|
||||||
|
|
||||||
def to_yaml(self)->str:
|
def to_yaml(self) -> str:
|
||||||
"""
|
"""
|
||||||
Return a YAML string representing our settings. This can be used
|
Return a YAML string representing our settings. This can be used
|
||||||
as the contents of `invokeai.yaml` to restore settings later.
|
as the contents of `invokeai.yaml` to restore settings later.
|
||||||
"""
|
"""
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)['type'])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = dict({type:dict()})
|
field_dict = dict({type: dict()})
|
||||||
for name,field in self.__fields__.items():
|
for name, field in self.__fields__.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
value = getattr(self,name)
|
value = getattr(self, name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = dict()
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser):
|
def add_parser_arguments(cls, parser):
|
||||||
if 'type' in get_type_hints(cls):
|
if "type" in get_type_hints(cls):
|
||||||
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||||
else:
|
else:
|
||||||
settings_stanza = "Uncategorized"
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
||||||
|
|
||||||
initconf = cls.initconf.get(settings_stanza) \
|
initconf = (
|
||||||
if cls.initconf and settings_stanza in cls.initconf \
|
cls.initconf.get(settings_stanza)
|
||||||
else OmegaConf.create()
|
if cls.initconf and settings_stanza in cls.initconf
|
||||||
|
else OmegaConf.create()
|
||||||
|
)
|
||||||
|
|
||||||
# create an upcase version of the environment in
|
# create an upcase version of the environment in
|
||||||
# order to achieve case-insensitive environment
|
# order to achieve case-insensitive environment
|
||||||
# variables (the way Windows does)
|
# variables (the way Windows does)
|
||||||
upcase_environ = dict()
|
upcase_environ = dict()
|
||||||
for key,value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.__fields__
|
||||||
@ -237,8 +241,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
|
|
||||||
category = field.field_info.extra.get("category","Uncategorized")
|
category = field.field_info.extra.get("category", "Uncategorized")
|
||||||
env_name = env_prefix + '_' + name
|
env_name = env_prefix + "_" + name
|
||||||
if category in initconf and name in initconf.get(category):
|
if category in initconf and name in initconf.get(category):
|
||||||
field.default = initconf.get(category).get(name)
|
field.default = initconf.get(category).get(name)
|
||||||
if env_name.upper() in upcase_environ:
|
if env_name.upper() in upcase_environ:
|
||||||
@ -248,15 +252,15 @@ class InvokeAISettings(BaseSettings):
|
|||||||
field.default = current_default
|
field.default = current_default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(self, command_field: str='type')->str:
|
def cmd_name(self, command_field: str = "type") -> str:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
return get_args(hints[command_field])[0]
|
||||||
else:
|
else:
|
||||||
return 'Uncategorized'
|
return "Uncategorized"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parser(cls)->ArgumentParser:
|
def get_parser(cls) -> ArgumentParser:
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
prog=cls.cmd_name(),
|
prog=cls.cmd_name(),
|
||||||
description=cls.__doc__,
|
description=cls.__doc__,
|
||||||
@ -269,24 +273,41 @@ class InvokeAISettings(BaseSettings):
|
|||||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self)->List[str]:
|
def _excluded(self) -> List[str]:
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ['type','initconf']
|
return ["type", "initconf"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self)->List[str]:
|
def _excluded_from_yaml(self) -> List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root', 'nsfw_checker']
|
return [
|
||||||
|
"type",
|
||||||
|
"initconf",
|
||||||
|
"gpu_mem_reserved",
|
||||||
|
"max_loaded_models",
|
||||||
|
"version",
|
||||||
|
"from_file",
|
||||||
|
"model",
|
||||||
|
"restore",
|
||||||
|
"root",
|
||||||
|
"nsfw_checker",
|
||||||
|
]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = "utf-8"
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
default = (
|
||||||
|
default_override
|
||||||
|
if default_override is not None
|
||||||
|
else field.default
|
||||||
|
if field.default_factory is None
|
||||||
|
else field.default_factory()
|
||||||
|
)
|
||||||
if category := field.field_info.extra.get("category"):
|
if category := field.field_info.extra.get("category"):
|
||||||
if category not in cls.argparse_groups:
|
if category not in cls.argparse_groups:
|
||||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
@ -315,10 +336,10 @@ class InvokeAISettings(BaseSettings):
|
|||||||
argparse_group.add_argument(
|
argparse_group.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
nargs='*',
|
nargs="*",
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -327,31 +348,35 @@ class InvokeAISettings(BaseSettings):
|
|||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
def _find_root()->Path:
|
|
||||||
|
|
||||||
|
def _find_root() -> Path:
|
||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||||
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
'''
|
"""
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
the command-line client (recommended for experts only), or
|
the command-line client (recommended for experts only), or
|
||||||
"invokeai-web" to launch the web server. Global options
|
"invokeai-web" to launch the web server. Global options
|
||||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||||
setting environment variables INVOKEAI_<setting>.
|
setting environment variables INVOKEAI_<setting>.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||||
singleton_init: ClassVar[Dict] = None
|
singleton_init: ClassVar[Dict] = None
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
@ -399,16 +424,16 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
||||||
'''
|
"""
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||||
'''
|
"""
|
||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
super().parse_args(argv)
|
super().parse_args(argv)
|
||||||
@ -425,135 +450,139 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
if self.singleton_init and not clobber:
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self.__class__)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in self.singleton_init:
|
for k in self.singleton_init:
|
||||||
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||||
'''
|
"""
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
'''
|
"""
|
||||||
if cls.singleton_config is None \
|
if (
|
||||||
or type(cls.singleton_config)!=cls \
|
cls.singleton_config is None
|
||||||
or (kwargs and cls.singleton_init != kwargs):
|
or type(cls.singleton_config) != cls
|
||||||
|
or (kwargs and cls.singleton_init != kwargs)
|
||||||
|
):
|
||||||
cls.singleton_config = cls(**kwargs)
|
cls.singleton_config = cls(**kwargs)
|
||||||
cls.singleton_init = kwargs
|
cls.singleton_init = kwargs
|
||||||
return cls.singleton_config
|
return cls.singleton_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self)->Path:
|
def root_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to the runtime root directory
|
Path to the runtime root directory
|
||||||
'''
|
"""
|
||||||
if self.root:
|
if self.root:
|
||||||
return Path(self.root).expanduser().absolute()
|
return Path(self.root).expanduser().absolute()
|
||||||
else:
|
else:
|
||||||
return self.find_root()
|
return self.find_root()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_dir(self)->Path:
|
def root_dir(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Alias for above.
|
Alias for above.
|
||||||
'''
|
"""
|
||||||
return self.root_path
|
return self.root_path
|
||||||
|
|
||||||
def _resolve(self,partial_path:Path)->Path:
|
def _resolve(self, partial_path: Path) -> Path:
|
||||||
return (self.root_path / partial_path).resolve()
|
return (self.root_path / partial_path).resolve()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_file_path(self)->Path:
|
def init_file_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to invokeai.yaml
|
Path to invokeai.yaml
|
||||||
'''
|
"""
|
||||||
return self._resolve(INIT_FILE)
|
return self._resolve(INIT_FILE)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_path(self)->Path:
|
def output_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to defaults outputs directory.
|
Path to defaults outputs directory.
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.outdir)
|
return self._resolve(self.outdir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_path(self)->Path:
|
def db_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to the invokeai.db file.
|
Path to the invokeai.db file.
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.db_dir) / DB_FILE
|
return self._resolve(self.db_dir) / DB_FILE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_conf_path(self)->Path:
|
def model_conf_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to models configuration file.
|
Path to models configuration file.
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.conf_path)
|
return self._resolve(self.conf_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def legacy_conf_path(self)->Path:
|
def legacy_conf_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.legacy_conf_dir)
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models_path(self)->Path:
|
def models_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to the models directory
|
Path to the models directory
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def autoconvert_path(self)->Path:
|
def autoconvert_path(self) -> Path:
|
||||||
'''
|
"""
|
||||||
Path to the directory containing models to be imported automatically at startup.
|
Path to the directory containing models to be imported automatically at startup.
|
||||||
'''
|
"""
|
||||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self)->bool:
|
def full_precision(self) -> bool:
|
||||||
"""Return true if precision set to float32"""
|
"""Return true if precision set to float32"""
|
||||||
return self.precision=='float32'
|
return self.precision == "float32"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disable_xformers(self)->bool:
|
def disable_xformers(self) -> bool:
|
||||||
"""Return true if xformers_enabled is false"""
|
"""Return true if xformers_enabled is false"""
|
||||||
return not self.xformers_enabled
|
return not self.xformers_enabled
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def try_patchmatch(self)->bool:
|
def try_patchmatch(self) -> bool:
|
||||||
"""Return true if patchmatch true"""
|
"""Return true if patchmatch true"""
|
||||||
return self.patchmatch
|
return self.patchmatch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nsfw_checker(self)->bool:
|
def nsfw_checker(self) -> bool:
|
||||||
""" NSFW node is always active and disabled from Web UIe"""
|
"""NSFW node is always active and disabled from Web UIe"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def invisible_watermark(self)->bool:
|
def invisible_watermark(self) -> bool:
|
||||||
""" invisible watermark node is always active and disabled from Web UIe"""
|
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_root()->Path:
|
def find_root() -> Path:
|
||||||
'''
|
"""
|
||||||
Choose the runtime root directory when not specified on command line or
|
Choose the runtime root directory when not specified on command line or
|
||||||
init file.
|
init file.
|
||||||
'''
|
"""
|
||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
|
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
'''
|
"""
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
It also supports reading defaults from an init file.
|
It also supports reading defaults from an init file.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def print_help(self, file=None):
|
def print_help(self, file=None):
|
||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
|
||||||
'''
|
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||||
|
"""
|
||||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||||
'''
|
"""
|
||||||
return InvokeAIAppConfig.get_config(**kwargs)
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
|
@ -7,57 +7,80 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr
|
|||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||||
|
|
||||||
|
|
||||||
def create_text_to_image() -> LibraryGraph:
|
def create_text_to_image() -> LibraryGraph:
|
||||||
return LibraryGraph(
|
return LibraryGraph(
|
||||||
id=default_text_to_image_graph_id,
|
id=default_text_to_image_graph_id,
|
||||||
name='t2i',
|
name="t2i",
|
||||||
description='Converts text to an image',
|
description="Converts text to an image",
|
||||||
graph=Graph(
|
graph=Graph(
|
||||||
nodes={
|
nodes={
|
||||||
'width': ParamIntInvocation(id='width', a=512),
|
"width": ParamIntInvocation(id="width", a=512),
|
||||||
'height': ParamIntInvocation(id='height', a=512),
|
"height": ParamIntInvocation(id="height", a=512),
|
||||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
"seed": ParamIntInvocation(id="seed", a=-1),
|
||||||
'3': NoiseInvocation(id='3'),
|
"3": NoiseInvocation(id="3"),
|
||||||
'4': CompelInvocation(id='4'),
|
"4": CompelInvocation(id="4"),
|
||||||
'5': CompelInvocation(id='5'),
|
"5": CompelInvocation(id="5"),
|
||||||
'6': TextToLatentsInvocation(id='6'),
|
"6": TextToLatentsInvocation(id="6"),
|
||||||
'7': LatentsToImageInvocation(id='7'),
|
"7": LatentsToImageInvocation(id="7"),
|
||||||
'8': ImageNSFWBlurInvocation(id='8'),
|
"8": ImageNSFWBlurInvocation(id="8"),
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
Edge(
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
source=EdgeConnection(node_id="width", field="a"),
|
||||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
destination=EdgeConnection(node_id="3", field="width"),
|
||||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
),
|
||||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
Edge(
|
||||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
source=EdgeConnection(node_id="height", field="a"),
|
||||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
destination=EdgeConnection(node_id="3", field="height"),
|
||||||
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
|
),
|
||||||
]
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="seed", field="a"),
|
||||||
|
destination=EdgeConnection(node_id="3", field="seed"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="3", field="noise"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="noise"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="6", field="latents"),
|
||||||
|
destination=EdgeConnection(node_id="7", field="latents"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||||
|
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||||
|
),
|
||||||
|
Edge(
|
||||||
|
source=EdgeConnection(node_id="7", field="image"),
|
||||||
|
destination=EdgeConnection(node_id="8", field="image"),
|
||||||
|
),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
||||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
||||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
||||||
],
|
],
|
||||||
exposed_outputs=[
|
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||||
ExposedNodeOutput(node_path='8', field='image', alias='image')
|
)
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = list()
|
graphs: list[LibraryGraph] = list()
|
||||||
|
|
||||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||||
# #if text_to_image is None:
|
# #if text_to_image is None:
|
||||||
text_to_image = create_text_to_image()
|
text_to_image = create_text_to_image()
|
||||||
|
@ -44,9 +44,7 @@ class EventServiceBase:
|
|||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict()
|
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||||
if progress_image is not None
|
|
||||||
else None,
|
|
||||||
step=step,
|
step=step,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
@ -90,9 +88,7 @@ class EventServiceBase:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
|
||||||
) -> None:
|
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
|
@ -28,6 +28,7 @@ from ..invocations.baseinvocation import (
|
|||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
class EdgeConnection(BaseModel):
|
class EdgeConnection(BaseModel):
|
||||||
node_id: str = Field(description="The id of the node for this edge connection")
|
node_id: str = Field(description="The id of the node for this edge connection")
|
||||||
field: str = Field(description="The field for this connection")
|
field: str = Field(description="The field for this connection")
|
||||||
@ -61,6 +62,7 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
|||||||
node_input_field = node_inputs.get(field) or None
|
node_input_field = node_inputs.get(field) or None
|
||||||
return node_input_field
|
return node_input_field
|
||||||
|
|
||||||
|
|
||||||
def is_union_subtype(t1, t2):
|
def is_union_subtype(t1, t2):
|
||||||
t1_args = get_args(t1)
|
t1_args = get_args(t1)
|
||||||
t2_args = get_args(t2)
|
t2_args = get_args(t2)
|
||||||
@ -71,6 +73,7 @@ def is_union_subtype(t1, t2):
|
|||||||
# t1 is a Union, check that all of its types are in t2_args
|
# t1 is a Union, check that all of its types are in t2_args
|
||||||
return all(arg in t2_args for arg in t1_args)
|
return all(arg in t2_args for arg in t1_args)
|
||||||
|
|
||||||
|
|
||||||
def is_list_or_contains_list(t):
|
def is_list_or_contains_list(t):
|
||||||
t_args = get_args(t)
|
t_args = get_args(t)
|
||||||
|
|
||||||
@ -154,15 +157,17 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'type',
|
"type",
|
||||||
'image',
|
"image",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
"""Execute a graph"""
|
"""Execute a graph"""
|
||||||
|
|
||||||
type: Literal["graph"] = "graph"
|
type: Literal["graph"] = "graph"
|
||||||
|
|
||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
@ -182,23 +187,21 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'type',
|
"type",
|
||||||
'item',
|
"item",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
"""Iterates over a list of items"""
|
"""Iterates over a list of items"""
|
||||||
|
|
||||||
type: Literal["iterate"] = "iterate"
|
type: Literal["iterate"] = "iterate"
|
||||||
|
|
||||||
collection: list[Any] = Field(
|
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
||||||
description="The list of items to iterate over", default_factory=list
|
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
||||||
)
|
|
||||||
index: int = Field(
|
|
||||||
description="The index, will be provided on executed iterators", default=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||||
"""Produces the outputs as values"""
|
"""Produces the outputs as values"""
|
||||||
@ -212,12 +215,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'type',
|
"type",
|
||||||
'collection',
|
"collection",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
|
|
||||||
@ -269,9 +273,7 @@ class Graph(BaseModel):
|
|||||||
if node_path in self.nodes:
|
if node_path in self.nodes:
|
||||||
return (self, node_path)
|
return (self, node_path)
|
||||||
|
|
||||||
node_id = (
|
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
|
||||||
)
|
|
||||||
if node_id not in self.nodes:
|
if node_id not in self.nodes:
|
||||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||||
|
|
||||||
@ -333,9 +335,7 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate all edges reference nodes in the graph
|
# Validate all edges reference nodes in the graph
|
||||||
node_ids = set(
|
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
||||||
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
|
||||||
)
|
|
||||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -361,22 +361,14 @@ class Graph(BaseModel):
|
|||||||
# Validate all iterators
|
# Validate all iterators
|
||||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||||
if not all(
|
if not all(
|
||||||
(
|
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
||||||
self._is_iterator_connection_valid(n.id)
|
|
||||||
for n in self.nodes.values()
|
|
||||||
if isinstance(n, IterateInvocation)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate all collectors
|
# Validate all collectors
|
||||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||||
if not all(
|
if not all(
|
||||||
(
|
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
||||||
self._is_collector_connection_valid(n.id)
|
|
||||||
for n in self.nodes.values()
|
|
||||||
if isinstance(n, CollectInvocation)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -395,48 +387,51 @@ class Graph(BaseModel):
|
|||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
raise InvalidEdgeError(
|
||||||
|
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate that no cycles would be created
|
# Validate that no cycles would be created
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
raise InvalidEdgeError(
|
||||||
|
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
|
||||||
from_node, edge.source.field, to_node, edge.destination.field
|
raise InvalidEdgeError(
|
||||||
):
|
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
)
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||||
edge.destination.node_id, new_input=edge.source
|
raise InvalidEdgeError(
|
||||||
):
|
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
)
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||||
edge.source.node_id, new_output=edge.destination
|
raise InvalidEdgeError(
|
||||||
):
|
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
)
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
|
||||||
edge.destination.node_id, new_input=edge.source
|
raise InvalidEdgeError(
|
||||||
):
|
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
)
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
||||||
edge.source.node_id, new_output=edge.destination
|
raise InvalidEdgeError(
|
||||||
):
|
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
)
|
||||||
|
|
||||||
|
|
||||||
def has_node(self, node_path: str) -> bool:
|
def has_node(self, node_path: str) -> bool:
|
||||||
"""Determines whether or not a node exists in the graph."""
|
"""Determines whether or not a node exists in the graph."""
|
||||||
@ -465,17 +460,13 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Ensure the node type matches the new node
|
# Ensure the node type matches the new node
|
||||||
if type(node) != type(new_node):
|
if type(node) != type(new_node):
|
||||||
raise TypeError(
|
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||||
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the new id is either the same or is not in the graph
|
# Ensure the new id is either the same or is not in the graph
|
||||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||||
if new_node.id != node.id and self.has_node(new_path):
|
if new_node.id != node.id and self.has_node(new_path):
|
||||||
raise NodeAlreadyInGraphError(
|
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||||
"Node with id {new_node.id} already exists in graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the new node in the graph
|
# Set the new node in the graph
|
||||||
graph.nodes[new_node.id] = new_node
|
graph.nodes[new_node.id] = new_node
|
||||||
@ -497,9 +488,7 @@ class Graph(BaseModel):
|
|||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
Edge(
|
Edge(
|
||||||
source=edge.source,
|
source=edge.source,
|
||||||
destination=EdgeConnection(
|
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||||
node_id=new_graph_node_path, field=edge.destination.field
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -512,16 +501,12 @@ class Graph(BaseModel):
|
|||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(
|
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||||
node_id=new_graph_node_path, field=edge.source.field
|
destination=edge.destination,
|
||||||
),
|
|
||||||
destination=edge.destination
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_input_edges(
|
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||||
self, node_path: str, field: Optional[str] = None
|
|
||||||
) -> list[Edge]:
|
|
||||||
"""Gets all input edges for a node"""
|
"""Gets all input edges for a node"""
|
||||||
edges = self._get_input_edges_and_graphs(node_path)
|
edges = self._get_input_edges_and_graphs(node_path)
|
||||||
|
|
||||||
@ -538,7 +523,7 @@ class Graph(BaseModel):
|
|||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
field=e.destination.field,
|
field=e.destination.field,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
@ -550,32 +535,20 @@ class Graph(BaseModel):
|
|||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend(
|
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||||
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
|
||||||
)
|
|
||||||
|
|
||||||
node_id = (
|
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
|
||||||
)
|
|
||||||
node = self.nodes[node_id]
|
node = self.nodes[node_id]
|
||||||
|
|
||||||
if isinstance(node, GraphInvocation):
|
if isinstance(node, GraphInvocation):
|
||||||
graph = node.graph
|
graph = node.graph
|
||||||
graph_path = (
|
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||||
node.id
|
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||||
if prefix is None or prefix == ""
|
|
||||||
else self._get_node_path(node.id, prefix=prefix)
|
|
||||||
)
|
|
||||||
graph_edges = graph._get_input_edges_and_graphs(
|
|
||||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
|
||||||
)
|
|
||||||
edges.extend(graph_edges)
|
edges.extend(graph_edges)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
def _get_output_edges(
|
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||||
self, node_path: str, field: str
|
|
||||||
) -> list[Edge]:
|
|
||||||
"""Gets all output edges for a node"""
|
"""Gets all output edges for a node"""
|
||||||
edges = self._get_output_edges_and_graphs(node_path)
|
edges = self._get_output_edges_and_graphs(node_path)
|
||||||
|
|
||||||
@ -592,7 +565,7 @@ class Graph(BaseModel):
|
|||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
field=e.destination.field,
|
field=e.destination.field,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
@ -604,25 +577,15 @@ class Graph(BaseModel):
|
|||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend(
|
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||||
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
|
||||||
)
|
|
||||||
|
|
||||||
node_id = (
|
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
|
||||||
)
|
|
||||||
node = self.nodes[node_id]
|
node = self.nodes[node_id]
|
||||||
|
|
||||||
if isinstance(node, GraphInvocation):
|
if isinstance(node, GraphInvocation):
|
||||||
graph = node.graph
|
graph = node.graph
|
||||||
graph_path = (
|
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||||
node.id
|
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||||
if prefix is None or prefix == ""
|
|
||||||
else self._get_node_path(node.id, prefix=prefix)
|
|
||||||
)
|
|
||||||
graph_edges = graph._get_output_edges_and_graphs(
|
|
||||||
node_path[(len(node_id) + 1) :], prefix=graph_path
|
|
||||||
)
|
|
||||||
edges.extend(graph_edges)
|
edges.extend(graph_edges)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
@ -646,12 +609,8 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_field = get_output_field(
|
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||||
self.get_node(inputs[0].node_id), inputs[0].field
|
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||||
)
|
|
||||||
output_fields = list(
|
|
||||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) != list:
|
||||||
@ -659,12 +618,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate that all outputs match the input type
|
# Validate that all outputs match the input type
|
||||||
input_field_item_type = get_args(input_field)[0]
|
input_field_item_type = get_args(input_field)[0]
|
||||||
if not all(
|
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
||||||
(
|
|
||||||
are_connection_types_compatible(input_field_item_type, f)
|
|
||||||
for f in output_fields
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -684,35 +638,21 @@ class Graph(BaseModel):
|
|||||||
outputs.append(new_output)
|
outputs.append(new_output)
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_fields = list(
|
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
||||||
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
||||||
)
|
|
||||||
output_fields = list(
|
|
||||||
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that all inputs are derived from or match a single type
|
# Validate that all inputs are derived from or match a single type
|
||||||
input_field_types = set(
|
input_field_types = set(
|
||||||
[
|
[
|
||||||
t
|
t
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in (
|
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
||||||
[input_field]
|
|
||||||
if get_origin(input_field) == None
|
|
||||||
else get_args(input_field)
|
|
||||||
)
|
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
]
|
]
|
||||||
) # Get unique types
|
) # Get unique types
|
||||||
type_tree = nx.DiGraph()
|
type_tree = nx.DiGraph()
|
||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from(
|
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
||||||
[
|
|
||||||
e
|
|
||||||
for e in itertools.permutations(input_field_types, 2)
|
|
||||||
if issubclass(e[1], e[0])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||||
return False # There is more than one root type
|
return False # There is more than one root type
|
||||||
@ -729,9 +669,7 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
if not all(
|
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
||||||
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -751,9 +689,7 @@ class Graph(BaseModel):
|
|||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(
|
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
|
||||||
) -> nx.DiGraph:
|
|
||||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||||
g = nx_graph or nx.DiGraph()
|
g = nx_graph or nx.DiGraph()
|
||||||
|
|
||||||
@ -762,26 +698,18 @@ class Graph(BaseModel):
|
|||||||
[
|
[
|
||||||
self._get_node_path(n.id, prefix)
|
self._get_node_path(n.id, prefix)
|
||||||
for n in self.nodes.values()
|
for n in self.nodes.values()
|
||||||
if not isinstance(n, GraphInvocation)
|
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||||
and not isinstance(n, IterateInvocation)
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expand graph nodes
|
# Expand graph nodes
|
||||||
for sgn in (
|
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
|
||||||
):
|
|
||||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||||
g.add_edges_from(
|
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||||
[
|
|
||||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
|
||||||
for e in unique_edges
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
@ -800,23 +728,19 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Nodes that have been executed
|
# Nodes that have been executed
|
||||||
executed: set[str] = Field(
|
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
||||||
description="The set of node ids that have been executed", default_factory=set
|
|
||||||
)
|
|
||||||
executed_history: list[str] = Field(
|
executed_history: list[str] = Field(
|
||||||
description="The list of node ids that have been executed, in order of execution",
|
description="The list of node ids that have been executed, in order of execution",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[
|
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||||
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
description="The results of node executions", default_factory=dict
|
||||||
] = Field(description="The results of node executions", default_factory=dict)
|
)
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||||
description="Errors raised when executing nodes", default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map of prepared/executed nodes to their original nodes
|
# Map of prepared/executed nodes to their original nodes
|
||||||
prepared_source_mapping: dict[str, str] = Field(
|
prepared_source_mapping: dict[str, str] = Field(
|
||||||
@ -832,16 +756,16 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'id',
|
"id",
|
||||||
'graph',
|
"graph",
|
||||||
'execution_graph',
|
"execution_graph",
|
||||||
'executed',
|
"executed",
|
||||||
'executed_history',
|
"executed_history",
|
||||||
'results',
|
"results",
|
||||||
'errors',
|
"errors",
|
||||||
'prepared_source_mapping',
|
"prepared_source_mapping",
|
||||||
'source_prepared_mapping',
|
"source_prepared_mapping",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -899,9 +823,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
return len(self.errors) > 0
|
return len(self.errors) > 0
|
||||||
|
|
||||||
def _create_execution_node(
|
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||||
self, node_path: str, iteration_node_map: list[tuple[str, str]]
|
|
||||||
) -> list[str]:
|
|
||||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||||
|
|
||||||
node = self.graph.get_node(node_path)
|
node = self.graph.get_node(node_path)
|
||||||
@ -911,20 +833,12 @@ class GraphExecutionState(BaseModel):
|
|||||||
# If this is an iterator node, we must create a copy for each iteration
|
# If this is an iterator node, we must create a copy for each iteration
|
||||||
if isinstance(node, IterateInvocation):
|
if isinstance(node, IterateInvocation):
|
||||||
# Get input collection edge (should error if there are no inputs)
|
# Get input collection edge (should error if there are no inputs)
|
||||||
input_collection_edge = next(
|
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||||
iter(self.graph._get_input_edges(node_path, "collection"))
|
|
||||||
)
|
|
||||||
input_collection_prepared_node_id = next(
|
input_collection_prepared_node_id = next(
|
||||||
n[1]
|
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||||
for n in iteration_node_map
|
|
||||||
if n[0] == input_collection_edge.source.node_id
|
|
||||||
)
|
|
||||||
input_collection_prepared_node_output = self.results[
|
|
||||||
input_collection_prepared_node_id
|
|
||||||
]
|
|
||||||
input_collection = getattr(
|
|
||||||
input_collection_prepared_node_output, input_collection_edge.source.field
|
|
||||||
)
|
)
|
||||||
|
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
||||||
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes = list()
|
new_nodes = list()
|
||||||
@ -939,9 +853,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges = list()
|
new_edges = list()
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
||||||
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
|
||||||
):
|
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||||
@ -982,11 +894,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
def _iterator_graph(self) -> nx.DiGraph:
|
def _iterator_graph(self) -> nx.DiGraph:
|
||||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph_flat()
|
||||||
collectors = (
|
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
||||||
n
|
|
||||||
for n in self.graph.nodes
|
|
||||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
|
||||||
)
|
|
||||||
for c in collectors:
|
for c in collectors:
|
||||||
g.remove_edges_from(list(g.in_edges(c)))
|
g.remove_edges_from(list(g.in_edges(c)))
|
||||||
return g
|
return g
|
||||||
@ -994,11 +902,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||||
"""Gets iterators for a node"""
|
"""Gets iterators for a node"""
|
||||||
g = self._iterator_graph()
|
g = self._iterator_graph()
|
||||||
iterators = [
|
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||||
n
|
|
||||||
for n in nx.ancestors(g, node_id)
|
|
||||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
|
||||||
]
|
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
def _prepare(self) -> Optional[str]:
|
def _prepare(self) -> Optional[str]:
|
||||||
@ -1045,29 +949,18 @@ class GraphExecutionState(BaseModel):
|
|||||||
if isinstance(next_node, CollectInvocation):
|
if isinstance(next_node, CollectInvocation):
|
||||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||||
all_iteration_mappings = list(
|
all_iteration_mappings = list(
|
||||||
itertools.chain(
|
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
||||||
*(
|
|
||||||
((s, p) for p in self.source_prepared_mapping[s])
|
|
||||||
for s in next_node_parents
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||||
create_results = self._create_execution_node(
|
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
||||||
next_node_id, all_iteration_mappings
|
|
||||||
)
|
|
||||||
if create_results is not None:
|
if create_results is not None:
|
||||||
new_node_ids.extend(create_results)
|
new_node_ids.extend(create_results)
|
||||||
else: # Iterators or normal nodes
|
else: # Iterators or normal nodes
|
||||||
# Get all iterator combinations for this node
|
# Get all iterator combinations for this node
|
||||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||||
iterator_nodes_prepared = [
|
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||||
list(self.source_prepared_mapping[n]) for n in iterator_nodes
|
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||||
]
|
|
||||||
iterator_node_prepared_combinations = list(
|
|
||||||
itertools.product(*iterator_nodes_prepared)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select the correct prepared parents for each iteration
|
# Select the correct prepared parents for each iteration
|
||||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||||
@ -1096,31 +989,16 @@ class GraphExecutionState(BaseModel):
|
|||||||
return next(iter(prepared_nodes))
|
return next(iter(prepared_nodes))
|
||||||
|
|
||||||
# Check if the requested node is an iterator
|
# Check if the requested node is an iterator
|
||||||
prepared_iterator = next(
|
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
||||||
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
|
|
||||||
)
|
|
||||||
if prepared_iterator is not None:
|
if prepared_iterator is not None:
|
||||||
return prepared_iterator
|
return prepared_iterator
|
||||||
|
|
||||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||||
iterator_source_node_mapping = [
|
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||||
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
|
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||||
]
|
|
||||||
parent_iterators = [
|
|
||||||
itn
|
|
||||||
for itn in iterator_source_node_mapping
|
|
||||||
if nx.has_path(graph, itn[1], source_node_path)
|
|
||||||
]
|
|
||||||
|
|
||||||
return next(
|
return next(
|
||||||
(
|
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||||
n
|
|
||||||
for n in prepared_nodes
|
|
||||||
if all(
|
|
||||||
nx.has_path(execution_graph, pit[0], n)
|
|
||||||
for pit in parent_iterators
|
|
||||||
)
|
|
||||||
),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1130,13 +1008,13 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||||
|
|
||||||
next_node = next(
|
next_node = next(
|
||||||
(
|
(
|
||||||
n
|
n
|
||||||
for n in sorted_nodes
|
for n in sorted_nodes
|
||||||
if n not in self.executed # the node must not already be executed...
|
if n not in self.executed # the node must not already be executed...
|
||||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -1221,15 +1099,18 @@ class ExposedNodeOutput(BaseModel):
|
|||||||
field: str = Field(description="The field name of the output")
|
field: str = Field(description="The field name of the output")
|
||||||
alias: str = Field(description="The alias of the output")
|
alias: str = Field(description="The alias of the output")
|
||||||
|
|
||||||
|
|
||||||
class LibraryGraph(BaseModel):
|
class LibraryGraph(BaseModel):
|
||||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||||
graph: Graph = Field(description="The graph")
|
graph: Graph = Field(description="The graph")
|
||||||
name: str = Field(description="The name of the graph")
|
name: str = Field(description="The name of the graph")
|
||||||
description: str = Field(description="The description of the graph")
|
description: str = Field(description="The description of the graph")
|
||||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
exposed_outputs: list[ExposedNodeOutput] = Field(
|
||||||
|
description="The outputs exposed by this graph", default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
@validator('exposed_inputs', 'exposed_outputs')
|
@validator("exposed_inputs", "exposed_outputs")
|
||||||
def validate_exposed_aliases(cls, v):
|
def validate_exposed_aliases(cls, v):
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
@ -1237,23 +1118,27 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_exposed_nodes(cls, values):
|
def validate_exposed_nodes(cls, values):
|
||||||
graph = values['graph']
|
graph = values["graph"]
|
||||||
|
|
||||||
# Validate exposed inputs
|
# Validate exposed inputs
|
||||||
for exposed_input in values['exposed_inputs']:
|
for exposed_input in values["exposed_inputs"]:
|
||||||
if not graph.has_node(exposed_input.node_path):
|
if not graph.has_node(exposed_input.node_path):
|
||||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_input.node_path)
|
node = graph.get_node(exposed_input.node_path)
|
||||||
if get_input_field(node, exposed_input.field) is None:
|
if get_input_field(node, exposed_input.field) is None:
|
||||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
raise ValueError(
|
||||||
|
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate exposed outputs
|
# Validate exposed outputs
|
||||||
for exposed_output in values['exposed_outputs']:
|
for exposed_output in values["exposed_outputs"]:
|
||||||
if not graph.has_node(exposed_output.node_path):
|
if not graph.has_node(exposed_output.node_path):
|
||||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_output.node_path)
|
node = graph.get_node(exposed_output.node_path)
|
||||||
if get_output_field(node, exposed_output.field) is None:
|
if get_output_field(node, exposed_output.field) is None:
|
||||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
raise ValueError(
|
||||||
|
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
||||||
|
)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -85,9 +85,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
self.__output_folder: Path = (
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
||||||
)
|
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
|
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
@ -120,7 +118,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
|
||||||
if graph is not None:
|
if graph is not None:
|
||||||
@ -183,9 +181,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||||
if not image_name in self.__cache:
|
if not image_name in self.__cache:
|
||||||
self.__cache[image_name] = image
|
self.__cache[image_name] = image
|
||||||
self.__cache_ids.put(
|
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
||||||
image_name
|
|
||||||
) # TODO: this should refresh position for LRU cache
|
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
cache_id = self.__cache_ids.get()
|
cache_id = self.__cache_ids.get()
|
||||||
if cache_id in self.__cache:
|
if cache_id in self.__cache:
|
||||||
|
@ -426,9 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
return OffsetPaginatedResults(
|
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||||
items=images, offset=offset, limit=limit, total=count
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
@ -466,7 +464,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
|
||||||
def delete_intermediates(self) -> list[str]:
|
def delete_intermediates(self) -> list[str]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@ -505,9 +502,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = None if metadata is None else json.dumps(metadata)
|
||||||
None if metadata is None else json.dumps(metadata)
|
|
||||||
)
|
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
|
@ -217,12 +217,8 @@ class ImageService(ImageServiceABC):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
self._services.board_image_records.add_image_to_board(
|
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
board_id=board_id, image_name=image_name
|
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
|
||||||
)
|
|
||||||
self._services.image_files.save(
|
|
||||||
image_name=image_name, image=image, metadata=metadata, graph=graph
|
|
||||||
)
|
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -297,9 +293,7 @@ class ImageService(ImageServiceABC):
|
|||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata()
|
return ImageMetadata()
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(
|
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
image_record.session_id
|
|
||||||
)
|
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
if session_raw:
|
if session_raw:
|
||||||
@ -364,9 +358,7 @@ class ImageService(ImageServiceABC):
|
|||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_name),
|
self._services.urls.get_image_url(r.image_name),
|
||||||
self._services.urls.get_image_url(r.image_name, True),
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
self._services.board_image_records.get_board_for_image(
|
self._services.board_image_records.get_board_for_image(r.image_name),
|
||||||
r.image_name
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
@ -398,11 +390,7 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
def delete_images_on_board(self, board_id: str):
|
||||||
try:
|
try:
|
||||||
image_names = (
|
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
self._services.board_image_records.get_all_board_image_names_for_board(
|
|
||||||
board_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete_many(image_names)
|
self._services.image_records.delete_many(image_names)
|
||||||
|
@ -7,6 +7,7 @@ from queue import Queue
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
class InvocationQueueItem(BaseModel):
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
@ -45,9 +46,11 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
while isinstance(item, InvocationQueueItem) \
|
while (
|
||||||
and item.graph_execution_state_id in self.__cancellations \
|
isinstance(item, InvocationQueueItem)
|
||||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
and item.graph_execution_state_id in self.__cancellations
|
||||||
|
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
||||||
|
):
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
# Clear old items
|
# Clear old items
|
||||||
|
@ -7,6 +7,7 @@ from .graph import Graph, GraphExecutionState
|
|||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invocation_services import InvocationServices
|
from .invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
"""The invoker, used to execute invocations"""
|
"""The invoker, used to execute invocations"""
|
||||||
|
|
||||||
@ -16,9 +17,7 @@ class Invoker:
|
|||||||
self.services = services
|
self.services = services
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def invoke(
|
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
|
@ -9,13 +9,15 @@ T = TypeVar("T", bound=BaseModel)
|
|||||||
|
|
||||||
class PaginatedResults(GenericModel, Generic[T]):
|
class PaginatedResults(GenericModel, Generic[T]):
|
||||||
"""Paginated results"""
|
"""Paginated results"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
items: list[T] = Field(description="Items")
|
items: list[T] = Field(description="Items")
|
||||||
page: int = Field(description="Current Page")
|
page: int = Field(description="Current Page")
|
||||||
pages: int = Field(description="Total number of pages")
|
pages: int = Field(description="Total number of pages")
|
||||||
per_page: int = Field(description="Number of items per page")
|
per_page: int = Field(description="Number of items per page")
|
||||||
total: int = Field(description="Total number of items in result")
|
total: int = Field(description="Total number of items in result")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
@ -48,9 +50,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
self, query: str, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[T]:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||||
|
@ -7,6 +7,7 @@ from typing import Dict, Union, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
class LatentsStorageBase(ABC):
|
||||||
"""Responsible for storing and retrieving latents."""
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class LatentsStorageBase(ABC):
|
|||||||
|
|
||||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
__cache: Dict[str, torch.Tensor]
|
__cache: Dict[str, torch.Tensor]
|
||||||
__cache_ids: Queue
|
__cache_ids: Queue
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
@ -87,8 +88,6 @@ class DiskLatentsStorage(LatentsStorageBase):
|
|||||||
def delete(self, name: str) -> None:
|
def delete(self, name: str) -> None:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
latent_path.unlink()
|
latent_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
def get_path(self, name: str) -> Path:
|
||||||
return self.__output_folder / name
|
return self.__output_folder / name
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False
|
clobber: bool = False,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
@ -148,12 +148,12 @@ class ModelManagerServiceBase(ABC):
|
|||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
ModelNotFoundException if the name does not already exist.
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
@ -169,21 +169,20 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rename_model(self,
|
def rename_model(
|
||||||
model_name: str,
|
self,
|
||||||
base_model: BaseModelType,
|
model_name: str,
|
||||||
model_type: ModelType,
|
base_model: BaseModelType,
|
||||||
new_name: str,
|
model_type: ModelType,
|
||||||
):
|
new_name: str,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Rename the indicated model.
|
Rename the indicated model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_checkpoint_configs(
|
def list_checkpoint_configs(self) -> List[Path]:
|
||||||
self
|
|
||||||
)->List[Path]:
|
|
||||||
"""
|
"""
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
"""
|
"""
|
||||||
@ -194,7 +193,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -211,11 +210,12 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def heuristic_import(self,
|
def heuristic_import(
|
||||||
items_to_import: set[str],
|
self,
|
||||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
items_to_import: set[str],
|
||||||
)->dict[str, AddModelResult]:
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
) -> dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@ -230,19 +230,23 @@ class ModelManagerServiceBase(ABC):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
'''
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
model_names: List[str] = Field(
|
||||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
),
|
||||||
alpha: Optional[float] = 0.5,
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
default=None, description="Base model shared by all models to be merged"
|
||||||
force: Optional[bool] = False,
|
),
|
||||||
merge_dest_directory: Optional[Path] = None
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -250,27 +254,27 @@ class ModelManagerServiceBase(ABC):
|
|||||||
:param base_model: Base model to use for all models
|
:param base_model: Base model to use for all models
|
||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search_for_models(self, directory: Path)->List[Path]:
|
def search_for_models(self, directory: Path) -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Return list of all models found in the designated directory.
|
Return list of all models found in the designated directory.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sync_to_config(self):
|
def sync_to_config(self):
|
||||||
"""
|
"""
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
in the autoimport directories. Call after making changes outside the
|
in the autoimport directories. Call after making changes outside the
|
||||||
model manager API.
|
model manager API.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -280,9 +284,11 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
||||||
class ModelManagerService(ModelManagerServiceBase):
|
class ModelManagerService(ModelManagerServiceBase):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
@ -298,17 +304,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
config_file = config.model_conf_path
|
config_file = config.model_conf_path
|
||||||
else:
|
else:
|
||||||
config_file = config.root_dir / "configs/models.yaml"
|
config_file = config.root_dir / "configs/models.yaml"
|
||||||
|
|
||||||
logger.debug(f'Config file={config_file}')
|
logger.debug(f"Config file={config_file}")
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else ''
|
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||||
logger.info(f'GPU device = {device} {device_name}')
|
logger.info(f"GPU device = {device} {device_name}")
|
||||||
|
|
||||||
precision = config.precision
|
precision = config.precision
|
||||||
if precision == "auto":
|
if precision == "auto":
|
||||||
precision = choose_precision(device)
|
precision = choose_precision(device)
|
||||||
dtype = torch.float32 if precision == 'float32' else torch.float16
|
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||||
|
|
||||||
# this is transitional backward compatibility
|
# this is transitional backward compatibility
|
||||||
# support for the deprecated `max_loaded_models`
|
# support for the deprecated `max_loaded_models`
|
||||||
@ -316,9 +322,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
# cache size is set to 2.5 GB times
|
# cache size is set to 2.5 GB times
|
||||||
# the number of max_loaded_models. Otherwise
|
# the number of max_loaded_models. Otherwise
|
||||||
# use new `max_cache_size` config setting
|
# use new `max_cache_size` config setting
|
||||||
max_cache_size = config.max_cache_size \
|
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
||||||
if hasattr(config,'max_cache_size') \
|
|
||||||
else config.max_loaded_models * 2.5
|
|
||||||
|
|
||||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||||
|
|
||||||
@ -332,7 +336,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
sequential_offload=sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
logger.info('Model manager service initialized')
|
logger.info("Model manager service initialized")
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
@ -371,7 +375,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info
|
model_info=model_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
@ -405,9 +409,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.model_names()
|
return self.mgr.model_names()
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self,
|
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||||
base_model: Optional[BaseModelType] = None,
|
|
||||||
model_type: Optional[ModelType] = None
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a list of models.
|
Return a list of models.
|
||||||
@ -418,9 +420,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Return information about the model using the same format as list_models()
|
Return information about the model using the same format as list_models()
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_model(model_name=model_name,
|
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type)
|
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
@ -429,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
)->None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'add/update model {model_name}')
|
self.logger.debug(f"add/update model {model_name}")
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
def update_model(
|
def update_model(
|
||||||
@ -450,15 +450,15 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
ModelNotFoundException exception if the name does not already exist.
|
ModelNotFoundException exception if the name does not already exist.
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'update model {model_name}')
|
self.logger.debug(f"update model {model_name}")
|
||||||
if not self.model_exists(model_name, base_model, model_type):
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
|
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well.
|
as well.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'delete model {model_name}')
|
self.logger.debug(f"delete model {model_name}")
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
@ -478,8 +478,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||||
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
convert_dest_directory: Optional[Path] = Field(
|
||||||
|
default=None, description="Optional directory location for merged model"
|
||||||
|
),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -494,10 +496,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
directory already in place.
|
directory already in place.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'convert model {model_name}')
|
self.logger.debug(f"convert model {model_name}")
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path]=None):
|
def commit(self, conf_file: Optional[Path] = None):
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@ -524,7 +526,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info
|
model_info=model_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
@ -535,16 +537,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self):
|
def logger(self):
|
||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(
|
||||||
items_to_import: set[str],
|
self,
|
||||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
items_to_import: set[str],
|
||||||
)->dict[str, AddModelResult]:
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
) -> dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@ -559,18 +561,24 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
'''
|
"""
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
|
||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
model_names: List[str] = Field(
|
||||||
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
),
|
||||||
alpha: Optional[float] = 0.5,
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
default=None, description="Base model shared by all models to be merged"
|
||||||
force: Optional[bool] = False,
|
),
|
||||||
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = Field(
|
||||||
|
default=None, description="Optional directory location for merged model"
|
||||||
|
),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -578,25 +586,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param base_model: Base model to use for all models
|
:param base_model: Base model to use for all models
|
||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
"""
|
"""
|
||||||
merger = ModelMerger(self.mgr)
|
merger = ModelMerger(self.mgr)
|
||||||
try:
|
try:
|
||||||
result = merger.merge_diffusion_models_and_save(
|
result = merger.merge_diffusion_models_and_save(
|
||||||
model_names = model_names,
|
model_names=model_names,
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
merged_model_name = merged_model_name,
|
merged_model_name=merged_model_name,
|
||||||
alpha = alpha,
|
alpha=alpha,
|
||||||
interp = interp,
|
interp=interp,
|
||||||
force = force,
|
force=force,
|
||||||
merge_dest_directory=merge_dest_directory,
|
merge_dest_directory=merge_dest_directory,
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def search_for_models(self, directory: Path)->List[Path]:
|
def search_for_models(self, directory: Path) -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Return list of all models found in the designated directory.
|
Return list of all models found in the designated directory.
|
||||||
"""
|
"""
|
||||||
@ -605,28 +613,29 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def sync_to_config(self):
|
def sync_to_config(self):
|
||||||
"""
|
"""
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
in the autoimport directories. Call after making changes outside the
|
in the autoimport directories. Call after making changes outside the
|
||||||
model manager API.
|
model manager API.
|
||||||
"""
|
"""
|
||||||
return self.mgr.sync_to_config()
|
return self.mgr.sync_to_config()
|
||||||
|
|
||||||
def list_checkpoint_configs(self)->List[Path]:
|
def list_checkpoint_configs(self) -> List[Path]:
|
||||||
"""
|
"""
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
"""
|
"""
|
||||||
config = self.mgr.app_config
|
config = self.mgr.app_config
|
||||||
conf_path = config.legacy_conf_path
|
conf_path = config.legacy_conf_path
|
||||||
root_path = config.root_path
|
root_path = config.root_path
|
||||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||||
|
|
||||||
def rename_model(self,
|
def rename_model(
|
||||||
model_name: str,
|
self,
|
||||||
base_model: BaseModelType,
|
model_name: str,
|
||||||
model_type: ModelType,
|
base_model: BaseModelType,
|
||||||
new_name: str = None,
|
model_type: ModelType,
|
||||||
new_base: BaseModelType = None,
|
new_name: str = None,
|
||||||
):
|
new_base: BaseModelType = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Rename the indicated model. Can provide a new name and/or a new base.
|
Rename the indicated model. Can provide a new name and/or a new base.
|
||||||
:param model_name: Current name of the model
|
:param model_name: Current name of the model
|
||||||
@ -635,10 +644,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param new_name: New name for the model
|
:param new_name: New name for the model
|
||||||
:param new_base: New base for the model
|
:param new_base: New base for the model
|
||||||
"""
|
"""
|
||||||
self.mgr.rename_model(base_model = base_model,
|
self.mgr.rename_model(
|
||||||
model_type = model_type,
|
base_model=base_model,
|
||||||
model_name = model_name,
|
model_type=model_type,
|
||||||
new_name = new_name,
|
model_name=model_name,
|
||||||
new_base = new_base,
|
new_name=new_name,
|
||||||
)
|
new_base=new_base,
|
||||||
|
)
|
||||||
|
@ -11,30 +11,20 @@ class BoardRecord(BaseModel):
|
|||||||
"""The unique ID of the board."""
|
"""The unique ID of the board."""
|
||||||
board_name: str = Field(description="The name of the board.")
|
board_name: str = Field(description="The name of the board.")
|
||||||
"""The name of the board."""
|
"""The name of the board."""
|
||||||
created_at: Union[datetime, str] = Field(
|
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
|
||||||
description="The created timestamp of the board."
|
|
||||||
)
|
|
||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime, str] = Field(
|
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
||||||
description="The updated timestamp of the board."
|
|
||||||
)
|
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime, str, None] = Field(
|
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
||||||
description="The deleted timestamp of the board."
|
|
||||||
)
|
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
cover_image_name: Optional[str] = Field(
|
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
||||||
description="The name of the cover image of the board."
|
|
||||||
)
|
|
||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
class BoardDTO(BoardRecord):
|
class BoardDTO(BoardRecord):
|
||||||
"""Deserialized board record with cover image URL and image count."""
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
cover_image_name: Optional[str] = Field(
|
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||||
description="The name of the board's cover image."
|
|
||||||
)
|
|
||||||
"""The URL of the thumbnail of the most recent image in the board."""
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
image_count: int = Field(description="The number of images in the board.")
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
"""The number of images in the board."""
|
"""The number of images in the board."""
|
||||||
|
@ -20,17 +20,11 @@ class ImageRecord(BaseModel):
|
|||||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||||
height: int = Field(description="The height of the image in px.")
|
height: int = Field(description="The height of the image in px.")
|
||||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||||
created_at: Union[datetime.datetime, str] = Field(
|
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
|
||||||
description="The created timestamp of the image."
|
|
||||||
)
|
|
||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime.datetime, str] = Field(
|
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
||||||
description="The updated timestamp of the image."
|
|
||||||
)
|
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime.datetime, str, None] = Field(
|
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
||||||
description="The deleted timestamp of the image."
|
|
||||||
)
|
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
"""Whether this is an intermediate image."""
|
"""Whether this is an intermediate image."""
|
||||||
@ -55,18 +49,14 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
|||||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_category: Optional[ImageCategory] = Field(
|
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
||||||
description="The image's new category."
|
|
||||||
)
|
|
||||||
"""The image's new category."""
|
"""The image's new category."""
|
||||||
session_id: Optional[StrictStr] = Field(
|
session_id: Optional[StrictStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The image's new session ID.",
|
description="The image's new session ID.",
|
||||||
)
|
)
|
||||||
"""The image's new session ID."""
|
"""The image's new session ID."""
|
||||||
is_intermediate: Optional[StrictBool] = Field(
|
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
||||||
default=None, description="The image's new `is_intermediate` flag."
|
|
||||||
)
|
|
||||||
"""The image's new `is_intermediate` flag."""
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
|
||||||
|
|
||||||
@ -84,9 +74,7 @@ class ImageUrlsDTO(BaseModel):
|
|||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
board_id: Optional[str] = Field(
|
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||||
description="The id of the board the image belongs to, if one exists."
|
|
||||||
)
|
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -110,12 +98,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
|
|
||||||
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||||
image_name = image_dict.get("image_name", "unknown")
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_origin = ResourceOrigin(
|
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
|
||||||
)
|
|
||||||
image_category = ImageCategory(
|
|
||||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
|
||||||
)
|
|
||||||
width = image_dict.get("width", 0)
|
width = image_dict.get("width", 0)
|
||||||
height = image_dict.get("height", 0)
|
height = image_dict.get("height", 0)
|
||||||
session_id = image_dict.get("session_id", None)
|
session_id = image_dict.get("session_id", None)
|
||||||
|
@ -8,6 +8,8 @@ from .invoker import InvocationProcessorABC, Invoker
|
|||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
__stop_event: Event
|
__stop_event: Event
|
||||||
@ -24,9 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs=dict(stop_event=self.__stop_event),
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = (
|
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||||
True # TODO: make async and do not use threads
|
|
||||||
)
|
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
@ -47,10 +47,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_execution_state = (
|
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||||
self.__invoker.services.graph_execution_manager.get(
|
queue_item.graph_execution_state_id
|
||||||
queue_item.graph_execution_state_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||||
@ -60,11 +58,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error=traceback.format_exc(),
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
invocation = graph_execution_state.execution_graph.get_node(
|
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||||
queue_item.invocation_id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||||
@ -82,7 +78,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id
|
source_node_id=source_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
@ -95,18 +91,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
graph_execution_state.id
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
graph_execution_state
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
@ -130,9 +122,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
graph_execution_state.set_node_error(invocation.id, error)
|
graph_execution_state.set_node_error(invocation.id, error)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(
|
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||||
graph_execution_state
|
|
||||||
)
|
|
||||||
|
|
||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
# Send error event
|
# Send error event
|
||||||
@ -147,9 +137,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
graph_execution_state.id
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Queue any further commands if invoking all
|
# Queue any further commands if invoking all
|
||||||
@ -164,12 +152,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc()
|
error=traceback.format_exc(),
|
||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||||
graph_execution_state.id
|
|
||||||
)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||||
|
@ -66,9 +66,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def get(self, id: str) -> Optional[T]:
|
def get(self, id: str) -> Optional[T]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
|
||||||
)
|
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@ -81,9 +79,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def get_raw(self, id: str) -> Optional[str]:
|
def get_raw(self, id: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
|
||||||
)
|
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@ -96,9 +92,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@ -122,13 +116,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](
|
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
|
||||||
)
|
|
||||||
|
|
||||||
def search(
|
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
self, query: str, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[T]:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -149,6 +139,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](
|
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||||
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
|
||||||
)
|
|
||||||
|
@ -17,16 +17,8 @@ from controlnet_aux.util import HWC3, resize_image
|
|||||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||||
|
|
||||||
lvmin_kernels_raw = [
|
lvmin_kernels_raw = [
|
||||||
np.array([
|
np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
|
||||||
[-1, -1, -1],
|
np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
|
||||||
[0, 1, 0],
|
|
||||||
[1, 1, 1]
|
|
||||||
], dtype=np.int32),
|
|
||||||
np.array([
|
|
||||||
[0, -1, -1],
|
|
||||||
[1, 1, -1],
|
|
||||||
[0, 1, 0]
|
|
||||||
], dtype=np.int32)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
lvmin_kernels = []
|
lvmin_kernels = []
|
||||||
@ -36,16 +28,8 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
|||||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
|
||||||
lvmin_prunings_raw = [
|
lvmin_prunings_raw = [
|
||||||
np.array([
|
np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
|
||||||
[-1, -1, -1],
|
np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
|
||||||
[-1, 1, -1],
|
|
||||||
[0, 0, -1]
|
|
||||||
], dtype=np.int32),
|
|
||||||
np.array([
|
|
||||||
[-1, -1, -1],
|
|
||||||
[-1, 1, -1],
|
|
||||||
[-1, 0, 0]
|
|
||||||
], dtype=np.int32)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
lvmin_prunings = []
|
lvmin_prunings = []
|
||||||
@ -99,10 +83,10 @@ def nake_nms(x):
|
|||||||
################################################################################
|
################################################################################
|
||||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||||
def pixel_perfect_resolution(
|
def pixel_perfect_resolution(
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
target_H: int,
|
target_H: int,
|
||||||
target_W: int,
|
target_W: int,
|
||||||
resize_mode: str,
|
resize_mode: str,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||||
@ -135,7 +119,7 @@ def pixel_perfect_resolution(
|
|||||||
|
|
||||||
if resize_mode == "fill_resize":
|
if resize_mode == "fill_resize":
|
||||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||||
|
|
||||||
# print(f"Pixel Perfect Computation:")
|
# print(f"Pixel Perfect Computation:")
|
||||||
@ -154,13 +138,7 @@ def pixel_perfect_resolution(
|
|||||||
# modified for InvokeAI
|
# modified for InvokeAI
|
||||||
###########################################################################
|
###########################################################################
|
||||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||||
def np_img_resize(
|
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
|
||||||
np_img: np.ndarray,
|
|
||||||
resize_mode: str,
|
|
||||||
h: int,
|
|
||||||
w: int,
|
|
||||||
device: torch.device = torch.device('cpu')
|
|
||||||
):
|
|
||||||
# if 'inpaint' in module:
|
# if 'inpaint' in module:
|
||||||
# np_img = np_img.astype(np.float32)
|
# np_img = np_img.astype(np.float32)
|
||||||
# else:
|
# else:
|
||||||
@ -184,15 +162,14 @@ def np_img_resize(
|
|||||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||||
y = torch.from_numpy(y)
|
y = torch.from_numpy(y)
|
||||||
y = y.float() / 255.0
|
y = y.float() / 255.0
|
||||||
y = rearrange(y, 'h w c -> 1 c h w')
|
y = rearrange(y, "h w c -> 1 c h w")
|
||||||
y = y.clone()
|
y = y.clone()
|
||||||
# y = y.to(devices.get_device_for("controlnet"))
|
# y = y.to(devices.get_device_for("controlnet"))
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
y = y.clone()
|
y = y.clone()
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def high_quality_resize(x: np.ndarray,
|
def high_quality_resize(x: np.ndarray, size):
|
||||||
size):
|
|
||||||
# Written by lvmin
|
# Written by lvmin
|
||||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||||
inpaint_mask = None
|
inpaint_mask = None
|
||||||
@ -244,7 +221,7 @@ def np_img_resize(
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||||
if resize_mode == "just_resize": # RESIZE
|
if resize_mode == "just_resize": # RESIZE
|
||||||
np_img = high_quality_resize(np_img, (w, h))
|
np_img = high_quality_resize(np_img, (w, h))
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
@ -270,20 +247,21 @@ def np_img_resize(
|
|||||||
new_h, new_w, _ = np_img.shape
|
new_h, new_w, _ = np_img.shape
|
||||||
pad_h = max(0, (h - new_h) // 2)
|
pad_h = max(0, (h - new_h) // 2)
|
||||||
pad_w = max(0, (w - new_w) // 2)
|
pad_w = max(0, (w - new_w) // 2)
|
||||||
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
|
||||||
np_img = high_quality_background
|
np_img = high_quality_background
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||||
k = max(k0, k1)
|
k = max(k0, k1)
|
||||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
new_h, new_w, _ = np_img.shape
|
new_h, new_w, _ = np_img.shape
|
||||||
pad_h = max(0, (new_h - h) // 2)
|
pad_h = max(0, (new_h - h) // 2)
|
||||||
pad_w = max(0, (new_w - w) // 2)
|
pad_w = max(0, (new_w - w) // 2)
|
||||||
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
|
|
||||||
|
|
||||||
def prepare_control_image(
|
def prepare_control_image(
|
||||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||||
@ -301,15 +279,17 @@ def prepare_control_image(
|
|||||||
resize_mode="just_resize_simple",
|
resize_mode="just_resize_simple",
|
||||||
):
|
):
|
||||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||||
if (resize_mode == "just_resize_simple" or
|
if (
|
||||||
resize_mode == "crop_resize_simple" or
|
resize_mode == "just_resize_simple"
|
||||||
resize_mode == "fill_resize_simple"):
|
or resize_mode == "crop_resize_simple"
|
||||||
|
or resize_mode == "fill_resize_simple"
|
||||||
|
):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
if (resize_mode == "just_resize_simple"):
|
if resize_mode == "just_resize_simple":
|
||||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
elif resize_mode == "crop_resize_simple": # not yet implemented
|
||||||
pass
|
pass
|
||||||
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
elif resize_mode == "fill_resize_simple": # not yet implemented
|
||||||
pass
|
pass
|
||||||
nimage = np.array(image)
|
nimage = np.array(image)
|
||||||
nimage = nimage[None, :]
|
nimage = nimage[None, :]
|
||||||
@ -320,7 +300,7 @@ def prepare_control_image(
|
|||||||
timage = torch.from_numpy(nimage)
|
timage = torch.from_numpy(nimage)
|
||||||
|
|
||||||
# use fancy lvmin controlnet resizing
|
# use fancy lvmin controlnet resizing
|
||||||
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
|
||||||
nimage = np.array(image)
|
nimage = np.array(image)
|
||||||
timage, nimage = np_img_resize(
|
timage, nimage = np_img_resize(
|
||||||
np_img=nimage,
|
np_img=nimage,
|
||||||
@ -336,7 +316,7 @@ def prepare_control_image(
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
timage = timage.to(device=device, dtype=dtype)
|
timage = timage.to(device=device, dtype=dtype)
|
||||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||||
if do_classifier_free_guidance and not cfg_injection:
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
timage = torch.cat([timage] * 2)
|
timage = torch.cat([timage] * 2)
|
||||||
return timage
|
return timage
|
||||||
|
@ -9,19 +9,16 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||||
|
|
||||||
if smooth_matrix is not None:
|
if smooth_matrix is not None:
|
||||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
||||||
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||||
|
|
||||||
latents_ubyte = (
|
latents_ubyte = (
|
||||||
((latent_image + 1) / 2)
|
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
||||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
|
||||||
.mul(0xFF) # to 0..255
|
|
||||||
.byte()
|
|
||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
@ -92,6 +89,7 @@ def stable_diffusion_step_callback(
|
|||||||
total_steps=node["steps"],
|
total_steps=node["steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_xl_step_callback(
|
def stable_diffusion_xl_step_callback(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
node: dict,
|
node: dict,
|
||||||
@ -106,9 +104,9 @@ def stable_diffusion_xl_step_callback(
|
|||||||
sdxl_latent_rgb_factors = torch.tensor(
|
sdxl_latent_rgb_factors = torch.tensor(
|
||||||
[
|
[
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.3816, 0.4930, 0.5320],
|
[0.3816, 0.4930, 0.5320],
|
||||||
[-0.3753, 0.1631, 0.1739],
|
[-0.3753, 0.1631, 0.1739],
|
||||||
[ 0.1770, 0.3588, -0.2048],
|
[0.1770, 0.3588, -0.2048],
|
||||||
[-0.4350, -0.2644, -0.4289],
|
[-0.4350, -0.2644, -0.4289],
|
||||||
],
|
],
|
||||||
dtype=sample.dtype,
|
dtype=sample.dtype,
|
||||||
@ -117,9 +115,9 @@ def stable_diffusion_xl_step_callback(
|
|||||||
|
|
||||||
sdxl_smooth_matrix = torch.tensor(
|
sdxl_smooth_matrix = torch.tensor(
|
||||||
[
|
[
|
||||||
#[ 0.0478, 0.1285, 0.0478],
|
# [ 0.0478, 0.1285, 0.0478],
|
||||||
#[ 0.1285, 0.2948, 0.1285],
|
# [ 0.1285, 0.2948, 0.1285],
|
||||||
#[ 0.0478, 0.1285, 0.0478],
|
# [ 0.0478, 0.1285, 0.0478],
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
[0.0964, 0.4711, 0.0964],
|
[0.0964, 0.4711, 0.0964],
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
@ -143,4 +141,4 @@ def stable_diffusion_xl_step_callback(
|
|||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
step=step,
|
step=step,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generator import (
|
from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
|
||||||
InvokeAIGeneratorBasicParams,
|
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
||||||
InvokeAIGenerator,
|
|
||||||
InvokeAIGeneratorOutput,
|
|
||||||
Img2Img,
|
|
||||||
Inpaint
|
|
||||||
)
|
|
||||||
from .model_management import (
|
|
||||||
ModelManager, ModelCache, BaseModelType,
|
|
||||||
ModelType, SubModelType, ModelInfo
|
|
||||||
)
|
|
||||||
from .model_management.models import SilenceWarnings
|
from .model_management.models import SilenceWarnings
|
||||||
|
@ -33,61 +33,66 @@ from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
|||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
seed: Optional[int]=None
|
seed: Optional[int] = None
|
||||||
width: int=512
|
width: int = 512
|
||||||
height: int=512
|
height: int = 512
|
||||||
cfg_scale: float=7.5
|
cfg_scale: float = 7.5
|
||||||
steps: int=20
|
steps: int = 20
|
||||||
ddim_eta: float=0.0
|
ddim_eta: float = 0.0
|
||||||
scheduler: str='ddim'
|
scheduler: str = "ddim"
|
||||||
precision: str='float16'
|
precision: str = "float16"
|
||||||
perlin: float=0.0
|
perlin: float = 0.0
|
||||||
threshold: float=0.0
|
threshold: float = 0.0
|
||||||
seamless: bool=False
|
seamless: bool = False
|
||||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
|
||||||
h_symmetry_time_pct: Optional[float]=None
|
h_symmetry_time_pct: Optional[float] = None
|
||||||
v_symmetry_time_pct: Optional[float]=None
|
v_symmetry_time_pct: Optional[float] = None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list=field(default_factory=list)
|
with_variations: list = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
'''
|
"""
|
||||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||||
operation, including the image, its seed, the model name used to generate the image
|
operation, including the image, its seed, the model name used to generate the image
|
||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
'''
|
"""
|
||||||
|
|
||||||
image: Image.Image
|
image: Image.Image
|
||||||
seed: int
|
seed: int
|
||||||
model_hash: str
|
model_hash: str
|
||||||
attention_maps_images: List[Image.Image]
|
attention_maps_images: List[Image.Image]
|
||||||
params: Namespace
|
params: Namespace
|
||||||
|
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
model_info: dict,
|
self,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
model_info: dict,
|
||||||
**kwargs,
|
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
|
||||||
):
|
**kwargs,
|
||||||
self.model_info=model_info
|
):
|
||||||
self.params=params
|
self.model_info = model_info
|
||||||
|
self.params = params
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
conditioning: tuple,
|
conditioning: tuple,
|
||||||
scheduler,
|
scheduler,
|
||||||
callback: Optional[Callable]=None,
|
callback: Optional[Callable] = None,
|
||||||
step_callback: Optional[Callable]=None,
|
step_callback: Optional[Callable] = None,
|
||||||
iterations: int=1,
|
iterations: int = 1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||||
'''
|
"""
|
||||||
Return an iterator across the indicated number of generations.
|
Return an iterator across the indicated number of generations.
|
||||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||||
object. Use like this:
|
object. Use like this:
|
||||||
@ -107,7 +112,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
for o in outputs:
|
for o in outputs:
|
||||||
print(o.image, o.seed)
|
print(o.image, o.seed)
|
||||||
|
|
||||||
'''
|
"""
|
||||||
generator_args = dataclasses.asdict(self.params)
|
generator_args = dataclasses.asdict(self.params)
|
||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
@ -118,22 +123,21 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get("seed"),
|
||||||
generator_args.get('with_variations')
|
generator_args.get("variation_amount"),
|
||||||
)
|
generator_args.get("with_variations"),
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
for component in [model.unet, model.vae]:
|
for component in [model.unet, model.vae]:
|
||||||
configure_model_padding(component,
|
configure_model_padding(
|
||||||
generator_args.get('seamless',False),
|
component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
||||||
generator_args.get('seamless_axes')
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
configure_model_padding(model,
|
configure_model_padding(
|
||||||
generator_args.get('seamless',False),
|
model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
||||||
generator_args.get('seamless_axes')
|
)
|
||||||
)
|
|
||||||
|
|
||||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
for i in iteration_count:
|
for i in iteration_count:
|
||||||
@ -147,66 +151,66 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
attention_maps_images=results[0][2],
|
attention_maps_images=results[0][2],
|
||||||
model_hash = model_hash,
|
model_hash=model_hash,
|
||||||
params=Namespace(model_name=model_name,**generator_args),
|
params=Namespace(model_name=model_name, **generator_args),
|
||||||
)
|
)
|
||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schedulers(self)->List[str]:
|
def schedulers(self) -> List[str]:
|
||||||
'''
|
"""
|
||||||
Return list of all the schedulers that we currently handle.
|
Return list of all the schedulers that we currently handle.
|
||||||
'''
|
"""
|
||||||
return list(SCHEDULER_MAP.keys())
|
return list(SCHEDULER_MAP.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls)->Type[Generator]:
|
def _generator_class(cls) -> Type[Generator]:
|
||||||
'''
|
"""
|
||||||
In derived classes return the name of the generator to apply.
|
In derived classes return the name of the generator to apply.
|
||||||
If you don't override will return the name of the derived
|
If you don't override will return the name of the derived
|
||||||
class, which nicely parallels the generator class names.
|
class, which nicely parallels the generator class names.
|
||||||
'''
|
"""
|
||||||
return Generator
|
return Generator
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(
|
||||||
init_image: Union[Image.Image, torch.FloatTensor],
|
self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
|
||||||
strength: float=0.75,
|
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||||
**keyword_args
|
return super().generate(init_image=init_image, strength=strength, **keyword_args)
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(init_image=init_image,
|
|
||||||
strength=strength,
|
|
||||||
**keyword_args
|
|
||||||
)
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls):
|
def _generator_class(cls):
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
|
|
||||||
return Img2Img
|
return Img2Img
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(self,
|
def generate(
|
||||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
self,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||||
seam_size: int = 96,
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_blur: int = 16,
|
seam_size: int = 96,
|
||||||
seam_strength: float = 0.7,
|
seam_blur: int = 16,
|
||||||
seam_steps: int = 30,
|
seam_strength: float = 0.7,
|
||||||
tile_size: int = 32,
|
seam_steps: int = 30,
|
||||||
inpaint_replace=False,
|
tile_size: int = 32,
|
||||||
infill_method=None,
|
inpaint_replace=False,
|
||||||
inpaint_width=None,
|
infill_method=None,
|
||||||
inpaint_height=None,
|
inpaint_width=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_height=None,
|
||||||
**keyword_args
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
**keyword_args,
|
||||||
|
) -> Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(
|
return super().generate(
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
@ -219,13 +223,16 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_width=inpaint_width,
|
inpaint_width=inpaint_width,
|
||||||
inpaint_height=inpaint_height,
|
inpaint_height=inpaint_height,
|
||||||
inpaint_fill=inpaint_fill,
|
inpaint_fill=inpaint_fill,
|
||||||
**keyword_args
|
**keyword_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls):
|
def _generator_class(cls):
|
||||||
from .inpaint import Inpaint
|
from .inpaint import Inpaint
|
||||||
|
|
||||||
return Inpaint
|
return Inpaint
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
@ -251,9 +258,7 @@ class Generator:
|
|||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
||||||
"image_iterator() must be implemented in a descendent class"
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_variation(self, seed, variation_amount, with_variations):
|
def set_variation(self, seed, variation_amount, with_variations):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -280,9 +285,7 @@ class Generator:
|
|||||||
scope = nullcontext
|
scope = nullcontext
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
attention_maps_images = []
|
attention_maps_images = []
|
||||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||||
saver.get_stacked_maps_image()
|
|
||||||
)
|
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
init_image=init_image,
|
init_image=init_image,
|
||||||
@ -327,11 +330,7 @@ class Generator:
|
|||||||
results.append([image, seed, attention_maps_images])
|
results.append([image, seed, attention_maps_images])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
attention_maps_image = (
|
attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
|
||||||
None
|
|
||||||
if len(attention_maps_images) == 0
|
|
||||||
else attention_maps_images[-1]
|
|
||||||
)
|
|
||||||
image_callback(
|
image_callback(
|
||||||
image,
|
image,
|
||||||
seed,
|
seed,
|
||||||
@ -342,9 +341,7 @@ class Generator:
|
|||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
# Free up memory from the last generation.
|
||||||
clear_cuda_cache = (
|
clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
|
||||||
)
|
|
||||||
if clear_cuda_cache is not None:
|
if clear_cuda_cache is not None:
|
||||||
clear_cuda_cache()
|
clear_cuda_cache()
|
||||||
|
|
||||||
@ -371,14 +368,8 @@ class Generator:
|
|||||||
|
|
||||||
# Get the original alpha channel of the mask if there is one.
|
# Get the original alpha channel of the mask if there is one.
|
||||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||||
pil_init_mask = (
|
pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
|
||||||
init_mask.getchannel("A")
|
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
||||||
if init_mask.mode == "RGBA"
|
|
||||||
else init_mask.convert("L")
|
|
||||||
)
|
|
||||||
pil_init_image = init_image.convert(
|
|
||||||
"RGBA"
|
|
||||||
) # Add an alpha channel if one doesn't exist
|
|
||||||
|
|
||||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||||
@ -404,10 +395,7 @@ class Generator:
|
|||||||
np_matched_result[:, :, :] = (
|
np_matched_result[:, :, :] = (
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
(
|
(np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
|
||||||
np_matched_result[:, :, :].astype(np.float32)
|
|
||||||
- gen_means[None, None, :]
|
|
||||||
)
|
|
||||||
/ gen_std[None, None, :]
|
/ gen_std[None, None, :]
|
||||||
)
|
)
|
||||||
* init_std[None, None, :]
|
* init_std[None, None, :]
|
||||||
@ -433,9 +421,7 @@ class Generator:
|
|||||||
else:
|
else:
|
||||||
blurred_init_mask = pil_init_mask
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
multiplied_blurred_init_mask = ImageChops.multiply(
|
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
||||||
blurred_init_mask, self.pil_image.split()[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Paste original on color-corrected generation (using blurred mask)
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||||
@ -461,10 +447,7 @@ class Generator:
|
|||||||
|
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||||
latents_ubyte = (
|
latents_ubyte = (
|
||||||
((latent_image + 1) / 2)
|
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
||||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
|
||||||
.mul(0xFF) # to 0..255
|
|
||||||
.byte()
|
|
||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
@ -494,9 +477,7 @@ class Generator:
|
|||||||
temp_height = int((height + 7) / 8) * 8
|
temp_height = int((height + 7) / 8) * 8
|
||||||
noise = torch.stack(
|
noise = torch.stack(
|
||||||
[
|
[
|
||||||
rand_perlin_2d(
|
rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
|
||||||
(temp_height, temp_width), (8, 8), device=self.model.device
|
|
||||||
).to(fixdevice)
|
|
||||||
for _ in range(input_channels)
|
for _ in range(input_channels)
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
@ -573,8 +554,6 @@ class Generator:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
perlin_noise = self.get_perlin_noise(
|
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
width // self.downsampling_factor, height // self.downsampling_factor
|
|
||||||
)
|
|
||||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||||
return x
|
return x
|
||||||
|
@ -77,10 +77,7 @@ class Img2Img(Generator):
|
|||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
if (
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
pipeline_output.attention_map_saver is not None
|
|
||||||
and attention_maps_callback is not None
|
|
||||||
):
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
@ -91,7 +88,5 @@ class Img2Img(Generator):
|
|||||||
x = torch.randn_like(like, device=device)
|
x = torch.randn_like(like, device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
shape = like.shape
|
shape = like.shape
|
||||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
||||||
shape[3], shape[2]
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
@ -68,15 +68,11 @@ class Inpaint(Img2Img):
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
im_patched_np = PatchMatch.inpaint(
|
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
|
||||||
)
|
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
def tile_fill_missing(
|
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||||
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
|
||||||
) -> Image.Image:
|
|
||||||
# Only fill if there's an alpha layer
|
# Only fill if there's an alpha layer
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
@ -127,15 +123,11 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
return si
|
return si
|
||||||
|
|
||||||
def mask_edge(
|
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
|
||||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
|
||||||
) -> Image.Image:
|
|
||||||
npimg = np.asarray(mask, dtype=np.uint8)
|
npimg = np.asarray(mask, dtype=np.uint8)
|
||||||
|
|
||||||
# Detect any partially transparent regions
|
# Detect any partially transparent regions
|
||||||
npgradient = np.uint8(
|
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
||||||
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detect hard edges
|
# Detect hard edges
|
||||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||||
@ -144,9 +136,7 @@ class Inpaint(Img2Img):
|
|||||||
npmask = npgradient + npedge
|
npmask = npgradient + npedge
|
||||||
|
|
||||||
# Expand
|
# Expand
|
||||||
npmask = cv2.dilate(
|
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
|
||||||
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
new_mask = Image.fromarray(npmask)
|
new_mask = Image.fromarray(npmask)
|
||||||
|
|
||||||
@ -242,25 +232,19 @@ class Inpaint(Img2Img):
|
|||||||
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
||||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||||
elif infill_method == "tile":
|
elif infill_method == "tile":
|
||||||
init_filled = self.tile_fill_missing(
|
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
|
||||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
|
||||||
)
|
|
||||||
elif infill_method == "solid":
|
elif infill_method == "solid":
|
||||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
||||||
f"Non-supported infill type {infill_method}", infill_method
|
|
||||||
)
|
|
||||||
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
||||||
|
|
||||||
# Resize if requested for inpainting
|
# Resize if requested for inpainting
|
||||||
if inpaint_width and inpaint_height:
|
if inpaint_width and inpaint_height:
|
||||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||||
|
|
||||||
debug_image(
|
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
||||||
init_filled, "init_filled", debug_status=self.enable_image_debugging
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create init tensor
|
# Create init tensor
|
||||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||||
@ -289,9 +273,7 @@ class Inpaint(Img2Img):
|
|||||||
"mask_image AFTER multiply with pil_image",
|
"mask_image AFTER multiply with pil_image",
|
||||||
debug_status=self.enable_image_debugging,
|
debug_status=self.enable_image_debugging,
|
||||||
)
|
)
|
||||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
|
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
mask_image, normalize=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
mask: torch.FloatTensor = mask_image
|
mask: torch.FloatTensor = mask_image
|
||||||
|
|
||||||
@ -302,9 +284,9 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
# todo: support cross-attention control
|
# todo: support cross-attention control
|
||||||
uc, c, _ = conditioning
|
uc, c, _ = conditioning
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
|
||||||
uc, c, cfg_scale
|
pipeline.scheduler, eta=ddim_eta
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
)
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, seed: int):
|
def make_image(x_T: torch.Tensor, seed: int):
|
||||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||||
@ -318,15 +300,10 @@ class Inpaint(Img2Img):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
||||||
pipeline_output.attention_map_saver is not None
|
|
||||||
and attention_maps_callback is not None
|
|
||||||
):
|
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
|
|
||||||
result = self.postprocess_size_and_mask(
|
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
||||||
pipeline.numpy_to_pil(pipeline_output.images)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||||
if seam_size > 0:
|
if seam_size > 0:
|
||||||
|
@ -8,9 +8,7 @@ from .txt2mask import Txt2Mask
|
|||||||
from .util import InitImageResizer, make_grid
|
from .util import InitImageResizer, make_grid
|
||||||
|
|
||||||
|
|
||||||
def debug_image(
|
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||||
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
|
||||||
):
|
|
||||||
if not debug_status:
|
if not debug_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -9,26 +9,26 @@ from PIL import Image
|
|||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class InvisibleWatermark:
|
class InvisibleWatermark:
|
||||||
"""
|
"""
|
||||||
Wrapper around InvisibleWatermark module.
|
Wrapper around InvisibleWatermark module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def invisible_watermark_available(self) -> bool:
|
def invisible_watermark_available(self) -> bool:
|
||||||
return config.invisible_watermark
|
return config.invisible_watermark
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_watermark(self, image: Image, watermark_text:str) -> Image:
|
def add_watermark(self, image: Image, watermark_text: str) -> Image:
|
||||||
if not self.invisible_watermark_available():
|
if not self.invisible_watermark_available():
|
||||||
return image
|
return image
|
||||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
encoder = WatermarkEncoder()
|
encoder = WatermarkEncoder()
|
||||||
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
|
encoder.set_watermark("bytes", watermark_text.encode("utf-8"))
|
||||||
bgr_encoded = encoder.encode(bgr, 'dwtDct')
|
bgr_encoded = encoder.encode(bgr, "dwtDct")
|
||||||
return Image.fromarray(
|
return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
|
|
||||||
).convert("RGBA")
|
|
||||||
|
@ -7,8 +7,10 @@ be suppressed or deferred
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
|
@ -34,9 +34,7 @@ class PngWriter:
|
|||||||
|
|
||||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||||
# returns full path of output
|
# returns full path of output
|
||||||
def save_image_and_prompt_to_png(
|
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
||||||
self, image, dream_prompt, name, metadata=None, compress_level=6
|
|
||||||
):
|
|
||||||
path = os.path.join(self.outdir, name)
|
path = os.path.join(self.outdir, name)
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text("Dream", dream_prompt)
|
info.add_text("Dream", dream_prompt)
|
||||||
@ -114,8 +112,6 @@ class PromptFormatter:
|
|||||||
if opt.variation_amount > 0:
|
if opt.variation_amount > 0:
|
||||||
switches.append(f"-v{opt.variation_amount}")
|
switches.append(f"-v{opt.variation_amount}")
|
||||||
if opt.with_variations:
|
if opt.with_variations:
|
||||||
formatted_variations = ",".join(
|
formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
|
||||||
f"{seed}:{weight}" for seed, weight in opt.with_variations
|
|
||||||
)
|
|
||||||
switches.append(f"-V{formatted_variations}")
|
switches.append(f"-V{formatted_variations}")
|
||||||
return " ".join(switches)
|
return " ".join(switches)
|
||||||
|
@ -9,14 +9,17 @@ from invokeai.backend import SilenceWarnings
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
|
|
||||||
|
|
||||||
class SafetyChecker:
|
class SafetyChecker:
|
||||||
"""
|
"""
|
||||||
Wrapper around SafetyChecker model.
|
Wrapper around SafetyChecker model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
tried_load: bool = False
|
tried_load: bool = False
|
||||||
@ -25,21 +28,19 @@ class SafetyChecker:
|
|||||||
def _load_safety_checker(self):
|
def _load_safety_checker(self):
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
|
|
||||||
if config.nsfw_checker:
|
if config.nsfw_checker:
|
||||||
try:
|
try:
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
||||||
config.models_path / CHECKER_PATH
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
)
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
logger.info("NSFW checker initialized")
|
||||||
config.models_path / CHECKER_PATH)
|
|
||||||
logger.info('NSFW checker initialized')
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'Could not load NSFW checker: {str(e)}')
|
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||||
else:
|
else:
|
||||||
logger.info('NSFW checker loading disabled')
|
logger.info("NSFW checker loading disabled")
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -51,7 +52,7 @@ class SafetyChecker:
|
|||||||
def has_nsfw_concept(self, image: Image) -> bool:
|
def has_nsfw_concept(self, image: Image) -> bool:
|
||||||
if not self.safety_checker_available():
|
if not self.safety_checker_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
features = self.feature_extractor([image], return_tensors="pt")
|
features = self.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
|
@ -5,12 +5,8 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
"""
|
"""
|
||||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||||
"""
|
"""
|
||||||
working = nn.functional.pad(
|
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
||||||
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
|
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
||||||
)
|
|
||||||
working = nn.functional.pad(
|
|
||||||
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
|
|
||||||
)
|
|
||||||
return nn.functional.conv2d(
|
return nn.functional.conv2d(
|
||||||
working,
|
working,
|
||||||
weight,
|
weight,
|
||||||
@ -32,18 +28,14 @@ def configure_model_padding(model, seamless, seamless_axes):
|
|||||||
if seamless:
|
if seamless:
|
||||||
m.asymmetric_padding_mode = {}
|
m.asymmetric_padding_mode = {}
|
||||||
m.asymmetric_padding = {}
|
m.asymmetric_padding = {}
|
||||||
m.asymmetric_padding_mode["x"] = (
|
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
||||||
"circular" if ("x" in seamless_axes) else "constant"
|
|
||||||
)
|
|
||||||
m.asymmetric_padding["x"] = (
|
m.asymmetric_padding["x"] = (
|
||||||
m._reversed_padding_repeated_twice[0],
|
m._reversed_padding_repeated_twice[0],
|
||||||
m._reversed_padding_repeated_twice[1],
|
m._reversed_padding_repeated_twice[1],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
m.asymmetric_padding_mode["y"] = (
|
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
||||||
"circular" if ("y" in seamless_axes) else "constant"
|
|
||||||
)
|
|
||||||
m.asymmetric_padding["y"] = (
|
m.asymmetric_padding["y"] = (
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
@ -39,23 +39,18 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
|||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
self.heatmap = heatmap
|
self.heatmap = heatmap
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def to_grayscale(self, invert: bool = False) -> Image:
|
def to_grayscale(self, invert: bool = False) -> Image:
|
||||||
return self._rescale(
|
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||||
Image.fromarray(
|
|
||||||
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
return self._rescale(
|
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||||
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_transparent(self, invert: bool = False) -> Image:
|
def to_transparent(self, invert: bool = False) -> Image:
|
||||||
transparent_image = self.image.copy()
|
transparent_image = self.image.copy()
|
||||||
@ -67,11 +62,7 @@ class SegmentedGrayscale(object):
|
|||||||
|
|
||||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||||
def _rescale(self, heatmap: Image) -> Image:
|
def _rescale(self, heatmap: Image) -> Image:
|
||||||
size = (
|
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||||
self.image.width
|
|
||||||
if (self.image.width > self.image.height)
|
|
||||||
else self.image.height
|
|
||||||
)
|
|
||||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||||
|
|
||||||
@ -87,12 +78,8 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
||||||
)
|
|
||||||
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
|
||||||
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||||
@ -107,9 +94,7 @@ class Txt2Mask(object):
|
|||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
img = self._scale_and_crop(image)
|
img = self._scale_and_crop(image)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
|
||||||
text=[prompt], images=[img], padding=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
heatmap = torch.sigmoid(outputs.logits)
|
heatmap = torch.sigmoid(outputs.logits)
|
||||||
return SegmentedGrayscale(image, heatmap)
|
return SegmentedGrayscale(image, heatmap)
|
||||||
|
@ -6,28 +6,31 @@ from invokeai.app.services.config import (
|
|||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||||
try:
|
try:
|
||||||
assert config.model_conf_path.exists(), f'{config.model_conf_path} not found'
|
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||||
assert config.db_path.parent.exists(), f'{config.db_path.parent} not found'
|
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||||
assert config.models_path.exists(), f'{config.models_path} not found'
|
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||||
for model in [
|
for model in [
|
||||||
'CLIP-ViT-bigG-14-laion2B-39B-b160k',
|
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
||||||
'bert-base-uncased',
|
"bert-base-uncased",
|
||||||
'clip-vit-large-patch14',
|
"clip-vit-large-patch14",
|
||||||
'sd-vae-ft-mse',
|
"sd-vae-ft-mse",
|
||||||
'stable-diffusion-2-clip',
|
"stable-diffusion-2-clip",
|
||||||
'stable-diffusion-safety-checker']:
|
"stable-diffusion-safety-checker",
|
||||||
path = config.models_path / f'core/convert/{model}'
|
]:
|
||||||
assert path.exists(), f'{path} is missing'
|
path = config.models_path / f"core/convert/{model}"
|
||||||
|
assert path.exists(), f"{path} is missing"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
print(f'An exception has occurred: {str(e)}')
|
print(f"An exception has occurred: {str(e)}")
|
||||||
print('== STARTUP ABORTED ==')
|
print("== STARTUP ABORTED ==")
|
||||||
print('** One or more necessary files is missing from your InvokeAI root directory **')
|
print("** One or more necessary files is missing from your InvokeAI root directory **")
|
||||||
print('** Please rerun the configuration script to fix this problem. **')
|
print("** Please rerun the configuration script to fix this problem. **")
|
||||||
print('** From the launcher, selection option [7]. **')
|
print("** From the launcher, selection option [7]. **")
|
||||||
print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **')
|
print(
|
||||||
input('Press any key to continue...')
|
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||||
|
)
|
||||||
|
input("Press any key to continue...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
@ -60,9 +60,7 @@ from invokeai.backend.install.model_install_backend import (
|
|||||||
InstallSelections,
|
InstallSelections,
|
||||||
ModelInstall,
|
ModelInstall,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_probe import (
|
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
||||||
ModelType, BaseModelType
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -77,7 +75,7 @@ Model_dir = "models"
|
|||||||
Default_config_file = config.model_conf_path
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = config.legacy_conf_path
|
SD_Configs = config.legacy_conf_path
|
||||||
|
|
||||||
PRECISION_CHOICES = ['auto','float16','float32']
|
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
||||||
|
|
||||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||||
@ -85,7 +83,8 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
|||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger=InvokeAILogger.getLogger()
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
def postscript(errors: None):
|
def postscript(errors: None):
|
||||||
@ -108,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
message = (
|
||||||
|
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||||
|
)
|
||||||
for err in errors:
|
for err in errors:
|
||||||
message += f"\t - {err}\n"
|
message += f"\t - {err}\n"
|
||||||
message += "Please check the logs above and correct any issues."
|
message += "Please check the logs above and correct any issues."
|
||||||
@ -169,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
logger.info(f"Installing {label} model file {model_url}...")
|
logger.info(f"Installing {label} model file {model_url}...")
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
request.urlretrieve(
|
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
|
||||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
|
||||||
)
|
|
||||||
logger.info("...downloaded successfully")
|
logger.info("...downloaded successfully")
|
||||||
else:
|
else:
|
||||||
logger.info("...exists")
|
logger.info("...exists")
|
||||||
@ -182,90 +181,93 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.root_path / 'models/core/convert'
|
target_dir = config.root_path / "models/core/convert"
|
||||||
kwargs = dict() # for future use
|
kwargs = dict() # for future use
|
||||||
try:
|
try:
|
||||||
logger.info('Downloading core tokenizers and text encoders')
|
logger.info("Downloading core tokenizers and text encoders")
|
||||||
|
|
||||||
# bert
|
# bert
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = "openai/clip-vit-large-patch14"
|
||||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
||||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
||||||
|
|
||||||
# sd-xl - tokenizer_2
|
# sd-xl - tokenizer_2
|
||||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
_, model_name = repo_id.split('/')
|
_, model_name = repo_id.split("/")
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||||
|
|
||||||
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info('Downloading stable diffusion VAE')
|
logger.info("Downloading stable diffusion VAE")
|
||||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
||||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info('Downloading safety checker')
|
logger.info("Downloading safety checker")
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing ESRGAN Upscaling models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
dict(
|
dict(
|
||||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
description = "RealESRGAN_x4plus.pth",
|
description="RealESRGAN_x4plus.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
description = "RealESRGAN_x4plus_anime_6B.pth",
|
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
description = "RealESRGAN_x2plus.pth",
|
description="RealESRGAN_x2plus.pth",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def get_root(root: str = None) -> str:
|
def get_root(root: str = None) -> str:
|
||||||
if root:
|
if root:
|
||||||
@ -275,6 +277,7 @@ def get_root(root: str = None) -> str:
|
|||||||
else:
|
else:
|
||||||
return str(config.root_path)
|
return str(config.root_path)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing - disabled
|
||||||
@ -283,14 +286,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
def create(self):
|
def create(self):
|
||||||
program_opts = self.parentApp.program_opts
|
program_opts = self.parentApp.program_opts
|
||||||
old_opts = self.parentApp.invokeai_opts
|
old_opts = self.parentApp.invokeai_opts
|
||||||
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
first_time = not (config.root_path / "invokeai.yaml").exists()
|
||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
label = """Configure startup settings. You can come back and change these later.
|
label = """Configure startup settings. You can come back and change these later.
|
||||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||||
"""
|
"""
|
||||||
for i in textwrap.wrap(label,width=window_width-6):
|
for i in textwrap.wrap(label, width=window_width - 6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=i,
|
value=i,
|
||||||
@ -300,7 +303,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||||
for line in textwrap.wrap(label,width=window_width-6):
|
for line in textwrap.wrap(label, width=window_width - 6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=line,
|
value=line,
|
||||||
@ -343,7 +346,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
relx=50,
|
relx=50,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -=1
|
self.nextrely -= 1
|
||||||
self.always_use_cpu = self.add_widget_intelligent(
|
self.always_use_cpu = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Force CPU to be used on GPU systems",
|
name="Force CPU to be used on GPU systems",
|
||||||
@ -351,10 +354,8 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
relx=80,
|
relx=80,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
precision = old_opts.precision or (
|
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||||
"float32" if program_opts.full_precision else "auto"
|
self.nextrely += 1
|
||||||
)
|
|
||||||
self.nextrely +=1
|
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="Floating Point Precision",
|
name="Floating Point Precision",
|
||||||
@ -363,10 +364,10 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -=1
|
self.nextrely -= 1
|
||||||
self.precision = self.add_widget_intelligent(
|
self.precision = self.add_widget_intelligent(
|
||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
columns = 3,
|
columns=3,
|
||||||
name="Precision",
|
name="Precision",
|
||||||
values=PRECISION_CHOICES,
|
values=PRECISION_CHOICES,
|
||||||
value=PRECISION_CHOICES.index(precision),
|
value=PRECISION_CHOICES.index(precision),
|
||||||
@ -398,25 +399,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.autoimport_dirs = {}
|
self.autoimport_dirs = {}
|
||||||
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
|
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||||
FileBox,
|
FileBox,
|
||||||
name=f'Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models',
|
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||||
value=str(config.root_path / config.autoimport_dir),
|
value=str(config.root_path / config.autoimport_dir),
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=False,
|
must_exist=False,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
labelColor="GOOD",
|
labelColor="GOOD",
|
||||||
begin_entry_at=32,
|
begin_entry_at=32,
|
||||||
max_height = 3,
|
max_height=3,
|
||||||
scroll_exit=True
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
||||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
||||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
||||||
"""
|
"""
|
||||||
for i in textwrap.wrap(label,width=window_width-6):
|
for i in textwrap.wrap(label, width=window_width - 6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=i,
|
value=i,
|
||||||
@ -431,11 +432,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = (
|
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
|
||||||
"DONE"
|
|
||||||
if program_opts.skip_sd_weights or program_opts.default_only
|
|
||||||
else "NEXT"
|
|
||||||
)
|
|
||||||
self.ok_button = self.add_widget_intelligent(
|
self.ok_button = self.add_widget_intelligent(
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
name=label,
|
name=label,
|
||||||
@ -454,13 +451,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
self.editing = False
|
self.editing = False
|
||||||
else:
|
else:
|
||||||
self.editing = True
|
self.editing = True
|
||||||
|
|
||||||
def validate_field_values(self, opt: Namespace) -> bool:
|
def validate_field_values(self, opt: Namespace) -> bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
if not opt.license_acceptance:
|
if not opt.license_acceptance:
|
||||||
bad_fields.append(
|
bad_fields.append("Please accept the license terms before proceeding to model downloads")
|
||||||
"Please accept the license terms before proceeding to model downloads"
|
|
||||||
)
|
|
||||||
if not Path(opt.outdir).parent.exists():
|
if not Path(opt.outdir).parent.exists():
|
||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||||
@ -478,11 +473,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
new_opts = Namespace()
|
new_opts = Namespace()
|
||||||
|
|
||||||
for attr in [
|
for attr in [
|
||||||
"outdir",
|
"outdir",
|
||||||
"free_gpu_mem",
|
"free_gpu_mem",
|
||||||
"max_cache_size",
|
"max_cache_size",
|
||||||
"xformers_enabled",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
@ -495,7 +490,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
new_opts.hf_token = self.hf_token.value
|
new_opts.hf_token = self.hf_token.value
|
||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
|
|
||||||
return new_opts
|
return new_opts
|
||||||
|
|
||||||
|
|
||||||
@ -534,19 +529,20 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
editApp.run()
|
editApp.run()
|
||||||
return editApp.new_opts()
|
return editApp.new_opts()
|
||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIAppConfig.get_config()
|
opts = InvokeAIAppConfig.get_config()
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
|
||||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installer = ModelInstall(config)
|
installer = ModelInstall(config)
|
||||||
except omegaconf.errors.ConfigKeyError:
|
except omegaconf.errors.ConfigKeyError:
|
||||||
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing')
|
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||||
initialize_rootdir(config.root_path, True)
|
initialize_rootdir(config.root_path, True)
|
||||||
installer = ModelInstall(config)
|
installer = ModelInstall(config)
|
||||||
|
|
||||||
models = installer.all_models()
|
models = installer.all_models()
|
||||||
return InstallSelections(
|
return InstallSelections(
|
||||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||||
@ -556,55 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
else list(),
|
else list(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||||
logger.info("Initializing InvokeAI runtime directory")
|
logger.info("Initializing InvokeAI runtime directory")
|
||||||
for name in (
|
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
|
||||||
"models",
|
|
||||||
"databases",
|
|
||||||
"text-inversion-output",
|
|
||||||
"text-inversion-training-data",
|
|
||||||
"configs"
|
|
||||||
):
|
|
||||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||||
for model_type in ModelType:
|
for model_type in ModelType:
|
||||||
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
configs_src = Path(configs.__path__[0])
|
configs_src = Path(configs.__path__[0])
|
||||||
configs_dest = root / "configs"
|
configs_dest = root / "configs"
|
||||||
if not os.path.samefile(configs_src, configs_dest):
|
if not os.path.samefile(configs_src, configs_dest):
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
dest = root / 'models'
|
dest = root / "models"
|
||||||
for model_base in BaseModelType:
|
for model_base in BaseModelType:
|
||||||
for model_type in ModelType:
|
for model_type in ModelType:
|
||||||
path = dest / model_base.value / model_type.value
|
path = dest / model_base.value / model_type.value
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
path = dest / 'core'
|
path = dest / "core"
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
maybe_create_models_yaml(root)
|
maybe_create_models_yaml(root)
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_models_yaml(root: Path):
|
def maybe_create_models_yaml(root: Path):
|
||||||
models_yaml = root / 'configs' / 'models.yaml'
|
models_yaml = root / "configs" / "models.yaml"
|
||||||
if models_yaml.exists():
|
if models_yaml.exists():
|
||||||
if OmegaConf.load(models_yaml).get('__metadata__'): # up to date
|
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info('Creating new models.yaml, original saved as models.yaml.orig')
|
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||||
models_yaml.rename(models_yaml.parent / 'models.yaml.orig')
|
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||||
|
|
||||||
with open(models_yaml,'w') as yaml_file:
|
with open(models_yaml, "w") as yaml_file:
|
||||||
yaml_file.write(yaml.dump({'__metadata__':
|
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||||
{'version':'3.0.0'}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def run_console_ui(
|
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||||
program_opts: Namespace, initfile: Path = None
|
|
||||||
) -> (Namespace, Namespace):
|
|
||||||
# parse_args() will read from init file if present
|
# parse_args() will read from init file if present
|
||||||
invokeai_opts = default_startup_options(initfile)
|
invokeai_opts = default_startup_options(initfile)
|
||||||
invokeai_opts.root = program_opts.root
|
invokeai_opts.root = program_opts.root
|
||||||
@ -616,8 +603,9 @@ def run_console_ui(
|
|||||||
# the install-models application spawns a subprocess to install
|
# the install-models application spawns a subprocess to install
|
||||||
# models, and will crash unless this is set before running.
|
# models, and will crash unless this is set before running.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||||
editApp.run()
|
editApp.run()
|
||||||
if editApp.user_cancelled:
|
if editApp.user_cancelled:
|
||||||
@ -634,39 +622,42 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
# this will load current settings
|
# this will load current settings
|
||||||
new_config = InvokeAIAppConfig.get_config()
|
new_config = InvokeAIAppConfig.get_config()
|
||||||
new_config.root = config.root
|
new_config.root = config.root
|
||||||
|
|
||||||
for key,value in opts.__dict__.items():
|
|
||||||
if hasattr(new_config,key):
|
|
||||||
setattr(new_config,key,value)
|
|
||||||
|
|
||||||
with open(init_file,'w', encoding='utf-8') as file:
|
for key, value in opts.__dict__.items():
|
||||||
|
if hasattr(new_config, key):
|
||||||
|
setattr(new_config, key, value)
|
||||||
|
|
||||||
|
with open(init_file, "w", encoding="utf-8") as file:
|
||||||
file.write(new_config.to_yaml())
|
file.write(new_config.to_yaml())
|
||||||
|
|
||||||
if hasattr(opts,'hf_token') and opts.hf_token:
|
if hasattr(opts, "hf_token") and opts.hf_token:
|
||||||
HfLogin(opts.hf_token)
|
HfLogin(opts.hf_token)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return config.root_path / "outputs"
|
return config.root_path / "outputs"
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||||
opt = default_startup_options(initfile)
|
opt = default_startup_options(initfile)
|
||||||
write_opts(opt, initfile)
|
write_opts(opt, initfile)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
# Here we bring in
|
# Here we bring in
|
||||||
# the legacy Args object in order to parse
|
# the legacy Args object in order to parse
|
||||||
# the old init file and write out the new
|
# the old init file and write out the new
|
||||||
# yaml format.
|
# yaml format.
|
||||||
def migrate_init_file(legacy_format:Path):
|
def migrate_init_file(legacy_format: Path):
|
||||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||||
new = InvokeAIAppConfig.get_config()
|
new = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
if hasattr(old,attr):
|
if hasattr(old, attr):
|
||||||
setattr(new,attr,getattr(old,attr))
|
setattr(new, attr, getattr(old, attr))
|
||||||
|
|
||||||
# a few places where the field names have changed and we have to
|
# a few places where the field names have changed and we have to
|
||||||
# manually add in the new names/values
|
# manually add in the new names/values
|
||||||
@ -674,40 +665,43 @@ def migrate_init_file(legacy_format:Path):
|
|||||||
new.conf_path = old.conf
|
new.conf_path = old.conf
|
||||||
new.root = legacy_format.parent.resolve()
|
new.root = legacy_format.parent.resolve()
|
||||||
|
|
||||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
invokeai_yaml = legacy_format.parent / "invokeai.yaml"
|
||||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(new.to_yaml())
|
outfile.write(new.to_yaml())
|
||||||
|
|
||||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
|
legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def migrate_models(root: Path):
|
def migrate_models(root: Path):
|
||||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||||
|
|
||||||
do_migrate(root, root)
|
do_migrate(root, root)
|
||||||
|
|
||||||
def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
|
||||||
# We check for to see if the runtime directory is correctly initialized.
|
|
||||||
old_init_file = root / 'invokeai.init'
|
|
||||||
new_init_file = root / 'invokeai.yaml'
|
|
||||||
old_hub = root / 'models/hub'
|
|
||||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
|
||||||
|
|
||||||
if migration_needed:
|
|
||||||
if opt.yes_to_all or \
|
|
||||||
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
|
|
||||||
|
|
||||||
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||||
|
# We check for to see if the runtime directory is correctly initialized.
|
||||||
|
old_init_file = root / "invokeai.init"
|
||||||
|
new_init_file = root / "invokeai.yaml"
|
||||||
|
old_hub = root / "models/hub"
|
||||||
|
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||||
|
|
||||||
|
if migration_needed:
|
||||||
|
if opt.yes_to_all or yes_or_no(
|
||||||
|
f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
|
||||||
|
):
|
||||||
|
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||||
migrate_init_file(old_init_file)
|
migrate_init_file(old_init_file)
|
||||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||||
|
|
||||||
if old_hub.exists():
|
if old_hub.exists():
|
||||||
migrate_models(config.root_path)
|
migrate_models(config.root_path)
|
||||||
else:
|
else:
|
||||||
print('Cannot continue without conversion. Aborting.')
|
print("Cannot continue without conversion. Aborting.")
|
||||||
|
|
||||||
return migration_needed
|
return migration_needed
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
@ -764,9 +758,9 @@ def main():
|
|||||||
|
|
||||||
invoke_args = []
|
invoke_args = []
|
||||||
if opt.root:
|
if opt.root:
|
||||||
invoke_args.extend(['--root',opt.root])
|
invoke_args.extend(["--root", opt.root])
|
||||||
if opt.full_precision:
|
if opt.full_precision:
|
||||||
invoke_args.extend(['--precision','float32'])
|
invoke_args.extend(["--precision", "float32"])
|
||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
@ -782,22 +776,18 @@ def main():
|
|||||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||||
|
|
||||||
models_to_download = default_user_selections(opt)
|
models_to_download = default_user_selections(opt)
|
||||||
new_init_file = config.root_path / 'invokeai.yaml'
|
new_init_file = config.root_path / "invokeai.yaml"
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, new_init_file)
|
write_default_options(opt, new_init_file)
|
||||||
init_options = Namespace(
|
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||||
precision="float32" if opt.full_precision else "float16"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||||
if init_options:
|
if init_options:
|
||||||
write_opts(init_options, new_init_file)
|
write_opts(init_options, new_init_file)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
|
||||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
|
||||||
)
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if opt.skip_support_models:
|
if opt.skip_support_models:
|
||||||
logger.info("Skipping support models at user's request")
|
logger.info("Skipping support models at user's request")
|
||||||
else:
|
else:
|
||||||
@ -811,7 +801,7 @@ def main():
|
|||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
if not opt.yes_to_all:
|
if not opt.yes_to_all:
|
||||||
input('Press any key to continue...')
|
input("Press any key to continue...")
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nGoodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
|
|
||||||
|
@ -47,17 +47,18 @@ PRECISION_CHOICES = [
|
|||||||
"float16",
|
"float16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FileArgumentParser(ArgumentParser):
|
class FileArgumentParser(ArgumentParser):
|
||||||
"""
|
"""
|
||||||
Supports reading defaults from an init file.
|
Supports reading defaults from an init file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert_arg_line_to_args(self, arg_line):
|
def convert_arg_line_to_args(self, arg_line):
|
||||||
return shlex.split(arg_line, comments=True)
|
return shlex.split(arg_line, comments=True)
|
||||||
|
|
||||||
|
|
||||||
legacy_parser = FileArgumentParser(
|
legacy_parser = FileArgumentParser(
|
||||||
description=
|
description="""
|
||||||
"""
|
|
||||||
Generate images using Stable Diffusion.
|
Generate images using Stable Diffusion.
|
||||||
Use --web to launch the web interface.
|
Use --web to launch the web interface.
|
||||||
Use --from_file to load prompts from a file path or standard input ("-").
|
Use --from_file to load prompts from a file path or standard input ("-").
|
||||||
@ -65,304 +66,279 @@ Generate images using Stable Diffusion.
|
|||||||
Other command-line arguments are defaults that can usually be overridden
|
Other command-line arguments are defaults that can usually be overridden
|
||||||
prompt the command prompt.
|
prompt the command prompt.
|
||||||
""",
|
""",
|
||||||
fromfile_prefix_chars='@',
|
fromfile_prefix_chars="@",
|
||||||
)
|
)
|
||||||
general_group = legacy_parser.add_argument_group('General')
|
general_group = legacy_parser.add_argument_group("General")
|
||||||
model_group = legacy_parser.add_argument_group('Model selection')
|
model_group = legacy_parser.add_argument_group("Model selection")
|
||||||
file_group = legacy_parser.add_argument_group('Input/output')
|
file_group = legacy_parser.add_argument_group("Input/output")
|
||||||
web_server_group = legacy_parser.add_argument_group('Web server')
|
web_server_group = legacy_parser.add_argument_group("Web server")
|
||||||
render_group = legacy_parser.add_argument_group('Rendering')
|
render_group = legacy_parser.add_argument_group("Rendering")
|
||||||
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
|
||||||
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
deprecated_group = legacy_parser.add_argument_group("Deprecated options")
|
||||||
|
|
||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument("--laion400m")
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument("--weights") # deprecated
|
||||||
general_group.add_argument(
|
general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
|
||||||
'--version','-V',
|
|
||||||
action='store_true',
|
|
||||||
help='Print InvokeAI version number'
|
|
||||||
)
|
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--root_dir',
|
"--root_dir",
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--config',
|
"--config",
|
||||||
'-c',
|
"-c",
|
||||||
'-config',
|
"-config",
|
||||||
dest='conf',
|
dest="conf",
|
||||||
default='./configs/models.yaml',
|
default="./configs/models.yaml",
|
||||||
help='Path to configuration file for alternate models.',
|
help="Path to configuration file for alternate models.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--model',
|
"--model",
|
||||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--weight_dirs',
|
"--weight_dirs",
|
||||||
nargs='+',
|
nargs="+",
|
||||||
type=str,
|
type=str,
|
||||||
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
help="List of one or more directories that will be auto-scanned for new model weights to import",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--png_compression','-z',
|
"--png_compression",
|
||||||
|
"-z",
|
||||||
type=int,
|
type=int,
|
||||||
default=6,
|
default=6,
|
||||||
choices=range(0,9),
|
choices=range(0, 9),
|
||||||
dest='png_compression',
|
dest="png_compression",
|
||||||
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'-F',
|
"-F",
|
||||||
'--full_precision',
|
"--full_precision",
|
||||||
dest='full_precision',
|
dest="full_precision",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Deprecated way to set --precision=float32',
|
help="Deprecated way to set --precision=float32",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--max_loaded_models',
|
"--max_loaded_models",
|
||||||
dest='max_loaded_models',
|
dest="max_loaded_models",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--free_gpu_mem',
|
"--free_gpu_mem",
|
||||||
dest='free_gpu_mem',
|
dest="free_gpu_mem",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Force free gpu memory before final decoding',
|
help="Force free gpu memory before final decoding",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--sequential_guidance',
|
"--sequential_guidance",
|
||||||
dest='sequential_guidance',
|
dest="sequential_guidance",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
|
||||||
"at the expense of speed",
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--xformers',
|
"--xformers",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='Enable/disable xformers support (default enabled if installed)',
|
help="Enable/disable xformers support (default enabled if installed)",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--always_use_cpu",
|
"--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
|
||||||
dest="always_use_cpu",
|
|
||||||
action="store_true",
|
|
||||||
help="Force use of CPU even if GPU is available"
|
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--precision',
|
"--precision",
|
||||||
dest='precision',
|
dest="precision",
|
||||||
type=str,
|
type=str,
|
||||||
choices=PRECISION_CHOICES,
|
choices=PRECISION_CHOICES,
|
||||||
metavar='PRECISION',
|
metavar="PRECISION",
|
||||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
default='auto',
|
default="auto",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--ckpt_convert',
|
"--ckpt_convert",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest='ckpt_convert',
|
dest="ckpt_convert",
|
||||||
default=True,
|
default=True,
|
||||||
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--internet',
|
"--internet",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest='internet_available',
|
dest="internet_available",
|
||||||
default=True,
|
default=True,
|
||||||
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--nsfw_checker',
|
"--nsfw_checker",
|
||||||
'--safety_checker',
|
"--safety_checker",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest='safety_checker',
|
dest="safety_checker",
|
||||||
default=False,
|
default=False,
|
||||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--autoimport',
|
"--autoimport",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--autoconvert',
|
"--autoconvert",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--patchmatch',
|
"--patchmatch",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--from_file',
|
"--from_file",
|
||||||
dest='infile',
|
dest="infile",
|
||||||
type=str,
|
type=str,
|
||||||
help='If specified, load prompts from this file',
|
help="If specified, load prompts from this file",
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--outdir',
|
"--outdir",
|
||||||
'-o',
|
"-o",
|
||||||
type=str,
|
type=str,
|
||||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
|
||||||
default='outputs',
|
default="outputs",
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--prompt_as_dir',
|
"--prompt_as_dir",
|
||||||
'-p',
|
"-p",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Place images in subdirectories named after the prompt.',
|
help="Place images in subdirectories named after the prompt.",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--fnformat',
|
"--fnformat",
|
||||||
default='{prefix}.{seed}.png',
|
default="{prefix}.{seed}.png",
|
||||||
type=str,
|
type=str,
|
||||||
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
|
||||||
)
|
)
|
||||||
|
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'-s',
|
"-W",
|
||||||
'--steps',
|
"--width",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
help="Image width, multiple of 64",
|
||||||
help='Number of steps'
|
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'-W',
|
"-H",
|
||||||
'--width',
|
"--height",
|
||||||
type=int,
|
type=int,
|
||||||
help='Image width, multiple of 64',
|
help="Image height, multiple of 64",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'-H',
|
"-C",
|
||||||
'--height',
|
"--cfg_scale",
|
||||||
type=int,
|
|
||||||
help='Image height, multiple of 64',
|
|
||||||
)
|
|
||||||
render_group.add_argument(
|
|
||||||
'-C',
|
|
||||||
'--cfg_scale',
|
|
||||||
default=7.5,
|
default=7.5,
|
||||||
type=float,
|
type=float,
|
||||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--sampler',
|
"--sampler",
|
||||||
'-A',
|
"-A",
|
||||||
'-m',
|
"-m",
|
||||||
dest='sampler_name',
|
dest="sampler_name",
|
||||||
type=str,
|
type=str,
|
||||||
choices=SAMPLER_CHOICES,
|
choices=SAMPLER_CHOICES,
|
||||||
metavar='SAMPLER_NAME',
|
metavar="SAMPLER_NAME",
|
||||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
default='k_lms',
|
default="k_lms",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--log_tokenization',
|
"--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
|
||||||
'-t',
|
|
||||||
action='store_true',
|
|
||||||
help='shows how the prompt is split into tokens'
|
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'-f',
|
"-f",
|
||||||
'--strength',
|
"--strength",
|
||||||
type=float,
|
type=float,
|
||||||
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'-T',
|
"-T",
|
||||||
'-fit',
|
"-fit",
|
||||||
'--fit',
|
"--fit",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--grid',
|
"--embedding_directory",
|
||||||
'-g',
|
"--embedding_path",
|
||||||
action=argparse.BooleanOptionalAction,
|
dest="embedding_path",
|
||||||
help='generate a grid'
|
default="embeddings",
|
||||||
)
|
|
||||||
render_group.add_argument(
|
|
||||||
'--embedding_directory',
|
|
||||||
'--embedding_path',
|
|
||||||
dest='embedding_path',
|
|
||||||
default='embeddings',
|
|
||||||
type=str,
|
type=str,
|
||||||
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--lora_directory',
|
"--lora_directory",
|
||||||
dest='lora_path',
|
dest="lora_path",
|
||||||
default='loras',
|
default="loras",
|
||||||
type=str,
|
type=str,
|
||||||
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--embeddings',
|
"--embeddings",
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
help="Enable embedding directory (default). Use --no-embeddings to disable.",
|
||||||
)
|
)
|
||||||
|
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--enable_image_debugging',
|
"--karras_max",
|
||||||
action='store_true',
|
|
||||||
help='Generates debugging image to display'
|
|
||||||
)
|
|
||||||
render_group.add_argument(
|
|
||||||
'--karras_max',
|
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
|
||||||
)
|
)
|
||||||
# Restoration related args
|
# Restoration related args
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--no_restore',
|
"--no_restore",
|
||||||
dest='restore',
|
dest="restore",
|
||||||
action='store_false',
|
action="store_false",
|
||||||
help='Disable face restoration with GFPGAN or codeformer',
|
help="Disable face restoration with GFPGAN or codeformer",
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--no_upscale',
|
"--no_upscale",
|
||||||
dest='esrgan',
|
dest="esrgan",
|
||||||
action='store_false',
|
action="store_false",
|
||||||
help='Disable upscaling with ESRGAN',
|
help="Disable upscaling with ESRGAN",
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--esrgan_bg_tile',
|
"--esrgan_bg_tile",
|
||||||
type=int,
|
type=int,
|
||||||
default=400,
|
default=400,
|
||||||
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--esrgan_denoise_str',
|
"--esrgan_denoise_str",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.75,
|
default=0.75,
|
||||||
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
'--gfpgan_model_path',
|
"--gfpgan_model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default='./models/gfpgan/GFPGANv1.4.pth',
|
default="./models/gfpgan/GFPGANv1.4.pth",
|
||||||
help='Indicates the path to the GFPGAN model',
|
help="Indicates the path to the GFPGAN model",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--web',
|
"--web",
|
||||||
dest='web',
|
dest="web",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Start in web server mode.',
|
help="Start in web server mode.",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--web_develop',
|
"--web_develop",
|
||||||
dest='web_develop',
|
dest="web_develop",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Start in web server development mode.',
|
help="Start in web server development mode.",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--web_verbose",
|
"--web_verbose",
|
||||||
@ -376,32 +352,27 @@ web_server_group.add_argument(
|
|||||||
help="Additional allowed origins, comma-separated",
|
help="Additional allowed origins, comma-separated",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--host',
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default='127.0.0.1',
|
default="127.0.0.1",
|
||||||
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
|
||||||
)
|
)
|
||||||
|
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--port',
|
"--certfile",
|
||||||
type=int,
|
|
||||||
default='9090',
|
|
||||||
help='Web server: Port to listen on'
|
|
||||||
)
|
|
||||||
web_server_group.add_argument(
|
|
||||||
'--certfile',
|
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--keyfile',
|
"--keyfile",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
help="Web server: Path to private key file to use for SSL. Use together with --certfile",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
'--gui',
|
"--gui",
|
||||||
dest='gui',
|
dest="gui",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help='Start InvokeAI GUI',
|
help="Start InvokeAI GUI",
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
'''
|
"""
|
||||||
Migrate the models directory and models.yaml file from an existing
|
Migrate the models directory and models.yaml file from an existing
|
||||||
InvokeAI 2.3 installation to 3.0.0.
|
InvokeAI 2.3 installation to 3.0.0.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
@ -29,14 +29,13 @@ from transformers import (
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import ModelManager
|
from invokeai.backend.model_management import ModelManager
|
||||||
from invokeai.backend.model_management.model_probe import (
|
from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||||
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# holder for paths that we will migrate
|
# holder for paths that we will migrate
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelPaths:
|
class ModelPaths:
|
||||||
@ -45,81 +44,82 @@ class ModelPaths:
|
|||||||
loras: Path
|
loras: Path
|
||||||
controlnets: Path
|
controlnets: Path
|
||||||
|
|
||||||
|
|
||||||
class MigrateTo3(object):
|
class MigrateTo3(object):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
from_root: Path,
|
self,
|
||||||
to_models: Path,
|
from_root: Path,
|
||||||
model_manager: ModelManager,
|
to_models: Path,
|
||||||
src_paths: ModelPaths,
|
model_manager: ModelManager,
|
||||||
):
|
src_paths: ModelPaths,
|
||||||
|
):
|
||||||
self.root_directory = from_root
|
self.root_directory = from_root
|
||||||
self.dest_models = to_models
|
self.dest_models = to_models
|
||||||
self.mgr = model_manager
|
self.mgr = model_manager
|
||||||
self.src_paths = src_paths
|
self.src_paths = src_paths
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_yaml(cls, yaml_file: Path):
|
def initialize_yaml(cls, yaml_file: Path):
|
||||||
with open(yaml_file, 'w') as file:
|
with open(yaml_file, "w") as file:
|
||||||
file.write(
|
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||||
yaml.dump(
|
|
||||||
{
|
|
||||||
'__metadata__': {'version':'3.0.0'}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_directory_structure(self):
|
def create_directory_structure(self):
|
||||||
'''
|
"""
|
||||||
Create the basic directory structure for the models folder.
|
Create the basic directory structure for the models folder.
|
||||||
'''
|
"""
|
||||||
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||||
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
|
for model_type in [
|
||||||
ModelType.ControlNet,ModelType.TextualInversion]:
|
ModelType.Main,
|
||||||
|
ModelType.Vae,
|
||||||
|
ModelType.Lora,
|
||||||
|
ModelType.ControlNet,
|
||||||
|
ModelType.TextualInversion,
|
||||||
|
]:
|
||||||
path = self.dest_models / model_base.value / model_type.value
|
path = self.dest_models / model_base.value / model_type.value
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
path = self.dest_models / 'core'
|
path = self.dest_models / "core"
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_file(src:Path,dest:Path):
|
def copy_file(src: Path, dest: Path):
|
||||||
'''
|
"""
|
||||||
copy a single file with logging
|
copy a single file with logging
|
||||||
'''
|
"""
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
logger.info(f'Skipping existing {str(dest)}')
|
logger.info(f"Skipping existing {str(dest)}")
|
||||||
return
|
return
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||||
try:
|
try:
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'COPY FAILED: {str(e)}')
|
logger.error(f"COPY FAILED: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_dir(src:Path,dest:Path):
|
def copy_dir(src: Path, dest: Path):
|
||||||
'''
|
"""
|
||||||
Recursively copy a directory with logging
|
Recursively copy a directory with logging
|
||||||
'''
|
"""
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
logger.info(f'Skipping existing {str(dest)}')
|
logger.info(f"Skipping existing {str(dest)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||||
try:
|
try:
|
||||||
shutil.copytree(src, dest)
|
shutil.copytree(src, dest)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'COPY FAILED: {str(e)}')
|
logger.error(f"COPY FAILED: {str(e)}")
|
||||||
|
|
||||||
def migrate_models(self, src_dir: Path):
|
def migrate_models(self, src_dir: Path):
|
||||||
'''
|
"""
|
||||||
Recursively walk through src directory, probe anything
|
Recursively walk through src directory, probe anything
|
||||||
that looks like a model, and copy the model into the
|
that looks like a model, and copy the model into the
|
||||||
appropriate location within the destination models directory.
|
appropriate location within the destination models directory.
|
||||||
'''
|
"""
|
||||||
directories_scanned = set()
|
directories_scanned = set()
|
||||||
for root, dirs, files in os.walk(src_dir):
|
for root, dirs, files in os.walk(src_dir):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
try:
|
try:
|
||||||
model = Path(root,d)
|
model = Path(root, d)
|
||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
if not info:
|
if not info:
|
||||||
continue
|
continue
|
||||||
@ -136,9 +136,9 @@ class MigrateTo3(object):
|
|||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||||
# let them be copied as part of a tree copy operation
|
# let them be copied as part of a tree copy operation
|
||||||
try:
|
try:
|
||||||
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
|
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||||
continue
|
continue
|
||||||
model = Path(root,f)
|
model = Path(root, f)
|
||||||
if model.parent in directories_scanned:
|
if model.parent in directories_scanned:
|
||||||
continue
|
continue
|
||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
@ -154,148 +154,146 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
'''
|
"""
|
||||||
Copy the clipseg, upscaler, and restoration models to their new
|
Copy the clipseg, upscaler, and restoration models to their new
|
||||||
locations.
|
locations.
|
||||||
'''
|
"""
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
if (self.root_directory / 'models/clipseg').exists():
|
if (self.root_directory / "models/clipseg").exists():
|
||||||
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
|
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
||||||
if (self.root_directory / 'models/realesrgan').exists():
|
if (self.root_directory / "models/realesrgan").exists():
|
||||||
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
|
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
||||||
for d in ['codeformer','gfpgan']:
|
for d in ["codeformer", "gfpgan"]:
|
||||||
path = self.root_directory / 'models' / d
|
path = self.root_directory / "models" / d
|
||||||
if path.exists():
|
if path.exists():
|
||||||
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
|
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
||||||
|
|
||||||
def migrate_tuning_models(self):
|
def migrate_tuning_models(self):
|
||||||
'''
|
"""
|
||||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||||
'''
|
"""
|
||||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||||
if not src:
|
if not src:
|
||||||
continue
|
continue
|
||||||
if src.is_dir():
|
if src.is_dir():
|
||||||
logger.info(f'Scanning {src}')
|
logger.info(f"Scanning {src}")
|
||||||
self.migrate_models(src)
|
self.migrate_models(src)
|
||||||
else:
|
else:
|
||||||
logger.info(f'{src} directory not found; skipping')
|
logger.info(f"{src} directory not found; skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def migrate_conversion_models(self):
|
def migrate_conversion_models(self):
|
||||||
'''
|
"""
|
||||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||||
script.
|
script.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
cache_dir = self.root_directory / 'models/hub',
|
cache_dir=self.root_directory / "models/hub",
|
||||||
#local_files_only = True
|
# local_files_only = True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
logger.info('Migrating core tokenizers and text encoders')
|
logger.info("Migrating core tokenizers and text encoders")
|
||||||
target_dir = dest_directory / 'core' / 'convert'
|
target_dir = dest_directory / "core" / "convert"
|
||||||
|
|
||||||
self._migrate_pretrained(BertTokenizerFast,
|
self._migrate_pretrained(
|
||||||
repo_id='bert-base-uncased',
|
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
||||||
dest = target_dir / 'bert-base-uncased',
|
)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = "openai/clip-vit-large-patch14"
|
||||||
self._migrate_pretrained(CLIPTokenizer,
|
self._migrate_pretrained(
|
||||||
repo_id= repo_id,
|
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
||||||
dest= target_dir / 'clip-vit-large-patch14',
|
)
|
||||||
**kwargs)
|
self._migrate_pretrained(
|
||||||
self._migrate_pretrained(CLIPTextModel,
|
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
||||||
repo_id = repo_id,
|
)
|
||||||
dest = target_dir / 'clip-vit-large-patch14',
|
|
||||||
force = True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
self._migrate_pretrained(CLIPTokenizer,
|
self._migrate_pretrained(
|
||||||
repo_id = repo_id,
|
CLIPTokenizer,
|
||||||
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
repo_id=repo_id,
|
||||||
**{'subfolder':'tokenizer',**kwargs}
|
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
||||||
)
|
**{"subfolder": "tokenizer", **kwargs},
|
||||||
self._migrate_pretrained(CLIPTextModel,
|
)
|
||||||
repo_id = repo_id,
|
self._migrate_pretrained(
|
||||||
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
CLIPTextModel,
|
||||||
**{'subfolder':'text_encoder',**kwargs}
|
repo_id=repo_id,
|
||||||
)
|
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
||||||
|
**{"subfolder": "text_encoder", **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info('Migrating stable diffusion VAE')
|
logger.info("Migrating stable diffusion VAE")
|
||||||
self._migrate_pretrained(AutoencoderKL,
|
self._migrate_pretrained(
|
||||||
repo_id = 'stabilityai/sd-vae-ft-mse',
|
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
||||||
dest = target_dir / 'sd-vae-ft-mse',
|
)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info('Migrating safety checker')
|
logger.info("Migrating safety checker")
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
self._migrate_pretrained(AutoFeatureExtractor,
|
self._migrate_pretrained(
|
||||||
repo_id = repo_id,
|
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
||||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
)
|
||||||
**kwargs)
|
self._migrate_pretrained(
|
||||||
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
StableDiffusionSafetyChecker,
|
||||||
repo_id = repo_id,
|
repo_id=repo_id,
|
||||||
dest = target_dir / 'stable-diffusion-safety-checker',
|
dest=target_dir / "stable-diffusion-safety-checker",
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
||||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||||
|
|
||||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
|
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
||||||
if dest.exists() and not force:
|
if dest.exists() and not force:
|
||||||
logger.info(f'Skipping existing {dest}')
|
logger.info(f"Skipping existing {dest}")
|
||||||
return
|
return
|
||||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||||
self._save_pretrained(model, dest, overwrite=force)
|
self._save_pretrained(model, dest, overwrite=force)
|
||||||
|
|
||||||
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
|
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
||||||
model_name = dest.name
|
model_name = dest.name
|
||||||
if overwrite:
|
if overwrite:
|
||||||
model.save_pretrained(dest, safe_serialization=True)
|
model.save_pretrained(dest, safe_serialization=True)
|
||||||
else:
|
else:
|
||||||
download_path = dest.with_name(f'{model_name}.downloading')
|
download_path = dest.with_name(f"{model_name}.downloading")
|
||||||
model.save_pretrained(download_path, safe_serialization=True)
|
model.save_pretrained(download_path, safe_serialization=True)
|
||||||
download_path.replace(dest)
|
download_path.replace(dest)
|
||||||
|
|
||||||
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||||
info = ModelProbe().heuristic_probe(vae)
|
info = ModelProbe().heuristic_probe(vae)
|
||||||
_, model_name = repo_id.split('/')
|
_, model_name = repo_id.split("/")
|
||||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||||
vae.save_pretrained(dest, safe_serialization=True)
|
vae.save_pretrained(dest, safe_serialization=True)
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
def _vae_path(self, vae: Union[str,dict])->Path:
|
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||||
'''
|
"""
|
||||||
Convert 2.3 VAE stanza to a straight path.
|
Convert 2.3 VAE stanza to a straight path.
|
||||||
'''
|
"""
|
||||||
vae_path = None
|
vae_path = None
|
||||||
|
|
||||||
# First get a path
|
# First get a path
|
||||||
if isinstance(vae,str):
|
if isinstance(vae, str):
|
||||||
vae_path = vae
|
vae_path = vae
|
||||||
|
|
||||||
elif isinstance(vae,DictConfig):
|
elif isinstance(vae, DictConfig):
|
||||||
if p := vae.get('path'):
|
if p := vae.get("path"):
|
||||||
vae_path = p
|
vae_path = p
|
||||||
elif repo_id := vae.get('repo_id'):
|
elif repo_id := vae.get("repo_id"):
|
||||||
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||||
vae_path = 'models/core/convert/sd-vae-ft-mse'
|
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||||
return vae_path
|
return vae_path
|
||||||
else:
|
else:
|
||||||
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||||
|
|
||||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||||
|
|
||||||
@ -307,152 +305,144 @@ class MigrateTo3(object):
|
|||||||
dest = self._model_probe_to_path(info) / vae_path.name
|
dest = self._model_probe_to_path(info) / vae_path.name
|
||||||
if not dest.exists():
|
if not dest.exists():
|
||||||
if vae_path.is_dir():
|
if vae_path.is_dir():
|
||||||
self.copy_dir(vae_path,dest)
|
self.copy_dir(vae_path, dest)
|
||||||
else:
|
else:
|
||||||
self.copy_file(vae_path,dest)
|
self.copy_file(vae_path, dest)
|
||||||
vae_path = dest
|
vae_path = dest
|
||||||
|
|
||||||
if vae_path.is_relative_to(self.dest_models):
|
if vae_path.is_relative_to(self.dest_models):
|
||||||
rel_path = vae_path.relative_to(self.dest_models)
|
rel_path = vae_path.relative_to(self.dest_models)
|
||||||
return Path('models',rel_path)
|
return Path("models", rel_path)
|
||||||
else:
|
else:
|
||||||
return vae_path
|
return vae_path
|
||||||
|
|
||||||
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
|
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
||||||
'''
|
"""
|
||||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||||
'''
|
"""
|
||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
cache = self.root_directory / 'models/hub'
|
cache = self.root_directory / "models/hub"
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
cache_dir = cache,
|
cache_dir=cache,
|
||||||
safety_checker = None,
|
safety_checker=None,
|
||||||
# local_files_only = True,
|
# local_files_only = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
owner,repo_name = repo_id.split('/')
|
owner, repo_name = repo_id.split("/")
|
||||||
model_name = model_name or repo_name
|
model_name = model_name or repo_name
|
||||||
model = cache / '--'.join(['models',owner,repo_name])
|
model = cache / "--".join(["models", owner, repo_name])
|
||||||
|
|
||||||
if len(list(model.glob('snapshots/**/model_index.json')))==0:
|
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
||||||
return
|
return
|
||||||
revisions = [x.name for x in model.glob('refs/*')]
|
revisions = [x.name for x in model.glob("refs/*")]
|
||||||
|
|
||||||
# if an fp16 is available we use that
|
# if an fp16 is available we use that
|
||||||
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
|
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
||||||
repo_id,
|
|
||||||
revision=revision,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(pipeline)
|
info = ModelProbe().heuristic_probe(pipeline)
|
||||||
if not info:
|
if not info:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||||
return
|
return
|
||||||
|
|
||||||
dest = self._model_probe_to_path(info) / model_name
|
dest = self._model_probe_to_path(info) / model_name
|
||||||
self._save_pretrained(pipeline, dest)
|
self._save_pretrained(pipeline, dest)
|
||||||
|
|
||||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
rel_path = Path("models", dest.relative_to(dest_dir))
|
||||||
self._add_model(model_name, info, rel_path, **extra_config)
|
self._add_model(model_name, info, rel_path, **extra_config)
|
||||||
|
|
||||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
||||||
'''
|
"""
|
||||||
Migrate a model referred to using 'weights' or 'path'
|
Migrate a model referred to using 'weights' or 'path'
|
||||||
'''
|
"""
|
||||||
|
|
||||||
# handle relative paths
|
# handle relative paths
|
||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
location = self.root_directory / location
|
location = self.root_directory / location
|
||||||
model_name = model_name or location.stem
|
model_name = model_name or location.stem
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location)
|
||||||
if not info:
|
if not info:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# uh oh, weights is in the old models directory - move it into the new one
|
# uh oh, weights is in the old models directory - move it into the new one
|
||||||
if Path(location).is_relative_to(self.src_paths.models):
|
if Path(location).is_relative_to(self.src_paths.models):
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||||
if location.is_dir():
|
if location.is_dir():
|
||||||
self.copy_dir(location,dest)
|
self.copy_dir(location, dest)
|
||||||
else:
|
else:
|
||||||
self.copy_file(location,dest)
|
self.copy_file(location, dest)
|
||||||
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
||||||
|
|
||||||
self._add_model(model_name, info, location, **extra_config)
|
self._add_model(model_name, info, location, **extra_config)
|
||||||
|
|
||||||
def _add_model(self,
|
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
||||||
model_name: str,
|
|
||||||
info: ModelProbeInfo,
|
|
||||||
location: Path,
|
|
||||||
**extra_config):
|
|
||||||
if info.model_type != ModelType.Main:
|
if info.model_type != ModelType.Main:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.mgr.add_model(
|
|
||||||
model_name = model_name,
|
|
||||||
base_model = info.base_type,
|
|
||||||
model_type = info.model_type,
|
|
||||||
clobber = True,
|
|
||||||
model_attributes = {
|
|
||||||
'path': str(location),
|
|
||||||
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
|
||||||
'model_format': info.format,
|
|
||||||
'variant': info.variant_type.value,
|
|
||||||
**extra_config,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def migrate_defined_models(self):
|
|
||||||
'''
|
|
||||||
Migrate models defined in models.yaml
|
|
||||||
'''
|
|
||||||
# find any models referred to in old models.yaml
|
|
||||||
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
|
|
||||||
|
|
||||||
for model_name, stanza in conf.items():
|
|
||||||
|
|
||||||
|
self.mgr.add_model(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=info.base_type,
|
||||||
|
model_type=info.model_type,
|
||||||
|
clobber=True,
|
||||||
|
model_attributes={
|
||||||
|
"path": str(location),
|
||||||
|
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
||||||
|
"model_format": info.format,
|
||||||
|
"variant": info.variant_type.value,
|
||||||
|
**extra_config,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def migrate_defined_models(self):
|
||||||
|
"""
|
||||||
|
Migrate models defined in models.yaml
|
||||||
|
"""
|
||||||
|
# find any models referred to in old models.yaml
|
||||||
|
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
||||||
|
|
||||||
|
for model_name, stanza in conf.items():
|
||||||
try:
|
try:
|
||||||
passthru_args = {}
|
passthru_args = {}
|
||||||
|
|
||||||
if vae := stanza.get('vae'):
|
if vae := stanza.get("vae"):
|
||||||
try:
|
try:
|
||||||
passthru_args['vae'] = str(self._vae_path(vae))
|
passthru_args["vae"] = str(self._vae_path(vae))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
if config := stanza.get('config'):
|
if config := stanza.get("config"):
|
||||||
passthru_args['config'] = config
|
passthru_args["config"] = config
|
||||||
|
|
||||||
if description:= stanza.get('description'):
|
if description := stanza.get("description"):
|
||||||
passthru_args['description'] = description
|
passthru_args["description"] = description
|
||||||
|
|
||||||
if repo_id := stanza.get('repo_id'):
|
if repo_id := stanza.get("repo_id"):
|
||||||
logger.info(f'Migrating diffusers model {model_name}')
|
logger.info(f"Migrating diffusers model {model_name}")
|
||||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get('weights'):
|
elif location := stanza.get("weights"):
|
||||||
logger.info(f'Migrating checkpoint model {model_name}')
|
logger.info(f"Migrating checkpoint model {model_name}")
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get('path'):
|
elif location := stanza.get("path"):
|
||||||
logger.info(f'Migrating diffusers model {model_name}')
|
logger.info(f"Migrating diffusers model {model_name}")
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self):
|
||||||
self.create_directory_structure()
|
self.create_directory_structure()
|
||||||
# the configure script is doing this
|
# the configure script is doing this
|
||||||
@ -461,67 +451,71 @@ class MigrateTo3(object):
|
|||||||
self.migrate_tuning_models()
|
self.migrate_tuning_models()
|
||||||
self.migrate_defined_models()
|
self.migrate_defined_models()
|
||||||
|
|
||||||
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
|
|
||||||
'''
|
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
||||||
|
"""
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
'''
|
"""
|
||||||
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
|
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--embedding_directory',
|
"--embedding_directory",
|
||||||
'--embedding_path',
|
"--embedding_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
dest='embedding_path',
|
dest="embedding_path",
|
||||||
default=Path('embeddings'),
|
default=Path("embeddings"),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--lora_directory',
|
"--lora_directory",
|
||||||
dest='lora_path',
|
dest="lora_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path('loras'),
|
default=Path("loras"),
|
||||||
)
|
)
|
||||||
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
|
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
||||||
return ModelPaths(
|
return ModelPaths(
|
||||||
models = root / 'models',
|
models=root / "models",
|
||||||
embeddings = root / str(opt.embedding_path).strip('"'),
|
embeddings=root / str(opt.embedding_path).strip('"'),
|
||||||
loras = root / str(opt.lora_path).strip('"'),
|
loras=root / str(opt.lora_path).strip('"'),
|
||||||
controlnets = root / 'controlnets',
|
controlnets=root / "controlnets",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
|
||||||
'''
|
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
||||||
|
"""
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
'''
|
"""
|
||||||
# Don't use the config object because it is unforgiving of version updates
|
# Don't use the config object because it is unforgiving of version updates
|
||||||
# Just use omegaconf directly
|
# Just use omegaconf directly
|
||||||
opt = OmegaConf.load(initfile)
|
opt = OmegaConf.load(initfile)
|
||||||
paths = opt.InvokeAI.Paths
|
paths = opt.InvokeAI.Paths
|
||||||
models = paths.get('models_dir','models')
|
models = paths.get("models_dir", "models")
|
||||||
embeddings = paths.get('embedding_dir','embeddings')
|
embeddings = paths.get("embedding_dir", "embeddings")
|
||||||
loras = paths.get('lora_dir','loras')
|
loras = paths.get("lora_dir", "loras")
|
||||||
controlnets = paths.get('controlnet_dir','controlnets')
|
controlnets = paths.get("controlnet_dir", "controlnets")
|
||||||
return ModelPaths(
|
return ModelPaths(
|
||||||
models = root / models,
|
models=root / models,
|
||||||
embeddings = root / embeddings,
|
embeddings=root / embeddings,
|
||||||
loras = root /loras,
|
loras=root / loras,
|
||||||
controlnets = root / controlnets,
|
controlnets=root / controlnets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||||
path = root / 'invokeai.init'
|
path = root / "invokeai.init"
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return _parse_legacy_initfile(root, path)
|
return _parse_legacy_initfile(root, path)
|
||||||
path = root / 'invokeai.yaml'
|
path = root / "invokeai.yaml"
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return _parse_legacy_yamlfile(root, path)
|
return _parse_legacy_yamlfile(root, path)
|
||||||
|
|
||||||
|
|
||||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||||
"""
|
"""
|
||||||
Migrate models from src to dest InvokeAI root directories
|
Migrate models from src to dest InvokeAI root directories
|
||||||
"""
|
"""
|
||||||
config_file = dest_directory / 'configs' / 'models.yaml.3'
|
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||||
dest_models = dest_directory / 'models.3'
|
dest_models = dest_directory / "models.3"
|
||||||
|
|
||||||
version_3 = (dest_directory / 'models' / 'core').exists()
|
version_3 = (dest_directory / "models" / "core").exists()
|
||||||
|
|
||||||
# Here we create the destination models.yaml file.
|
# Here we create the destination models.yaml file.
|
||||||
# If we are writing into a version 3 directory and the
|
# If we are writing into a version 3 directory and the
|
||||||
@ -530,80 +524,80 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
|||||||
# create a new empty one.
|
# create a new empty one.
|
||||||
if version_3: # write into the dest directory
|
if version_3: # write into the dest directory
|
||||||
try:
|
try:
|
||||||
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
|
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||||
except:
|
except:
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
MigrateTo3.initialize_yaml(config_file)
|
||||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||||
(dest_directory / 'models').replace(dest_models)
|
(dest_directory / "models").replace(dest_models)
|
||||||
else:
|
else:
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
MigrateTo3.initialize_yaml(config_file)
|
||||||
mgr = ModelManager(config_file)
|
mgr = ModelManager(config_file)
|
||||||
|
|
||||||
paths = get_legacy_embeddings(src_directory)
|
paths = get_legacy_embeddings(src_directory)
|
||||||
migrator = MigrateTo3(
|
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||||
from_root = src_directory,
|
|
||||||
to_models = dest_models,
|
|
||||||
model_manager = mgr,
|
|
||||||
src_paths = paths
|
|
||||||
)
|
|
||||||
migrator.migrate()
|
migrator.migrate()
|
||||||
print("Migration successful.")
|
print("Migration successful.")
|
||||||
|
|
||||||
if not version_3:
|
if not version_3:
|
||||||
(dest_directory / 'models').replace(src_directory / 'models.orig')
|
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||||
print(f'Original models directory moved to {dest_directory}/models.orig')
|
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||||
|
|
||||||
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
|
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||||
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
|
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||||
|
|
||||||
config_file.replace(config_file.with_suffix(''))
|
config_file.replace(config_file.with_suffix(""))
|
||||||
dest_models.replace(dest_models.with_suffix(''))
|
dest_models.replace(dest_models.with_suffix(""))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
parser = argparse.ArgumentParser(
|
||||||
description="""
|
prog="invokeai-migrate3",
|
||||||
|
description="""
|
||||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||||
|
|
||||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||||
script, which will perform a full upgrade in place."""
|
script, which will perform a full upgrade in place.""",
|
||||||
)
|
)
|
||||||
parser.add_argument('--from-directory',
|
parser.add_argument(
|
||||||
dest='src_root',
|
"--from-directory",
|
||||||
type=Path,
|
dest="src_root",
|
||||||
required=True,
|
type=Path,
|
||||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
required=True,
|
||||||
)
|
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
||||||
parser.add_argument('--to-directory',
|
)
|
||||||
dest='dest_root',
|
parser.add_argument(
|
||||||
type=Path,
|
"--to-directory",
|
||||||
required=True,
|
dest="dest_root",
|
||||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
type=Path,
|
||||||
)
|
required=True,
|
||||||
|
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_root = args.src_root
|
src_root = args.src_root
|
||||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||||
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||||
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
|
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||||
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
|
assert (src_root / "invokeai.init").exists() or (
|
||||||
|
src_root / "invokeai.yaml"
|
||||||
|
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||||
|
|
||||||
dest_root = args.dest_root
|
dest_root = args.dest_root
|
||||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
config.parse_args(['--root',str(dest_root)])
|
config.parse_args(["--root", str(dest_root)])
|
||||||
|
|
||||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||||
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
|
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||||
if not dest_is_setup:
|
if not dest_is_setup:
|
||||||
import invokeai.frontend.install.invokeai_configure
|
import invokeai.frontend.install.invokeai_configure
|
||||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||||
|
|
||||||
initialize_rootdir(dest_root, True)
|
initialize_rootdir(dest_root, True)
|
||||||
|
|
||||||
do_migrate(src_root,dest_root)
|
do_migrate(src_root, dest_root)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ Utility (backend) functions used by model_install.py
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass,field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List, Dict, Callable, Union, Set
|
from typing import List, Dict, Callable, Union, Set
|
||||||
@ -28,7 +28,7 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
logger = InvokeAILogger.getLogger(name="InvokeAI")
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||||
@ -45,59 +45,63 @@ Config_preamble = """
|
|||||||
|
|
||||||
LEGACY_CONFIGS = {
|
LEGACY_CONFIGS = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelVariantType.Normal: 'v1-inference.yaml',
|
ModelVariantType.Normal: "v1-inference.yaml",
|
||||||
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelVariantType.Normal: {
|
ModelVariantType.Normal: {
|
||||||
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
|
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||||
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
|
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||||
},
|
},
|
||||||
ModelVariantType.Inpaint: {
|
ModelVariantType.Inpaint: {
|
||||||
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||||
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
BaseModelType.StableDiffusionXL: {
|
||||||
ModelVariantType.Normal: 'sd_xl_base.yaml',
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInstallList:
|
class ModelInstallList:
|
||||||
'''Class for listing models to be installed/removed'''
|
"""Class for listing models to be installed/removed"""
|
||||||
|
|
||||||
install_models: List[str] = field(default_factory=list)
|
install_models: List[str] = field(default_factory=list)
|
||||||
remove_models: List[str] = field(default_factory=list)
|
remove_models: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InstallSelections():
|
|
||||||
install_models: List[str]= field(default_factory=list)
|
|
||||||
remove_models: List[str]=field(default_factory=list)
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelLoadInfo():
|
class InstallSelections:
|
||||||
|
install_models: List[str] = field(default_factory=list)
|
||||||
|
remove_models: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelLoadInfo:
|
||||||
name: str
|
name: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
path: Path = None
|
path: Path = None
|
||||||
repo_id: str = None
|
repo_id: str = None
|
||||||
description: str = ''
|
description: str = ""
|
||||||
installed: bool = False
|
installed: bool = False
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
default: bool = False
|
default: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelInstall(object):
|
class ModelInstall(object):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
config:InvokeAIAppConfig,
|
self,
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
config: InvokeAIAppConfig,
|
||||||
model_manager: ModelManager = None,
|
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||||
access_token:str = None):
|
model_manager: ModelManager = None,
|
||||||
|
access_token: str = None,
|
||||||
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
@ -105,66 +109,66 @@ class ModelInstall(object):
|
|||||||
self.access_token = access_token or HfFolder.get_token()
|
self.access_token = access_token or HfFolder.get_token()
|
||||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||||
|
|
||||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
||||||
'''
|
"""
|
||||||
Return dict of model_key=>ModelLoadInfo objects.
|
Return dict of model_key=>ModelLoadInfo objects.
|
||||||
This method consolidates and simplifies the entries in both
|
This method consolidates and simplifies the entries in both
|
||||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||||
be treated uniformly. It also sorts the models alphabetically
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
by their name, to improve the display somewhat.
|
by their name, to improve the display somewhat.
|
||||||
'''
|
"""
|
||||||
model_dict = dict()
|
model_dict = dict()
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name,base,model_type = ModelManager.parse_key(key)
|
name, base, model_type = ModelManager.parse_key(key)
|
||||||
value['name'] = name
|
value["name"] = name
|
||||||
value['base_type'] = base
|
value["base_type"] = base
|
||||||
value['model_type'] = model_type
|
value["model_type"] = model_type
|
||||||
model_dict[key] = ModelLoadInfo(**value)
|
model_dict[key] = ModelLoadInfo(**value)
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = self.mgr.list_models()
|
installed_models = self.mgr.list_models()
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md['base_model']
|
base = md["base_model"]
|
||||||
model_type = md['model_type']
|
model_type = md["model_type"]
|
||||||
name = md['model_name']
|
name = md["model_name"]
|
||||||
key = ModelManager.create_key(name, base, model_type)
|
key = ModelManager.create_key(name, base, model_type)
|
||||||
if key in model_dict:
|
if key in model_dict:
|
||||||
model_dict[key].installed = True
|
model_dict[key].installed = True
|
||||||
else:
|
else:
|
||||||
model_dict[key] = ModelLoadInfo(
|
model_dict[key] = ModelLoadInfo(
|
||||||
name = name,
|
name=name,
|
||||||
base_type = base,
|
base_type=base,
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
path = value.get('path'),
|
path=value.get("path"),
|
||||||
installed = True,
|
installed=True,
|
||||||
)
|
)
|
||||||
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
||||||
|
|
||||||
def list_models(self, model_type):
|
def list_models(self, model_type):
|
||||||
installed = self.mgr.list_models(model_type=model_type)
|
installed = self.mgr.list_models(model_type=model_type)
|
||||||
print(f'Installed models of type `{model_type}`:')
|
print(f"Installed models of type `{model_type}`:")
|
||||||
for i in installed:
|
for i in installed:
|
||||||
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||||
|
|
||||||
# logic here a little reversed to maintain backward compatibility
|
# logic here a little reversed to maintain backward compatibility
|
||||||
def starter_models(self, all_models: bool=False)->Set[str]:
|
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name,base,model_type = ModelManager.parse_key(key)
|
name, base, model_type = ModelManager.parse_key(key)
|
||||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||||
models.add(key)
|
models.add(key)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def recommended_models(self)->Set[str]:
|
def recommended_models(self) -> Set[str]:
|
||||||
starters = self.starter_models(all_models=True)
|
starters = self.starter_models(all_models=True)
|
||||||
return set([x for x in starters if self.datasets[x].get('recommended',False)])
|
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||||
|
|
||||||
def default_model(self)->str:
|
def default_model(self) -> str:
|
||||||
starters = self.starter_models()
|
starters = self.starter_models()
|
||||||
defaults = [x for x in starters if self.datasets[x].get('default',False)]
|
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
||||||
return defaults[0]
|
return defaults[0]
|
||||||
|
|
||||||
def install(self, selections: InstallSelections):
|
def install(self, selections: InstallSelections):
|
||||||
@ -173,54 +177,57 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
job = 1
|
job = 1
|
||||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||||
|
|
||||||
# remove requested models
|
# remove requested models
|
||||||
for key in selections.remove_models:
|
for key in selections.remove_models:
|
||||||
name,base,mtype = self.mgr.parse_key(key)
|
name, base, mtype = self.mgr.parse_key(key)
|
||||||
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
||||||
try:
|
try:
|
||||||
self.mgr.del_model(name,base,mtype)
|
self.mgr.del_model(name, base, mtype)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f'Installing {path} [{job}/{jobs}]')
|
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||||
try:
|
try:
|
||||||
self.heuristic_import(path)
|
self.heuristic_import(path)
|
||||||
except (ValueError, KeyError) as e:
|
except (ValueError, KeyError) as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
dlogging.set_verbosity(verbosity)
|
dlogging.set_verbosity(verbosity)
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(
|
||||||
model_path_id_or_url: Union[str,Path],
|
self,
|
||||||
models_installed: Set[Path]=None,
|
model_path_id_or_url: Union[str, Path],
|
||||||
)->Dict[str, AddModelResult]:
|
models_installed: Set[Path] = None,
|
||||||
'''
|
) -> Dict[str, AddModelResult]:
|
||||||
|
"""
|
||||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
:param models_installed: Set of installed models, used for recursive invocation
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = dict()
|
models_installed = dict()
|
||||||
|
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
path = Path(model_path_id_or_url)
|
path = Path(model_path_id_or_url)
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.update({str(path):self._install_path(path)})
|
models_installed.update({str(path): self._install_path(path)})
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any([(path/x).exists() for x in \
|
elif path.is_dir() and any(
|
||||||
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
[
|
||||||
]
|
(path / x).exists()
|
||||||
):
|
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||||
|
]
|
||||||
|
):
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
@ -229,7 +236,7 @@ class ModelInstall(object):
|
|||||||
self.heuristic_import(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(model_path_id_or_url).split('/')) == 2:
|
elif len(str(model_path_id_or_url).split("/")) == 2:
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
@ -237,42 +244,43 @@ class ModelInstall(object):
|
|||||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
||||||
|
|
||||||
return models_installed
|
return models_installed
|
||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
||||||
if not info:
|
if not info:
|
||||||
logger.warning(f'Unable to parse format of {path}')
|
logger.warning(f"Unable to parse format of {path}")
|
||||||
return None
|
return None
|
||||||
model_name = path.stem if path.is_file() else path.name
|
model_name = path.stem if path.is_file() else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path, info)
|
||||||
return self.mgr.add_model(model_name = model_name,
|
return self.mgr.add_model(
|
||||||
base_model = info.base_type,
|
model_name=model_name,
|
||||||
model_type = info.model_type,
|
base_model=info.base_type,
|
||||||
model_attributes = attributes,
|
model_type=info.model_type,
|
||||||
)
|
model_attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
def _install_url(self, url: str)->AddModelResult:
|
def _install_url(self, url: str) -> AddModelResult:
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url, Path(staging))
|
||||||
if not location:
|
if not location:
|
||||||
logger.error(f'Unable to download {url}. Skipping.')
|
logger.error(f"Unable to download {url}. Skipping.")
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location)
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
models_path = shutil.move(location,dest)
|
models_path = shutil.move(location, dest)
|
||||||
|
|
||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str)->AddModelResult:
|
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
# list all the files in the repo
|
# list all the files in the repo
|
||||||
files = [x.rfilename for x in hinfo.siblings]
|
files = [x.rfilename for x in hinfo.siblings]
|
||||||
@ -280,42 +288,49 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if 'model_index.json' in files:
|
if "model_index.json" in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
else:
|
else:
|
||||||
for suffix in ['safetensors','bin']:
|
for suffix in ["safetensors", "bin"]:
|
||||||
if f'pytorch_lora_weights.{suffix}' in files:
|
if f"pytorch_lora_weights.{suffix}" in files:
|
||||||
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
|
||||||
break
|
break
|
||||||
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
|
elif (
|
||||||
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
|
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
|
||||||
|
): # vae, controlnet or some other standalone
|
||||||
|
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
break
|
break
|
||||||
elif f'diffusion_pytorch_model.{suffix}' in files:
|
elif f"diffusion_pytorch_model.{suffix}" in files:
|
||||||
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
|
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
break
|
break
|
||||||
elif f'learned_embeds.{suffix}' in files:
|
elif f"learned_embeds.{suffix}" in files:
|
||||||
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
|
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||||
break
|
break
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
if not info:
|
if not info:
|
||||||
logger.warning(f'Could not probe {location}. Skipping install.')
|
logger.warning(f"Could not probe {location}. Skipping install.")
|
||||||
return {}
|
return {}
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
dest = (
|
||||||
|
self.config.models_path
|
||||||
|
/ info.base_type.value
|
||||||
|
/ info.model_type.value
|
||||||
|
/ self._get_model_name(repo_id, location)
|
||||||
|
)
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
shutil.rmtree(dest)
|
shutil.rmtree(dest)
|
||||||
shutil.copytree(location,dest)
|
shutil.copytree(location, dest)
|
||||||
return self._install_path(dest, info)
|
return self._install_path(dest, info)
|
||||||
|
|
||||||
def _get_model_name(self,path_name: str, location: Path)->str:
|
def _get_model_name(self, path_name: str, location: Path) -> str:
|
||||||
'''
|
"""
|
||||||
Calculate a name for the model - primitive implementation.
|
Calculate a name for the model - primitive implementation.
|
||||||
'''
|
"""
|
||||||
if key := self.reverse_paths.get(path_name):
|
if key := self.reverse_paths.get(path_name):
|
||||||
(name, base, mtype) = ModelManager.parse_key(key)
|
(name, base, mtype) = ModelManager.parse_key(key)
|
||||||
return name
|
return name
|
||||||
@ -324,99 +339,103 @@ class ModelInstall(object):
|
|||||||
else:
|
else:
|
||||||
return location.stem
|
return location.stem
|
||||||
|
|
||||||
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
||||||
model_name = path.name if path.is_dir() else path.stem
|
model_name = path.name if path.is_dir() else path.stem
|
||||||
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
|
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
||||||
if key := self.reverse_paths.get(self.current_id):
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
if key in self.datasets:
|
if key in self.datasets:
|
||||||
description = self.datasets[key].get('description') or description
|
description = self.datasets[key].get("description") or description
|
||||||
|
|
||||||
rel_path = self.relative_to_root(path)
|
rel_path = self.relative_to_root(path)
|
||||||
|
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path = str(rel_path),
|
path=str(rel_path),
|
||||||
description = str(description),
|
description=str(description),
|
||||||
model_format = info.format,
|
model_format=info.format,
|
||||||
)
|
)
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main:
|
if info.model_type == ModelType.Main:
|
||||||
attributes.update(dict(variant = info.variant_type,))
|
attributes.update(
|
||||||
if info.format=="checkpoint":
|
dict(
|
||||||
|
variant=info.variant_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if info.format == "checkpoint":
|
||||||
try:
|
try:
|
||||||
possible_conf = path.with_suffix('.yaml')
|
possible_conf = path.with_suffix(".yaml")
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
elif info.base_type == BaseModelType.StableDiffusion2:
|
elif info.base_type == BaseModelType.StableDiffusion2:
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
|
legacy_conf = Path(
|
||||||
|
self.config.legacy_conf_dir,
|
||||||
|
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
|
legacy_conf = Path(
|
||||||
|
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
||||||
|
|
||||||
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
|
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
||||||
possible_conf = path.with_suffix('.yaml')
|
possible_conf = path.with_suffix(".yaml")
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
|
|
||||||
if legacy_conf:
|
if legacy_conf:
|
||||||
attributes.update(
|
attributes.update(dict(config=str(legacy_conf)))
|
||||||
dict(
|
|
||||||
config = str(legacy_conf)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path)->Path:
|
def relative_to_root(self, path: Path) -> Path:
|
||||||
root = self.config.root_path
|
root = self.config.root_path
|
||||||
if path.is_relative_to(root):
|
if path.is_relative_to(root):
|
||||||
return path.relative_to(root)
|
return path.relative_to(root)
|
||||||
else:
|
else:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
|
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
|
||||||
'''
|
"""
|
||||||
This retrieves a StableDiffusion model from cache or remote and then
|
This retrieves a StableDiffusion model from cache or remote and then
|
||||||
does a save_pretrained() to the indicated staging area.
|
does a save_pretrained() to the indicated staging area.
|
||||||
'''
|
"""
|
||||||
_,name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
|
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
||||||
model = None
|
model = None
|
||||||
for revision in revisions:
|
for revision in revisions:
|
||||||
try:
|
try:
|
||||||
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
||||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
pass
|
pass
|
||||||
if model:
|
if model:
|
||||||
break
|
break
|
||||||
if not model:
|
if not model:
|
||||||
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
|
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||||
return None
|
return None
|
||||||
model.save_pretrained(staging / name, safe_serialization=True)
|
model.save_pretrained(staging / name, safe_serialization=True)
|
||||||
return staging / name
|
return staging / name
|
||||||
|
|
||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
|
||||||
_,name = repo_id.split("/")
|
_, name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = list()
|
||||||
for filename in files:
|
for filename in files:
|
||||||
p = hf_download_with_resume(repo_id,
|
p = hf_download_with_resume(
|
||||||
model_dir=location,
|
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
|
||||||
model_name=filename,
|
)
|
||||||
access_token = self.access_token
|
|
||||||
)
|
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
else:
|
else:
|
||||||
logger.warning(f'Could not download {filename} from {repo_id}.')
|
logger.warning(f"Could not download {filename} from {repo_id}.")
|
||||||
|
|
||||||
return location if len(paths)>0 else None
|
return location if len(paths) > 0 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _reverse_paths(cls,datasets)->dict:
|
def _reverse_paths(cls, datasets) -> dict:
|
||||||
'''
|
"""
|
||||||
Reverse mapping from repo_id/path to destination name.
|
Reverse mapping from repo_id/path to destination name.
|
||||||
'''
|
"""
|
||||||
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
def yes_or_no(prompt: str, default_yes=True):
|
||||||
@ -427,13 +446,12 @@ def yes_or_no(prompt: str, default_yes=True):
|
|||||||
else:
|
else:
|
||||||
return response[0] in ("y", "Y")
|
return response[0] in ("y", "Y")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_from_pretrained(
|
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||||
model_class: object, model_name: str, destination: Path, **kwargs
|
logger = InvokeAILogger.getLogger("InvokeAI")
|
||||||
):
|
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
|
||||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
@ -442,13 +460,14 @@ def hf_download_from_pretrained(
|
|||||||
model.save_pretrained(destination, safe_serialization=True)
|
model.save_pretrained(destination, safe_serialization=True)
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_with_resume(
|
def hf_download_with_resume(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
model_dir: str,
|
model_dir: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_dest: Path = None,
|
model_dest: Path = None,
|
||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
@ -467,9 +486,7 @@ def hf_download_with_resume(
|
|||||||
resp = requests.get(url, headers=header, stream=True)
|
resp = requests.get(url, headers=header, stream=True)
|
||||||
total = int(resp.headers.get("content-length", 0))
|
total = int(resp.headers.get("content-length", 0))
|
||||||
|
|
||||||
if (
|
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
||||||
resp.status_code == 416
|
|
||||||
): # "range not satisfiable", which means nothing to return
|
|
||||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||||
return model_dest
|
return model_dest
|
||||||
elif resp.status_code == 404:
|
elif resp.status_code == 404:
|
||||||
@ -498,5 +515,3 @@ def hf_download_with_resume(
|
|||||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,12 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException
|
from .models import (
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
ModelNotFoundException,
|
||||||
|
DuplicateModelException,
|
||||||
|
)
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
@ -56,9 +56,7 @@ from diffusers.schedulers import (
|
|||||||
)
|
)
|
||||||
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
||||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||||
LDMBertConfig, LDMBertModel
|
|
||||||
)
|
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
@ -85,6 +83,7 @@ if is_accelerate_available():
|
|||||||
logger = InvokeAILogger.getLogger(__name__)
|
logger = InvokeAILogger.getLogger(__name__)
|
||||||
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert"
|
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert"
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
@ -509,9 +508,7 @@ def convert_ldm_unet_checkpoint(
|
|||||||
|
|
||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
assign_to_checkpoint(
|
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
@ -796,7 +793,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||||
if text_encoder is None:
|
if text_encoder is None:
|
||||||
config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
|
config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
|
|
||||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||||
with ctx():
|
with ctx():
|
||||||
@ -1008,7 +1005,9 @@ def stable_unclip_image_encoder(original_config):
|
|||||||
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
||||||
feature_extractor = CLIPImageProcessor()
|
feature_extractor = CLIPImageProcessor()
|
||||||
# InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur
|
# InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur
|
||||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K")
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||||
|
CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
||||||
@ -1071,17 +1070,17 @@ def convert_controlnet_checkpoint(
|
|||||||
extract_ema,
|
extract_ema,
|
||||||
use_linear_projection=None,
|
use_linear_projection=None,
|
||||||
cross_attention_dim=None,
|
cross_attention_dim=None,
|
||||||
precision: torch.dtype=torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||||
|
|
||||||
ctrlnet_config.pop("sample_size")
|
ctrlnet_config.pop("sample_size")
|
||||||
original_config = ctrlnet_config.copy()
|
original_config = ctrlnet_config.copy()
|
||||||
|
|
||||||
ctrlnet_config.pop('addition_embed_type')
|
ctrlnet_config.pop("addition_embed_type")
|
||||||
ctrlnet_config.pop('addition_time_embed_dim')
|
ctrlnet_config.pop("addition_time_embed_dim")
|
||||||
ctrlnet_config.pop('transformer_layers_per_block')
|
ctrlnet_config.pop("transformer_layers_per_block")
|
||||||
|
|
||||||
if use_linear_projection is not None:
|
if use_linear_projection is not None:
|
||||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||||
@ -1111,6 +1110,7 @@ def convert_controlnet_checkpoint(
|
|||||||
|
|
||||||
return controlnet.to(precision)
|
return controlnet.to(precision)
|
||||||
|
|
||||||
|
|
||||||
# TO DO - PASS PRECISION
|
# TO DO - PASS PRECISION
|
||||||
def download_from_original_stable_diffusion_ckpt(
|
def download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
@ -1249,8 +1249,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||||
while "state_dict" in checkpoint:
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}')
|
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
|
||||||
|
|
||||||
if original_config_file is None:
|
if original_config_file is None:
|
||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
@ -1258,7 +1258,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
|
|
||||||
# model_type = "v1"
|
# model_type = "v1"
|
||||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
config_url = (
|
||||||
|
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||||
# model_type = "v2"
|
# model_type = "v2"
|
||||||
@ -1277,7 +1279,10 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
original_config_file = BytesIO(requests.get(config_url).content)
|
original_config_file = BytesIO(requests.get(config_url).content)
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
if model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"]["parameterization"] == "v":
|
if (
|
||||||
|
model_version == BaseModelType.StableDiffusion2
|
||||||
|
and original_config["model"]["params"]["parameterization"] == "v"
|
||||||
|
):
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
image_size = 768
|
image_size = 768
|
||||||
@ -1436,7 +1441,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
config_kwargs = {"subfolder": "text_encoder"}
|
config_kwargs = {"subfolder": "text_encoder"}
|
||||||
|
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip', subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
|
||||||
|
|
||||||
if stable_unclip is None:
|
if stable_unclip is None:
|
||||||
if controlnet:
|
if controlnet:
|
||||||
@ -1491,7 +1496,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
|
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
|
||||||
|
|
||||||
prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
prior_text_model = CLIPTextModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||||
|
CONVERT_MODEL_ROOT / "clip-vit-large-patch14"
|
||||||
|
)
|
||||||
|
|
||||||
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
|
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
|
||||||
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
||||||
@ -1533,11 +1540,19 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
text_model = convert_ldm_clip_checkpoint(
|
text_model = convert_ldm_clip_checkpoint(
|
||||||
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
||||||
)
|
)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") if tokenizer is None else tokenizer
|
tokenizer = (
|
||||||
|
CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
|
if tokenizer is None
|
||||||
|
else tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
if load_safety_checker:
|
if load_safety_checker:
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker")
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker")
|
CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker"
|
||||||
|
)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
@ -1567,7 +1582,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
if model_type == "SDXL":
|
if model_type == "SDXL":
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||||
|
|
||||||
tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!")
|
tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!")
|
||||||
|
|
||||||
@ -1577,7 +1592,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline (
|
pipe = StableDiffusionXLPipeline(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -1686,24 +1701,22 @@ def download_controlnet_from_original_ckpt(
|
|||||||
|
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|
||||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
|
||||||
vae_config = create_vae_diffusers_config(
|
|
||||||
vae_config, image_size=image_size
|
|
||||||
)
|
|
||||||
|
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
||||||
checkpoint, vae_config
|
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
|
||||||
)
|
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
|
||||||
def convert_ckpt_to_diffusers(
|
def convert_ckpt_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
use_safetensors: bool=True,
|
use_safetensors: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||||
@ -1717,10 +1730,11 @@ def convert_ckpt_to_diffusers(
|
|||||||
safe_serialization=use_safetensors and is_safetensors_available(),
|
safe_serialization=use_safetensors and is_safetensors_available(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_controlnet_to_diffusers(
|
def convert_controlnet_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||||
|
@ -11,14 +11,15 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
#rank: Optional[int]
|
|
||||||
#alpha: Optional[float]
|
|
||||||
#bias: Optional[torch.Tensor]
|
|
||||||
#layer_key: str
|
|
||||||
|
|
||||||
#@property
|
class LoRALayerBase:
|
||||||
#def scale(self):
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
# layer_key: str
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,11 +32,7 @@ class LoRALayerBase:
|
|||||||
else:
|
else:
|
||||||
self.alpha = None
|
self.alpha = None
|
||||||
|
|
||||||
if (
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
"bias_indices" in values
|
|
||||||
and "bias_values" in values
|
|
||||||
and "bias_size" in values
|
|
||||||
):
|
|
||||||
self.bias = torch.sparse_coo_tensor(
|
self.bias = torch.sparse_coo_tensor(
|
||||||
values["bias_indices"],
|
values["bias_indices"],
|
||||||
values["bias_values"],
|
values["bias_values"],
|
||||||
@ -45,13 +42,13 @@ class LoRALayerBase:
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||||
multiplier: float,
|
multiplier: float,
|
||||||
):
|
):
|
||||||
if type(module) == torch.nn.Conv2d:
|
if type(module) == torch.nn.Conv2d:
|
||||||
@ -71,12 +68,16 @@ class LoRALayerBase:
|
|||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
bias = self.bias if self.bias is not None else 0
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
return op(
|
return (
|
||||||
*input_h,
|
op(
|
||||||
(weight + bias).view(module.weight.shape),
|
*input_h,
|
||||||
None,
|
(weight + bias).view(module.weight.shape),
|
||||||
**extra_args,
|
None,
|
||||||
) * multiplier * scale
|
**extra_args,
|
||||||
|
)
|
||||||
|
* multiplier
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -99,9 +100,9 @@ class LoRALayerBase:
|
|||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
#up: torch.Tensor
|
# up: torch.Tensor
|
||||||
#mid: Optional[torch.Tensor]
|
# mid: Optional[torch.Tensor]
|
||||||
#down: torch.Tensor
|
# down: torch.Tensor
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -151,12 +152,12 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
class LoHALayer(LoRALayerBase):
|
||||||
#w1_a: torch.Tensor
|
# w1_a: torch.Tensor
|
||||||
#w1_b: torch.Tensor
|
# w1_b: torch.Tensor
|
||||||
#w2_a: torch.Tensor
|
# w2_a: torch.Tensor
|
||||||
#w2_b: torch.Tensor
|
# w2_b: torch.Tensor
|
||||||
#t1: Optional[torch.Tensor] = None
|
# t1: Optional[torch.Tensor] = None
|
||||||
#t2: Optional[torch.Tensor] = None
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -187,12 +188,8 @@ class LoHALayer(LoRALayerBase):
|
|||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rebuild1 = torch.einsum(
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
)
|
|
||||||
rebuild2 = torch.einsum(
|
|
||||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
|
||||||
)
|
|
||||||
weight = rebuild1 * rebuild2
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
@ -223,20 +220,20 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
class LoKRLayer(LoRALayerBase):
|
||||||
#w1: Optional[torch.Tensor] = None
|
# w1: Optional[torch.Tensor] = None
|
||||||
#w1_a: Optional[torch.Tensor] = None
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
#w1_b: Optional[torch.Tensor] = None
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
#w2: Optional[torch.Tensor] = None
|
# w2: Optional[torch.Tensor] = None
|
||||||
#w2_a: Optional[torch.Tensor] = None
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
#w2_b: Optional[torch.Tensor] = None
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
#t2: Optional[torch.Tensor] = None
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_key: str,
|
layer_key: str,
|
||||||
values: dict,
|
values: dict,
|
||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
if "lokr_w1" in values:
|
||||||
self.w1 = values["lokr_w1"]
|
self.w1 = values["lokr_w1"]
|
||||||
@ -266,7 +263,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
elif "lokr_w2_b" in values:
|
elif "lokr_w2_b" in values:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
w1 = self.w1
|
w1 = self.w1
|
||||||
@ -278,7 +275,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
if self.t2 is None:
|
if self.t2 is None:
|
||||||
w2 = self.w2_a @ self.w2_b
|
w2 = self.w2_a @ self.w2_b
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@ -317,7 +314,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: #(torch.nn.Module):
|
class LoRAModel: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, LoRALayer]
|
layers: Dict[str, LoRALayer]
|
||||||
_device: torch.device
|
_device: torch.device
|
||||||
@ -345,7 +342,7 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self._dtype
|
return self._dtype
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self,
|
self,
|
||||||
@ -380,7 +377,7 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
model = cls(
|
model = cls(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers=dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -392,7 +389,6 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
state_dict = cls._group_state(state_dict)
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_down.weight" in values:
|
||||||
layer = LoRALayer(layer_key, values)
|
layer = LoRALayer(layer_key, values)
|
||||||
@ -407,9 +403,7 @@ class LoRAModel: #(torch.nn.Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: diff/ia3/... format
|
||||||
print(
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
||||||
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
@ -443,9 +437,10 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
# unmodified unet
|
# unmodified unet
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
# TODO: rename smth like ModelPatcher and add TI method?
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
@ -455,10 +450,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
module = model
|
module = model
|
||||||
module_key = ""
|
module_key = ""
|
||||||
key_parts = lora_key[len(prefix):].split('_')
|
key_parts = lora_key[len(prefix) :].split("_")
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
while len(key_parts) > 0:
|
||||||
try:
|
try:
|
||||||
module = module.get_submodule(submodule_name)
|
module = module.get_submodule(submodule_name)
|
||||||
@ -477,7 +472,6 @@ class ModelPatcher:
|
|||||||
applied_loras: List[Tuple[LoRAModel, float]],
|
applied_loras: List[Tuple[LoRAModel, float]],
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
def lora_forward(module, input_h, output):
|
def lora_forward(module, input_h, output):
|
||||||
if len(applied_loras) == 0:
|
if len(applied_loras) == 0:
|
||||||
return output
|
return output
|
||||||
@ -491,7 +485,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return lora_forward
|
return lora_forward
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_unet(
|
def apply_lora_unet(
|
||||||
@ -502,7 +495,6 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_text_encoder(
|
def apply_lora_text_encoder(
|
||||||
@ -513,7 +505,6 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@ -526,7 +517,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
#assert lora.device.type == "cpu"
|
# assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
if not layer_key.startswith(prefix):
|
if not layer_key.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
@ -536,7 +527,7 @@ class ModelPatcher:
|
|||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
# enable autocast to calc fp16 loras on cpu
|
# enable autocast to calc fp16 loras on cpu
|
||||||
#with torch.autocast(device_type="cpu"):
|
# with torch.autocast(device_type="cpu"):
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32)
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||||
@ -547,14 +538,13 @@ class ModelPatcher:
|
|||||||
|
|
||||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ti(
|
def apply_ti(
|
||||||
@ -602,7 +592,9 @@ class ModelPatcher:
|
|||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
model_embeddings.weight.data[token_id] = embedding.to(
|
||||||
|
device=text_encoder.device, dtype=text_encoder.dtype
|
||||||
|
)
|
||||||
ti_tokens.append(token_id)
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
if len(ti_tokens) > 1:
|
||||||
@ -614,7 +606,6 @@ class ModelPatcher:
|
|||||||
if init_tokens_count and new_tokens_added:
|
if init_tokens_count and new_tokens_added:
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_clip_skip(
|
def apply_clip_skip(
|
||||||
@ -633,9 +624,10 @@ class ModelPatcher:
|
|||||||
while len(skipped_layers) > 0:
|
while len(skipped_layers) > 0:
|
||||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
name: str
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
@ -647,8 +639,8 @@ class TextualInversionModel:
|
|||||||
if not isinstance(file_path, Path):
|
if not isinstance(file_path, Path):
|
||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
result.name = file_path.stem # TODO:
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@ -659,7 +651,9 @@ class TextualInversionModel:
|
|||||||
# difference mostly in metadata
|
# difference mostly in metadata
|
||||||
if "string_to_param" in state_dict:
|
if "string_to_param" in state_dict:
|
||||||
if len(state_dict["string_to_param"]) > 1:
|
if len(state_dict["string_to_param"]) > 1:
|
||||||
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
|
print(
|
||||||
|
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
|
||||||
|
)
|
||||||
|
|
||||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||||
|
|
||||||
@ -688,10 +682,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.pad_tokens = dict()
|
self.pad_tokens = dict()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||||
self, token_ids: list[int]
|
|
||||||
) -> list[int]:
|
|
||||||
|
|
||||||
if len(self.pad_tokens) == 0:
|
if len(self.pad_tokens) == 0:
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@ -707,4 +698,3 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
new_token_ids.extend(self.pad_tokens[token_id])
|
new_token_ids.extend(self.pad_tokens[token_id])
|
||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
@ -37,19 +37,22 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
|||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
|
||||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||||
DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75
|
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _CacheRecord:
|
class _CacheRecord:
|
||||||
size: int
|
size: int
|
||||||
model: Any
|
model: Any
|
||||||
@ -79,22 +82,22 @@ class _CacheRecord:
|
|||||||
return self.model.device != self.cache.storage_device
|
return self.model.device != self.cache.storage_device
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE,
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
execution_device: torch.device=torch.device('cuda'),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
storage_device: torch.device=torch.device('cpu'),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
precision: torch.dtype=torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
sequential_offload: bool=False,
|
sequential_offload: bool = False,
|
||||||
lazy_offloading: bool=True,
|
lazy_offloading: bool = True,
|
||||||
sha_chunksize: int = 16777216,
|
sha_chunksize: int = 16777216,
|
||||||
logger: types.ModuleType = logger
|
logger: types.ModuleType = logger,
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
@ -102,16 +105,16 @@ class ModelCache(object):
|
|||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||||
'''
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype=precision
|
self.precision: torch.dtype = precision
|
||||||
self.max_cache_size: float=max_cache_size
|
self.max_cache_size: float = max_cache_size
|
||||||
self.max_vram_cache_size: float=max_vram_cache_size
|
self.max_vram_cache_size: float = max_vram_cache_size
|
||||||
self.execution_device: torch.device=execution_device
|
self.execution_device: torch.device = execution_device
|
||||||
self.storage_device: torch.device=storage_device
|
self.storage_device: torch.device = storage_device
|
||||||
self.sha_chunksize=sha_chunksize
|
self.sha_chunksize = sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
self._cached_models = dict()
|
self._cached_models = dict()
|
||||||
@ -124,7 +127,6 @@ class ModelCache(object):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
key = f"{model_path}:{base_model}:{model_type}"
|
key = f"{model_path}:{base_model}:{model_type}"
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
key += f":{submodel_type}"
|
key += f":{submodel_type}"
|
||||||
@ -163,7 +165,6 @@ class ModelCache(object):
|
|||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
||||||
if not isinstance(model_path, Path):
|
if not isinstance(model_path, Path):
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
|
||||||
@ -186,7 +187,7 @@ class ModelCache(object):
|
|||||||
# TODO: lock for no copies on simultaneous calls?
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
self.logger.info(f"Loading model {model_path}, type {base_model}:{model_type}:{submodel}")
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
@ -196,7 +197,7 @@ class ModelCache(object):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
|
||||||
|
|
||||||
cache_entry = _CacheRecord(self, model, mem_used)
|
cache_entry = _CacheRecord(self, model, mem_used)
|
||||||
self._cached_models[key] = cache_entry
|
self._cached_models[key] = cache_entry
|
||||||
@ -209,13 +210,13 @@ class ModelCache(object):
|
|||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||||
'''
|
"""
|
||||||
:param cache: The model_cache object
|
:param cache: The model_cache object
|
||||||
:param key: The key of the model to lock in GPU
|
:param key: The key of the model to lock in GPU
|
||||||
:param model: The model to lock
|
:param model: The model to lock
|
||||||
:param gpu_load: True if load into gpu
|
:param gpu_load: True if load into gpu
|
||||||
:param size_needed: Size of the model to load
|
:param size_needed: Size of the model to load
|
||||||
'''
|
"""
|
||||||
self.gpu_load = gpu_load
|
self.gpu_load = gpu_load
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.key = key
|
self.key = key
|
||||||
@ -224,7 +225,7 @@ class ModelCache(object):
|
|||||||
self.cache_entry = self.cache._cached_models[self.key]
|
self.cache_entry = self.cache._cached_models[self.key]
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
def __enter__(self) -> Any:
|
||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, "to"):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this
|
# NOTE that the model has to have the to() method in order for this
|
||||||
@ -234,22 +235,21 @@ class ModelCache(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.cache.lazy_offloading:
|
if self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models(self.size_needed)
|
self.cache._offload_unlocked_models(self.size_needed)
|
||||||
|
|
||||||
if self.model.device != self.cache.execution_device:
|
if self.model.device != self.cache.execution_device:
|
||||||
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
self.model.to(self.cache.execution_device) # move into GPU
|
self.model.to(self.cache.execution_device) # move into GPU
|
||||||
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
||||||
|
|
||||||
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
except:
|
except:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# TODO: not fully understand
|
# TODO: not fully understand
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
@ -259,7 +259,7 @@ class ModelCache(object):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
@ -277,11 +277,11 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
model_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
) -> str:
|
) -> str:
|
||||||
'''
|
"""
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
Given the HF repo id or path to a model on disk, returns a unique
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||||
:param model_path: Path to model file/directory on disk.
|
:param model_path: Path to model file/directory on disk.
|
||||||
'''
|
"""
|
||||||
return self._local_model_hash(model_path)
|
return self._local_model_hash(model_path)
|
||||||
|
|
||||||
def cache_size(self) -> float:
|
def cache_size(self) -> float:
|
||||||
@ -290,7 +290,7 @@ class ModelCache(object):
|
|||||||
return current_cache_size / GIG
|
return current_cache_size / GIG
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.execution_device.type == 'cuda'
|
return self.execution_device.type == "cuda"
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
def _print_cuda_stats(self):
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
@ -306,18 +306,21 @@ class ModelCache(object):
|
|||||||
if model_info.locked:
|
if model_info.locked:
|
||||||
locked_models += 1
|
locked_models += 1
|
||||||
|
|
||||||
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
|
self.logger.debug(
|
||||||
|
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
||||||
|
)
|
||||||
|
|
||||||
def _make_cache_room(self, model_size):
|
def _make_cache_room(self, model_size):
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
#multiplier = 2 if self.precision==torch.float32 else 1
|
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = model_size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = sum([m.size for m in self._cached_models.values()])
|
current_size = sum([m.size for m in self._cached_models.values()])
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
if current_size + bytes_needed > maximum_size:
|
||||||
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
self.logger.debug(
|
||||||
|
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
@ -339,7 +342,7 @@ class ModelCache(object):
|
|||||||
with suppress(RuntimeError):
|
with suppress(RuntimeError):
|
||||||
referrer.clear()
|
referrer.clear()
|
||||||
cleared = True
|
cleared = True
|
||||||
#break
|
# break
|
||||||
|
|
||||||
# repeat if referrers changes(due to frame clear), else exit loop
|
# repeat if referrers changes(due to frame clear), else exit loop
|
||||||
if cleared:
|
if cleared:
|
||||||
@ -348,13 +351,17 @@ class ModelCache(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
self.logger.debug(
|
||||||
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}"
|
||||||
|
)
|
||||||
|
|
||||||
# 2 refs:
|
# 2 refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
if not cache_entry.locked and refs <= 2:
|
if not cache_entry.locked and refs <= 2:
|
||||||
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
self.logger.debug(
|
||||||
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
del self._cache_stack[pos]
|
del self._cache_stack[pos]
|
||||||
del self._cached_models[model_key]
|
del self._cached_models[model_key]
|
||||||
@ -368,38 +375,36 @@ class ModelCache(object):
|
|||||||
|
|
||||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _offload_unlocked_models(self, size_needed: int=0):
|
def _offload_unlocked_models(self, size_needed: int = 0):
|
||||||
reserved = self.max_vram_cache_size * GIG
|
reserved = self.max_vram_cache_size * GIG
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size):
|
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
if not cache_entry.locked and cache_entry.loaded:
|
||||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
cache_entry.model.to(self.storage_device)
|
cache_entry.model.to(self.storage_device)
|
||||||
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
|
||||||
vram_in_use += mem.vram_used # note vram_used is negative
|
vram_in_use += mem.vram_used # note vram_used is negative
|
||||||
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
path = Path(model_path)
|
path = Path(model_path)
|
||||||
|
|
||||||
hashpath = path / "checksum.sha256"
|
hashpath = path / "checksum.sha256"
|
||||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
self.logger.debug(f'computing hash of model {path.name}')
|
self.logger.debug(f"computing hash of model {path.name}")
|
||||||
for file in list(path.rglob("*.ckpt")) \
|
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
||||||
+ list(path.rglob("*.safetensors")) \
|
|
||||||
+ list(path.rglob("*.pth")):
|
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
while chunk := f.read(self.sha_chunksize):
|
while chunk := f.read(self.sha_chunksize):
|
||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
@ -408,11 +413,12 @@ class ModelCache(object):
|
|||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
|
|
||||||
class VRAMUsage(object):
|
class VRAMUsage(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vram = None
|
self.vram = None
|
||||||
self.vram_used = 0
|
self.vram_used = 0
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.vram = torch.cuda.memory_allocated()
|
self.vram = torch.cuda.memory_allocated()
|
||||||
return self
|
return self
|
||||||
|
@ -249,20 +249,26 @@ from invokeai.backend.util import CUDA_DEVICE, Chdir
|
|||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .model_search import ModelSearch
|
from .model_search import ModelSearch
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
ModelError,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
MODEL_CLASSES,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
ModelNotFoundException, InvalidModelException,
|
ModelNotFoundException,
|
||||||
|
InvalidModelException,
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
CONFIG_FILE_VERSION='3.0.0'
|
CONFIG_FILE_VERSION = "3.0.0"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo():
|
class ModelInfo:
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
base_model: BaseModelType
|
base_model: BaseModelType
|
||||||
@ -275,20 +281,24 @@ class ModelInfo():
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self.context.__enter__()
|
return self.context.__enter__()
|
||||||
|
|
||||||
def __exit__(self,*args, **kwargs):
|
def __exit__(self, *args, **kwargs):
|
||||||
self.context.__exit__(*args, **kwargs)
|
self.context.__exit__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AddModelResult(BaseModel):
|
class AddModelResult(BaseModel):
|
||||||
name: str = Field(description="The name of the model after installation")
|
name: str = Field(description="The name of the model after installation")
|
||||||
model_type: ModelType = Field(description="The type of model")
|
model_type: ModelType = Field(description="The type of model")
|
||||||
base_model: BaseModelType = Field(description="The base model")
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
|
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
High-level interface to model management.
|
High-level interface to model management.
|
||||||
@ -315,12 +325,12 @@ class ModelManager(object):
|
|||||||
if isinstance(config, (str, Path)):
|
if isinstance(config, (str, Path)):
|
||||||
self.config_path = Path(config)
|
self.config_path = Path(config)
|
||||||
if not self.config_path.exists():
|
if not self.config_path.exists():
|
||||||
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
|
logger.warning(f"The file {self.config_path} was not found. Initializing a new file")
|
||||||
self.initialize_model_config(self.config_path)
|
self.initialize_model_config(self.config_path)
|
||||||
config = OmegaConf.load(self.config_path)
|
config = OmegaConf.load(self.config_path)
|
||||||
|
|
||||||
elif not isinstance(config, DictConfig):
|
elif not isinstance(config, DictConfig):
|
||||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
raise ValueError("config argument must be an OmegaConf object, a Path or a string")
|
||||||
|
|
||||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||||
# TODO: metadata not found
|
# TODO: metadata not found
|
||||||
@ -330,11 +340,11 @@ class ModelManager(object):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
max_vram_cache_size = self.app_config.max_vram_cache_size,
|
max_vram_cache_size=self.app_config.max_vram_cache_size,
|
||||||
execution_device = device_type,
|
execution_device=device_type,
|
||||||
precision = precision,
|
precision=precision,
|
||||||
sequential_offload = sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
logger = logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._read_models(config)
|
self._read_models(config)
|
||||||
@ -348,7 +358,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
self.models = dict()
|
self.models = dict()
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
if model_key.startswith('_'):
|
if model_key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -395,7 +405,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||||
base_model_str, model_type_str, model_name = model_key.split('/', 2)
|
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
||||||
try:
|
try:
|
||||||
model_type = ModelType(model_type_str)
|
model_type = ModelType(model_type_str)
|
||||||
except:
|
except:
|
||||||
@ -414,20 +424,16 @@ class ModelManager(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def initialize_model_config(cls, config_path: Path):
|
def initialize_model_config(cls, config_path: Path):
|
||||||
"""Create empty config file"""
|
"""Create empty config file"""
|
||||||
with open(config_path,'w') as yaml_file:
|
with open(config_path, "w") as yaml_file:
|
||||||
yaml_file.write(yaml.dump({'__metadata__':
|
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||||
{'version':'3.0.0'}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None
|
submodel_type: Optional[SubModelType] = None,
|
||||||
)->ModelInfo:
|
) -> ModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an ModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
@ -451,7 +457,7 @@ class ModelManager(object):
|
|||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
self.models[model_key].error = ModelError.NotFound
|
self.models[model_key].error = ModelError.NotFound
|
||||||
raise Exception(f"Files for model \"{model_key}\" not found")
|
raise Exception(f'Files for model "{model_key}" not found')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
@ -473,7 +479,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_path=str(model_path), # TODO: refactor str/Path types logic
|
model_path=str(model_path), # TODO: refactor str/Path types logic
|
||||||
output_path=dst_convert_path,
|
output_path=dst_convert_path,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
)
|
)
|
||||||
@ -490,17 +496,17 @@ class ModelManager(object):
|
|||||||
self.cache_keys[model_key] = set()
|
self.cache_keys[model_key] = set()
|
||||||
self.cache_keys[model_key].add(model_context.key)
|
self.cache_keys[model_key].add(model_context.key)
|
||||||
|
|
||||||
model_hash = "<NO_HASH>" # TODO:
|
model_hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
context = model_context,
|
context=model_context,
|
||||||
name = model_name,
|
name=model_name,
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
type = submodel_type or model_type,
|
type=submodel_type or model_type,
|
||||||
hash = model_hash,
|
hash=model_hash,
|
||||||
location = model_path, # TODO:
|
location=model_path, # TODO:
|
||||||
precision = self.cache.precision,
|
precision=self.cache.precision,
|
||||||
_cache = self.cache,
|
_cache=self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
@ -516,7 +522,7 @@ class ModelManager(object):
|
|||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
return self.models[model_key].dict(exclude_defaults=True)
|
return self.models[model_key].dict(exclude_defaults=True)
|
||||||
else:
|
else:
|
||||||
return None # TODO: None or empty dict on not found
|
return None # TODO: None or empty dict on not found
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
@ -526,16 +532,16 @@ class ModelManager(object):
|
|||||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||||
|
|
||||||
def list_model(
|
def list_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns a dict describing one installed model, using
|
Returns a dict describing one installed model, using
|
||||||
the combined format of the list_models() method.
|
the combined format of the list_models() method.
|
||||||
"""
|
"""
|
||||||
models = self.list_models(base_model,model_type,model_name)
|
models = self.list_models(base_model, model_type, model_name)
|
||||||
return models[0] if models else None
|
return models[0] if models else None
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
@ -548,13 +554,17 @@ class ModelManager(object):
|
|||||||
Return a list of models.
|
Return a list of models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
model_keys = (
|
||||||
|
[self.create_key(model_name, base_model, model_type)]
|
||||||
|
if model_name
|
||||||
|
else sorted(self.models, key=str.casefold)
|
||||||
|
)
|
||||||
models = []
|
models = []
|
||||||
for model_key in model_keys:
|
for model_key in model_keys:
|
||||||
model_config = self.models.get(model_key)
|
model_config = self.models.get(model_key)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
self.logger.error(f'Unknown model {model_name}')
|
self.logger.error(f"Unknown model {model_name}")
|
||||||
raise ModelNotFoundException(f'Unknown model {model_name}')
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@ -571,8 +581,8 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# expose paths as absolute to help web UI
|
# expose paths as absolute to help web UI
|
||||||
if path := model_dict.get('path'):
|
if path := model_dict.get("path"):
|
||||||
model_dict['path'] = str(self.app_config.root_path / path)
|
model_dict["path"] = str(self.app_config.root_path / path)
|
||||||
models.append(model_dict)
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@ -641,15 +651,15 @@ class ModelManager(object):
|
|||||||
model_info().
|
model_info().
|
||||||
"""
|
"""
|
||||||
# relativize paths as they go in - this makes it easier to move the root directory around
|
# relativize paths as they go in - this makes it easier to move the root directory around
|
||||||
if path := model_attributes.get('path'):
|
if path := model_attributes.get("path"):
|
||||||
if Path(path).is_relative_to(self.app_config.root_path):
|
if Path(path).is_relative_to(self.app_config.root_path):
|
||||||
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
|
model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if model_key in self.models and not clobber:
|
if model_key in self.models and not clobber:
|
||||||
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
|
|
||||||
old_model = self.models.pop(model_key, None)
|
old_model = self.models.pop(model_key, None)
|
||||||
@ -675,23 +685,23 @@ class ModelManager(object):
|
|||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
return AddModelResult(
|
return AddModelResult(
|
||||||
name = model_name,
|
name=model_name,
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
config = model_config,
|
config=model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rename_model(
|
def rename_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
new_name: str = None,
|
new_name: str = None,
|
||||||
new_base: BaseModelType = None,
|
new_base: BaseModelType = None,
|
||||||
):
|
):
|
||||||
'''
|
"""
|
||||||
Rename or rebase a model.
|
Rename or rebase a model.
|
||||||
'''
|
"""
|
||||||
if new_name is None and new_base is None:
|
if new_name is None and new_base is None:
|
||||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||||
return
|
return
|
||||||
@ -710,7 +720,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||||
if old_path.is_relative_to(self.app_config.models_path):
|
if old_path.is_relative_to(self.app_config.models_path):
|
||||||
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
|
new_path = (
|
||||||
|
self.app_config.root_path
|
||||||
|
/ "models"
|
||||||
|
/ BaseModelType(new_base).value
|
||||||
|
/ ModelType(model_type).value
|
||||||
|
/ new_name
|
||||||
|
)
|
||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
@ -726,18 +742,18 @@ class ModelManager(object):
|
|||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models.pop(model_key, None) # delete
|
self.models.pop(model_key, None) # delete
|
||||||
self.models[new_key] = model_cfg
|
self.models[new_key] = model_cfg
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def convert_model (
|
def convert_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||||
dest_directory: Optional[Path]=None,
|
dest_directory: Optional[Path] = None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
'''
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
version and deleting the original checkpoint file if it is in the models
|
version and deleting the original checkpoint file if it is in the models
|
||||||
directory.
|
directory.
|
||||||
@ -746,7 +762,7 @@ class ModelManager(object):
|
|||||||
:param model_type: Type of model ['vae' or 'main']
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
This will raise a ValueError unless the model is a checkpoint.
|
This will raise a ValueError unless the model is a checkpoint.
|
||||||
'''
|
"""
|
||||||
info = self.model_info(model_name, base_model, model_type)
|
info = self.model_info(model_name, base_model, model_type)
|
||||||
if info["model_format"] != "checkpoint":
|
if info["model_format"] != "checkpoint":
|
||||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||||
@ -754,27 +770,32 @@ class ModelManager(object):
|
|||||||
# We are taking advantage of a side effect of get_model() that converts check points
|
# We are taking advantage of a side effect of get_model() that converts check points
|
||||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||||
# what submodeltype we request here, so we get the smallest.
|
# what submodeltype we request here, so we get the smallest.
|
||||||
submodel = {"submodel_type": SubModelType.Scheduler} if model_type==ModelType.Main else {}
|
submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {}
|
||||||
model = self.get_model(model_name,
|
model = self.get_model(
|
||||||
base_model,
|
model_name,
|
||||||
model_type,
|
base_model,
|
||||||
**submodel,
|
model_type,
|
||||||
)
|
**submodel,
|
||||||
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
old_diffusers_path = self.app_config.models_path / model.location
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
new_diffusers_path = (
|
||||||
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||||
|
) / model_name
|
||||||
if new_diffusers_path.exists():
|
if new_diffusers_path.exists():
|
||||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
move(old_diffusers_path,new_diffusers_path)
|
move(old_diffusers_path, new_diffusers_path)
|
||||||
info["model_format"] = "diffusers"
|
info["model_format"] = "diffusers"
|
||||||
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
info["path"] = (
|
||||||
info.pop('config')
|
str(new_diffusers_path)
|
||||||
|
if dest_directory
|
||||||
|
else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||||
|
)
|
||||||
|
info.pop("config")
|
||||||
|
|
||||||
result = self.add_model(model_name, base_model, model_type,
|
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
||||||
model_attributes = info,
|
|
||||||
clobber=True)
|
|
||||||
except:
|
except:
|
||||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||||
rmtree(new_diffusers_path)
|
rmtree(new_diffusers_path)
|
||||||
@ -798,15 +819,12 @@ class ModelManager(object):
|
|||||||
found_models = []
|
found_models = []
|
||||||
for file in files:
|
for file in files:
|
||||||
location = str(file.resolve()).replace("\\", "/")
|
location = str(file.resolve()).replace("\\", "/")
|
||||||
if (
|
if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location:
|
||||||
"model.safetensors" not in location
|
|
||||||
and "diffusion_pytorch_model.safetensors" not in location
|
|
||||||
):
|
|
||||||
found_models.append({"name": file.stem, "location": location})
|
found_models.append({"name": file.stem, "location": location})
|
||||||
|
|
||||||
return search_folder, found_models
|
return search_folder, found_models
|
||||||
|
|
||||||
def commit(self, conf_file: Path=None) -> None:
|
def commit(self, conf_file: Path = None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
@ -824,7 +842,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None,'no config file path to write to'
|
assert config_file_path is not None, "no config file path to write to"
|
||||||
config_file_path = self.app_config.root_path / config_file_path
|
config_file_path = self.app_config.root_path / config_file_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
try:
|
try:
|
||||||
@ -857,11 +875,10 @@ class ModelManager(object):
|
|||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
self.logger.info(f'Scanning {self.app_config.models_path} for new models')
|
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
|
||||||
with Chdir(self.app_config.root_path):
|
with Chdir(self.app_config.root_path):
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
@ -887,10 +904,10 @@ class ModelManager(object):
|
|||||||
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
||||||
|
|
||||||
if not models_dir.exists():
|
if not models_dir.exists():
|
||||||
continue # TODO: or create all folders?
|
continue # TODO: or create all folders?
|
||||||
|
|
||||||
for model_path in models_dir.iterdir():
|
for model_path in models_dir.iterdir():
|
||||||
if model_path not in loaded_files: # TODO: check
|
if model_path not in loaded_files: # TODO: check
|
||||||
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
||||||
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
||||||
|
|
||||||
@ -900,7 +917,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_path.is_relative_to(self.app_config.root_path):
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
model_path = model_path.relative_to(self.app_config.root_path)
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
@ -916,11 +933,10 @@ class ModelManager(object):
|
|||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
def autoimport(self) -> Dict[str, AddModelResult]:
|
||||||
def autoimport(self)->Dict[str, AddModelResult]:
|
"""
|
||||||
'''
|
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
'''
|
"""
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
@ -939,7 +955,9 @@ class ModelManager(object):
|
|||||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||||
|
|
||||||
def on_search_completed(self):
|
def on_search_completed(self):
|
||||||
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
self.logger.info(
|
||||||
|
f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models"
|
||||||
|
)
|
||||||
|
|
||||||
def models_found(self):
|
def models_found(self):
|
||||||
return self.new_models_found
|
return self.new_models_found
|
||||||
@ -949,31 +967,37 @@ class ModelManager(object):
|
|||||||
# LS: hacky
|
# LS: hacky
|
||||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||||
try:
|
try:
|
||||||
self.heuristic_import({config.root_path / 'models/core/convert/sd-vae-ft-mse'})
|
self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"})
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(
|
||||||
model_manager = self,
|
config=self.app_config,
|
||||||
prediction_type_helper = ask_user_for_prediction_type,
|
model_manager=self,
|
||||||
)
|
prediction_type_helper=ask_user_for_prediction_type,
|
||||||
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
)
|
||||||
directories = {config.root_path / x for x in [config.autoimport_dir,
|
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
||||||
config.lora_dir,
|
directories = {
|
||||||
config.embedding_dir,
|
config.root_path / x
|
||||||
config.controlnet_dir,
|
for x in [
|
||||||
] if x
|
config.autoimport_dir,
|
||||||
}
|
config.lora_dir,
|
||||||
|
config.embedding_dir,
|
||||||
|
config.controlnet_dir,
|
||||||
|
]
|
||||||
|
if x
|
||||||
|
}
|
||||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||||
scanner.search()
|
scanner.search()
|
||||||
|
|
||||||
return scanner.models_found()
|
return scanner.models_found()
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(
|
||||||
items_to_import: Set[str],
|
self,
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
items_to_import: Set[str],
|
||||||
)->Dict[str, AddModelResult]:
|
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
) -> Dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@ -992,14 +1016,15 @@ class ModelManager(object):
|
|||||||
May return the following exceptions:
|
May return the following exceptions:
|
||||||
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
- ValueError - a corresponding model already exists
|
- ValueError - a corresponding model already exists
|
||||||
'''
|
"""
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
|
||||||
successfully_installed = dict()
|
successfully_installed = dict()
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(
|
||||||
prediction_type_helper = prediction_type_helper,
|
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
||||||
model_manager = self)
|
)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
installed = installer.heuristic_import(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
|
@ -17,23 +17,25 @@ import invokeai.backend.util.logging as logger
|
|||||||
|
|
||||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||||
|
|
||||||
|
|
||||||
class MergeInterpolationMethod(str, Enum):
|
class MergeInterpolationMethod(str, Enum):
|
||||||
WeightedSum = "weighted_sum"
|
WeightedSum = "weighted_sum"
|
||||||
Sigmoid = "sigmoid"
|
Sigmoid = "sigmoid"
|
||||||
InvSigmoid = "inv_sigmoid"
|
InvSigmoid = "inv_sigmoid"
|
||||||
AddDifference = "add_difference"
|
AddDifference = "add_difference"
|
||||||
|
|
||||||
|
|
||||||
class ModelMerger(object):
|
class ModelMerger(object):
|
||||||
def __init__(self, manager: ModelManager):
|
def __init__(self, manager: ModelManager):
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
self,
|
self,
|
||||||
model_paths: List[Path],
|
model_paths: List[Path],
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: MergeInterpolationMethod = None,
|
interp: MergeInterpolationMethod = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
"""
|
"""
|
||||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||||
@ -58,24 +60,23 @@ class ModelMerger(object):
|
|||||||
merged_pipe = pipe.merge(
|
merged_pipe = pipe.merge(
|
||||||
pretrained_model_name_or_path_list=model_paths,
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
interp=interp.value if interp else None, # diffusers API treats None as "weighted sum"
|
||||||
force=force,
|
force=force,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
dlogging.set_verbosity(verbosity)
|
dlogging.set_verbosity(verbosity)
|
||||||
return merged_pipe
|
return merged_pipe
|
||||||
|
|
||||||
|
def merge_diffusion_models_and_save(
|
||||||
def merge_diffusion_models_and_save (
|
self,
|
||||||
self,
|
model_names: List[str],
|
||||||
model_names: List[str],
|
base_model: Union[BaseModelType, str],
|
||||||
base_model: Union[BaseModelType,str],
|
merged_model_name: str,
|
||||||
merged_model_name: str,
|
alpha: float = 0.5,
|
||||||
alpha: float = 0.5,
|
interp: MergeInterpolationMethod = None,
|
||||||
interp: MergeInterpolationMethod = None,
|
force: bool = False,
|
||||||
force: bool = False,
|
merge_dest_directory: Optional[Path] = None,
|
||||||
merge_dest_directory: Optional[Path] = None,
|
**kwargs,
|
||||||
**kwargs,
|
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||||
@ -94,39 +95,45 @@ class ModelMerger(object):
|
|||||||
config = self.manager.app_config
|
config = self.manager.app_config
|
||||||
base_model = BaseModelType(base_model)
|
base_model = BaseModelType(base_model)
|
||||||
vae = None
|
vae = None
|
||||||
|
|
||||||
for mod in model_names:
|
for mod in model_names:
|
||||||
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
|
||||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||||
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
assert (
|
||||||
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
info["model_format"] == "diffusers"
|
||||||
assert len(model_names) <= 2 or \
|
), f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||||
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
|
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
|
||||||
|
assert (
|
||||||
|
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||||
|
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||||
# pick up the first model's vae
|
# pick up the first model's vae
|
||||||
if mod == model_names[0]:
|
if mod == model_names[0]:
|
||||||
vae = info.get("vae")
|
vae = info.get("vae")
|
||||||
model_paths.extend([config.root_path / info["path"]])
|
model_paths.extend([config.root_path / info["path"]])
|
||||||
|
|
||||||
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
|
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
|
||||||
logger.debug(f'interp = {interp}, merge_method={merge_method}')
|
logger.debug(f"interp = {interp}, merge_method={merge_method}")
|
||||||
merged_pipe = self.merge_diffusion_models(
|
merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs)
|
||||||
model_paths, alpha, merge_method, force, **kwargs
|
dump_path = (
|
||||||
|
Path(merge_dest_directory)
|
||||||
|
if merge_dest_directory
|
||||||
|
else config.models_path / base_model.value / ModelType.Main.value
|
||||||
)
|
)
|
||||||
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path = str(dump_path),
|
path=str(dump_path),
|
||||||
description = f"Merge of models {', '.join(model_names)}",
|
description=f"Merge of models {', '.join(model_names)}",
|
||||||
model_format = "diffusers",
|
model_format="diffusers",
|
||||||
variant = ModelVariantType.Normal.value,
|
variant=ModelVariantType.Normal.value,
|
||||||
vae = vae,
|
vae=vae,
|
||||||
|
)
|
||||||
|
return self.manager.add_model(
|
||||||
|
merged_model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
model_attributes=attributes,
|
||||||
|
clobber=True,
|
||||||
)
|
)
|
||||||
return self.manager.add_model(merged_model_name,
|
|
||||||
base_model = base_model,
|
|
||||||
model_type = ModelType.Main,
|
|
||||||
model_attributes = attributes,
|
|
||||||
clobber = True
|
|
||||||
)
|
|
||||||
|
@ -10,12 +10,16 @@ from typing import Callable, Literal, Union, Dict, Optional
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, ModelVariantType,
|
BaseModelType,
|
||||||
SchedulerPredictionType, SilenceWarnings,
|
ModelType,
|
||||||
InvalidModelException
|
ModelVariantType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SilenceWarnings,
|
||||||
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
from .models.base import read_checkpoint_meta
|
from .models.base import read_checkpoint_meta
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelProbeInfo(object):
|
class ModelProbeInfo(object):
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
@ -23,70 +27,74 @@ class ModelProbeInfo(object):
|
|||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal['diffusers','checkpoint', 'lycoris']
|
format: Literal["diffusers", "checkpoint", "lycoris"]
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
'''forward declaration'''
|
"""forward declaration"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
|
|
||||||
PROBES = {
|
PROBES = {
|
||||||
'diffusers': { },
|
"diffusers": {},
|
||||||
'checkpoint': { },
|
"checkpoint": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
'StableDiffusionPipeline' : ModelType.Main,
|
"StableDiffusionPipeline": ModelType.Main,
|
||||||
'StableDiffusionInpaintPipeline' : ModelType.Main,
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
'StableDiffusionXLPipeline' : ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
'StableDiffusionXLImg2ImgPipeline' : ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
'AutoencoderKL' : ModelType.Vae,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
'ControlNetModel' : ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls,
|
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase):
|
||||||
format: Literal['diffusers','checkpoint'],
|
|
||||||
model_type: ModelType,
|
|
||||||
probe_class: ProbeBase):
|
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def heuristic_probe(cls,
|
def heuristic_probe(
|
||||||
model: Union[Dict, ModelMixin, Path],
|
cls,
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
model: Union[Dict, ModelMixin, Path],
|
||||||
)->ModelProbeInfo:
|
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||||
if isinstance(model,Path):
|
) -> ModelProbeInfo:
|
||||||
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
if isinstance(model, Path):
|
||||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
||||||
|
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
||||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||||
else:
|
else:
|
||||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe(cls,
|
def probe(
|
||||||
model_path: Path,
|
cls,
|
||||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
model_path: Path,
|
||||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo:
|
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||||
'''
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
|
) -> ModelProbeInfo:
|
||||||
|
"""
|
||||||
Probe the model at model_path and return sufficient information about it
|
Probe the model at model_path and return sufficient information about it
|
||||||
to place it somewhere in the models directory hierarchy. If the model is
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
already loaded into memory, you may provide it as model in order to avoid
|
already loaded into memory, you may provide it as model in order to avoid
|
||||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||||
between V2-Base and V2-768 SD models.
|
between V2-Base and V2-768 SD models.
|
||||||
'''
|
"""
|
||||||
if model_path:
|
if model_path:
|
||||||
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||||
else:
|
else:
|
||||||
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = (
|
||||||
if format_type == 'diffusers' \
|
cls.get_model_type_from_folder(model_path, model)
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
if format_type == "diffusers"
|
||||||
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
|
)
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
@ -96,17 +104,23 @@ class ModelProbe(object):
|
|||||||
prediction_type = probe.get_scheduler_prediction_type()
|
prediction_type = probe.get_scheduler_prediction_type()
|
||||||
format = probe.get_format()
|
format = probe.get_format()
|
||||||
model_info = ModelProbeInfo(
|
model_info = ModelProbeInfo(
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
base_type = base_type,
|
base_type=base_type,
|
||||||
variant_type = variant_type,
|
variant_type=variant_type,
|
||||||
prediction_type = prediction_type,
|
prediction_type=prediction_type,
|
||||||
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
upcast_attention=(
|
||||||
and prediction_type==SchedulerPredictionType.VPrediction),
|
base_type == BaseModelType.StableDiffusion2
|
||||||
format = format,
|
and prediction_type == SchedulerPredictionType.VPrediction
|
||||||
image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \
|
),
|
||||||
768 if (base_type==BaseModelType.StableDiffusion2 \
|
format=format,
|
||||||
and prediction_type==SchedulerPredictionType.VPrediction ) else \
|
image_size=1024
|
||||||
512
|
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||||
|
else 768
|
||||||
|
if (
|
||||||
|
base_type == BaseModelType.StableDiffusion2
|
||||||
|
and prediction_type == SchedulerPredictionType.VPrediction
|
||||||
|
)
|
||||||
|
else 512,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
@ -115,7 +129,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||||
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
|
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if model_path.name == "learned_embeds.bin":
|
if model_path.name == "learned_embeds.bin":
|
||||||
@ -142,32 +156,32 @@ class ModelProbe(object):
|
|||||||
# diffusers-ti
|
# diffusers-ti
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
||||||
'''
|
"""
|
||||||
Get the model type of a hugging-face style folder.
|
Get the model type of a hugging-face style folder.
|
||||||
'''
|
"""
|
||||||
class_name = None
|
class_name = None
|
||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
if (folder_path / 'learned_embeds.bin').exists():
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||||
return ModelType.Lora
|
return ModelType.Lora
|
||||||
|
|
||||||
i = folder_path / 'model_index.json'
|
i = folder_path / "model_index.json"
|
||||||
c = folder_path / 'config.json'
|
c = folder_path / "config.json"
|
||||||
config_path = i if i.exists() else c if c.exists() else None
|
config_path = i if i.exists() else c if c.exists() else None
|
||||||
|
|
||||||
if config_path:
|
if config_path:
|
||||||
with open(config_path,'r') as file:
|
with open(config_path, "r") as file:
|
||||||
conf = json.load(file)
|
conf = json.load(file)
|
||||||
class_name = conf['_class_name']
|
class_name = conf["_class_name"]
|
||||||
|
|
||||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||||
return type
|
return type
|
||||||
@ -176,7 +190,7 @@ class ModelProbe(object):
|
|||||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||||
cls._scan_model(model_path, model_path)
|
cls._scan_model(model_path, model_path)
|
||||||
@ -186,55 +200,53 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _scan_model(cls, model_name, checkpoint):
|
def _scan_model(cls, model_name, checkpoint):
|
||||||
"""
|
"""
|
||||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||||
|
|
||||||
|
|
||||||
###################################################3
|
###################################################3
|
||||||
# Checkpoint probing
|
# Checkpoint probing
|
||||||
###################################################3
|
###################################################3
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_variant_type(self)->ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
pass
|
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_format(self)->str:
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_format(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
class CheckpointProbeBase(ProbeBase):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
checkpoint_path: Path,
|
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
||||||
checkpoint: dict,
|
) -> BaseModelType:
|
||||||
helper: Callable[[Path],SchedulerPredictionType] = None
|
|
||||||
)->BaseModelType:
|
|
||||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
self.helper = helper
|
self.helper = helper
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_format(self)->str:
|
def get_format(self) -> str:
|
||||||
return 'checkpoint'
|
return "checkpoint"
|
||||||
|
|
||||||
def get_variant_type(self)-> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
||||||
if model_type != ModelType.Main:
|
if model_type != ModelType.Main:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
in_channels = state_dict[
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
"model.diffusion_model.input_blocks.0.0.weight"
|
|
||||||
].shape[1]
|
|
||||||
if in_channels == 9:
|
if in_channels == 9:
|
||||||
return ModelVariantType.Inpaint
|
return ModelVariantType.Inpaint
|
||||||
elif in_channels == 5:
|
elif in_channels == 5:
|
||||||
@ -242,18 +254,21 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
elif in_channels == 4:
|
elif in_channels == 4:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
else:
|
else:
|
||||||
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
raise InvalidModelException(
|
||||||
|
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
|
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
@ -261,35 +276,38 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
else:
|
else:
|
||||||
raise InvalidModelException("Cannot determine base type")
|
raise InvalidModelException("Cannot determine base type")
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
type = self.get_base_type()
|
type = self.get_base_type()
|
||||||
if type == BaseModelType.StableDiffusion1:
|
if type == BaseModelType.StableDiffusion1:
|
||||||
return SchedulerPredictionType.Epsilon
|
return SchedulerPredictionType.Epsilon
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
if 'global_step' in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
if checkpoint['global_step'] == 220000:
|
if checkpoint["global_step"] == 220000:
|
||||||
return SchedulerPredictionType.Epsilon
|
return SchedulerPredictionType.Epsilon
|
||||||
elif checkpoint["global_step"] == 110000:
|
elif checkpoint["global_step"] == 110000:
|
||||||
return SchedulerPredictionType.VPrediction
|
return SchedulerPredictionType.VPrediction
|
||||||
if self.checkpoint_path and self.helper \
|
if (
|
||||||
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
|
self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
|
||||||
|
): # if a .yaml config file exists, then this step not needed
|
||||||
return self.helper(self.checkpoint_path)
|
return self.helper(self.checkpoint_path)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
# I can't find any standalone 2.X VAEs to test with!
|
# I can't find any standalone 2.X VAEs to test with!
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
|
||||||
def get_format(self)->str:
|
|
||||||
return 'lycoris'
|
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_format(self) -> str:
|
||||||
|
return "lycoris"
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
@ -307,16 +325,17 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_format(self)->str:
|
def get_format(self) -> str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
if 'string_to_token' in checkpoint:
|
if "string_to_token" in checkpoint:
|
||||||
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
|
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||||
elif 'emb_params' in checkpoint:
|
elif "emb_params" in checkpoint:
|
||||||
token_dim = checkpoint['emb_params'].shape[-1]
|
token_dim = checkpoint["emb_params"].shape[-1]
|
||||||
else:
|
else:
|
||||||
token_dim = list(checkpoint.values())[0].shape[0]
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
if token_dim == 768:
|
if token_dim == 768:
|
||||||
@ -326,12 +345,14 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
|
for key_name in (
|
||||||
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
|
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
):
|
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||||
|
):
|
||||||
if key_name not in checkpoint:
|
if key_name not in checkpoint:
|
||||||
continue
|
continue
|
||||||
if checkpoint[key_name].shape[-1] == 768:
|
if checkpoint[key_name].shape[-1] == 768:
|
||||||
@ -342,56 +363,54 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
return self.helper(self.checkpoint_path)
|
return self.helper(self.checkpoint_path)
|
||||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
#######################################################
|
#######################################################
|
||||||
class FolderProbeBase(ProbeBase):
|
class FolderProbeBase(ProbeBase):
|
||||||
def __init__(self,
|
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
||||||
folder_path: Path,
|
|
||||||
model: ModelMixin = None,
|
|
||||||
helper: Callable=None # not used
|
|
||||||
):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.folder_path = folder_path
|
self.folder_path = folder_path
|
||||||
|
|
||||||
def get_variant_type(self)->ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
def get_format(self)->str:
|
def get_format(self) -> str:
|
||||||
return 'diffusers'
|
return "diffusers"
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
if self.model:
|
if self.model:
|
||||||
unet_conf = self.model.unet.config
|
unet_conf = self.model.unet.config
|
||||||
else:
|
else:
|
||||||
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||||
unet_conf = json.load(file)
|
unet_conf = json.load(file)
|
||||||
if unet_conf['cross_attention_dim'] == 768:
|
if unet_conf["cross_attention_dim"] == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif unet_conf['cross_attention_dim'] == 1024:
|
elif unet_conf["cross_attention_dim"] == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
elif unet_conf['cross_attention_dim'] == 1280:
|
elif unet_conf["cross_attention_dim"] == 1280:
|
||||||
return BaseModelType.StableDiffusionXLRefiner
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
elif unet_conf['cross_attention_dim'] == 2048:
|
elif unet_conf["cross_attention_dim"] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
else:
|
else:
|
||||||
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
|
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||||
if self.model:
|
if self.model:
|
||||||
scheduler_conf = self.model.scheduler.config
|
scheduler_conf = self.model.scheduler.config
|
||||||
else:
|
else:
|
||||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||||
scheduler_conf = json.load(file)
|
scheduler_conf = json.load(file)
|
||||||
if scheduler_conf['prediction_type'] == "v_prediction":
|
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||||
return SchedulerPredictionType.VPrediction
|
return SchedulerPredictionType.VPrediction
|
||||||
elif scheduler_conf['prediction_type'] == 'epsilon':
|
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||||
return SchedulerPredictionType.Epsilon
|
return SchedulerPredictionType.Epsilon
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_variant_type(self)->ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
# This only works for pipelines! Any kind of
|
# This only works for pipelines! Any kind of
|
||||||
# exception results in our returning the
|
# exception results in our returning the
|
||||||
# "normal" variant type
|
# "normal" variant type
|
||||||
@ -399,11 +418,11 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
if self.model:
|
if self.model:
|
||||||
conf = self.model.unet.config
|
conf = self.model.unet.config
|
||||||
else:
|
else:
|
||||||
config_file = self.folder_path / 'unet' / 'config.json'
|
config_file = self.folder_path / "unet" / "config.json"
|
||||||
with open(config_file,'r') as file:
|
with open(config_file, "r") as file:
|
||||||
conf = json.load(file)
|
conf = json.load(file)
|
||||||
|
|
||||||
in_channels = conf['in_channels']
|
in_channels = conf["in_channels"]
|
||||||
if in_channels == 9:
|
if in_channels == 9:
|
||||||
return ModelVariantType.Inpaint
|
return ModelVariantType.Inpaint
|
||||||
elif in_channels == 5:
|
elif in_channels == 5:
|
||||||
@ -414,60 +433,67 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
pass
|
pass
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.folder_path / 'config.json'
|
config_file = self.folder_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
with open(config_file,'r') as file:
|
with open(config_file, "r") as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
return BaseModelType.StableDiffusionXL \
|
return (
|
||||||
if config.get('scaling_factor',0)==0.13025 and config.get('sample_size') in [512, 1024] \
|
BaseModelType.StableDiffusionXL
|
||||||
|
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||||
else BaseModelType.StableDiffusion1
|
else BaseModelType.StableDiffusion1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
def get_format(self)->str:
|
def get_format(self) -> str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
path = self.folder_path / 'learned_embeds.bin'
|
path = self.folder_path / "learned_embeds.bin"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||||
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
|
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.folder_path / 'config.json'
|
config_file = self.folder_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||||
with open(config_file,'r') as file:
|
with open(config_file, "r") as file:
|
||||||
config = json.load(file)
|
config = json.load(file)
|
||||||
# no obvious way to distinguish between sd2-base and sd2-768
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
return BaseModelType.StableDiffusion1 \
|
return (
|
||||||
if config['cross_attention_dim']==768 \
|
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2
|
||||||
else BaseModelType.StableDiffusion2
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
model_file = None
|
model_file = None
|
||||||
for suffix in ['safetensors','bin']:
|
for suffix in ["safetensors", "bin"]:
|
||||||
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
|
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
||||||
if base_file.exists():
|
if base_file.exists():
|
||||||
model_file = base_file
|
model_file = base_file
|
||||||
break
|
break
|
||||||
if not model_file:
|
if not model_file:
|
||||||
raise InvalidModelException('Unknown LoRA format encountered')
|
raise InvalidModelException("Unknown LoRA format encountered")
|
||||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
@ -10,8 +10,9 @@ from pathlib import Path
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
|
||||||
class ModelSearch(ABC):
|
class ModelSearch(ABC):
|
||||||
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
||||||
"""
|
"""
|
||||||
Initialize a recursive model directory search.
|
Initialize a recursive model directory search.
|
||||||
:param directories: List of directory Paths to recurse through
|
:param directories: List of directory Paths to recurse through
|
||||||
@ -56,18 +57,23 @@ class ModelSearch(ABC):
|
|||||||
|
|
||||||
def walk_directory(self, path: Path):
|
def walk_directory(self, path: Path):
|
||||||
for root, dirs, files in os.walk(path):
|
for root, dirs, files in os.walk(path):
|
||||||
if str(Path(root).name).startswith('.'):
|
if str(Path(root).name).startswith("."):
|
||||||
self._pruned_paths.add(root)
|
self._pruned_paths.add(root)
|
||||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._items_scanned += len(dirs) + len(files)
|
self._items_scanned += len(dirs) + len(files)
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
path = Path(root) / d
|
path = Path(root) / d
|
||||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||||
self._scanned_dirs.add(path)
|
self._scanned_dirs.add(path)
|
||||||
continue
|
continue
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
if any(
|
||||||
|
[
|
||||||
|
(path / x).exists()
|
||||||
|
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
||||||
|
]
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
self.on_model_found(path)
|
self.on_model_found(path)
|
||||||
self._models_found += 1
|
self._models_found += 1
|
||||||
@ -79,18 +85,19 @@ class ModelSearch(ABC):
|
|||||||
path = Path(root) / f
|
path = Path(root) / f
|
||||||
if path.parent in self._scanned_dirs:
|
if path.parent in self._scanned_dirs:
|
||||||
continue
|
continue
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||||
try:
|
try:
|
||||||
self.on_model_found(path)
|
self.on_model_found(path)
|
||||||
self._models_found += 1
|
self._models_found += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(str(e))
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
|
||||||
class FindModels(ModelSearch):
|
class FindModels(ModelSearch):
|
||||||
def on_search_started(self):
|
def on_search_started(self):
|
||||||
self.models_found: Set[Path] = set()
|
self.models_found: Set[Path] = set()
|
||||||
|
|
||||||
def on_model_found(self,model: Path):
|
def on_model_found(self, model: Path):
|
||||||
self.models_found.add(model)
|
self.models_found.add(model)
|
||||||
|
|
||||||
def on_search_completed(self):
|
def on_search_completed(self):
|
||||||
@ -99,5 +106,3 @@ class FindModels(ModelSearch):
|
|||||||
def list_models(self) -> List[Path]:
|
def list_models(self) -> List[Path]:
|
||||||
self.search()
|
self.search()
|
||||||
return list(self.models_found)
|
return list(self.models_found)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,15 +3,24 @@ from enum import Enum
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal, get_origin
|
from typing import Literal, get_origin
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase,
|
BaseModelType,
|
||||||
ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings,
|
ModelType,
|
||||||
ModelNotFoundException, InvalidModelException, DuplicateModelException
|
SubModelType,
|
||||||
)
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
ModelVariantType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
ModelError,
|
||||||
|
SilenceWarnings,
|
||||||
|
ModelNotFoundException,
|
||||||
|
InvalidModelException,
|
||||||
|
DuplicateModelException,
|
||||||
|
)
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .sdxl import StableDiffusionXLModel
|
from .sdxl import StableDiffusionXLModel
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
from .controlnet import ControlNetModel # TODO:
|
from .controlnet import ControlNetModel # TODO:
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
@ -45,18 +54,19 @@ MODEL_CLASSES = {
|
|||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
#BaseModelType.Kandinsky2_1: {
|
# BaseModelType.Kandinsky2_1: {
|
||||||
# ModelType.Main: Kandinsky2_1Model,
|
# ModelType.Main: Kandinsky2_1Model,
|
||||||
# ModelType.MoVQ: MoVQModel,
|
# ModelType.MoVQ: MoVQModel,
|
||||||
# ModelType.Lora: LoRAModel,
|
# ModelType.Lora: LoRAModel,
|
||||||
# ModelType.ControlNet: ControlNetModel,
|
# ModelType.ControlNet: ControlNetModel,
|
||||||
# ModelType.TextualInversion: TextualInversionModel,
|
# ModelType.TextualInversion: TextualInversionModel,
|
||||||
#},
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CONFIGS = list()
|
MODEL_CONFIGS = list()
|
||||||
OPENAPI_MODEL_CONFIGS = list()
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIModelInfoBase(BaseModel):
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
model_name: str
|
model_name: str
|
||||||
base_model: BaseModelType
|
base_model: BaseModelType
|
||||||
@ -72,27 +82,31 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
# LS: sort to get the checkpoint configs first, which makes
|
# LS: sort to get the checkpoint configs first, which makes
|
||||||
# for a better template in the Swagger docs
|
# for a better template in the Swagger docs
|
||||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
model_name, cfg_name = cfg.__qualname__.split(".")[-2:]
|
||||||
openapi_cfg_name = model_name + cfg_name
|
openapi_cfg_name = model_name + cfg_name
|
||||||
if openapi_cfg_name in vars():
|
if openapi_cfg_name in vars():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
api_wrapper = type(
|
||||||
__annotations__ = dict(
|
openapi_cfg_name,
|
||||||
model_type=Literal[model_type.value],
|
(cfg, OpenAPIModelInfoBase),
|
||||||
|
dict(
|
||||||
|
__annotations__=dict(
|
||||||
|
model_type=Literal[model_type.value],
|
||||||
|
),
|
||||||
),
|
),
|
||||||
))
|
)
|
||||||
|
|
||||||
#globals()[openapi_cfg_name] = api_wrapper
|
# globals()[openapi_cfg_name] = api_wrapper
|
||||||
vars()[openapi_cfg_name] = api_wrapper
|
vars()[openapi_cfg_name] = api_wrapper
|
||||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||||
|
|
||||||
|
|
||||||
def get_model_config_enums():
|
def get_model_config_enums():
|
||||||
enums = list()
|
enums = list()
|
||||||
|
|
||||||
for model_config in MODEL_CONFIGS:
|
for model_config in MODEL_CONFIGS:
|
||||||
|
if hasattr(inspect, "get_annotations"):
|
||||||
if hasattr(inspect,'get_annotations'):
|
|
||||||
fields = inspect.get_annotations(model_config)
|
fields = inspect.get_annotations(model_config)
|
||||||
else:
|
else:
|
||||||
fields = model_config.__annotations__
|
fields = model_config.__annotations__
|
||||||
@ -109,7 +123,9 @@ def get_model_config_enums():
|
|||||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
enums.append(field)
|
enums.append(field)
|
||||||
|
|
||||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
elif get_origin(field) is Literal and all(
|
||||||
|
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||||
|
):
|
||||||
enums.append(type(field.__args__[0]))
|
enums.append(type(field.__args__[0]))
|
||||||
|
|
||||||
elif field is None:
|
elif field is None:
|
||||||
@ -119,4 +135,3 @@ def get_model_config_enums():
|
|||||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||||
|
|
||||||
return enums
|
return enums
|
||||||
|
|
||||||
|
@ -15,29 +15,35 @@ from contextlib import suppress
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelException(Exception):
|
class InvalidModelException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundException(Exception):
|
class ModelNotFoundException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
#Kandinsky2_1 = "kandinsky-2.1"
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
Main = "main"
|
Main = "main"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
class SubModelType(str, Enum):
|
||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
@ -47,23 +53,27 @@ class SubModelType(str, Enum):
|
|||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
#MoVQ = "movq"
|
# MoVQ = "movq"
|
||||||
|
|
||||||
|
|
||||||
class ModelVariantType(str, Enum):
|
class ModelVariantType(str, Enum):
|
||||||
Normal = "normal"
|
Normal = "normal"
|
||||||
Inpaint = "inpaint"
|
Inpaint = "inpaint"
|
||||||
Depth = "depth"
|
Depth = "depth"
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPredictionType(str, Enum):
|
class SchedulerPredictionType(str, Enum):
|
||||||
Epsilon = "epsilon"
|
Epsilon = "epsilon"
|
||||||
VPrediction = "v_prediction"
|
VPrediction = "v_prediction"
|
||||||
Sample = "sample"
|
Sample = "sample"
|
||||||
|
|
||||||
|
|
||||||
class ModelError(str, Enum):
|
class ModelError(str, Enum):
|
||||||
NotFound = "not_found"
|
NotFound = "not_found"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
model_format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
@ -71,13 +81,17 @@ class ModelConfigBase(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
use_enum_values = True
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, *args, **kwargs):
|
def load_config(cls, *args, **kwargs):
|
||||||
cls.config_name = kwargs.pop("config_name")
|
cls.config_name = kwargs.pop("config_name")
|
||||||
return super().load_config(*args, **kwargs)
|
return super().load_config(*args, **kwargs)
|
||||||
|
|
||||||
T_co = TypeVar('T_co', covariant=True)
|
|
||||||
|
T_co = TypeVar("T_co", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
class classproperty(Generic[T_co]):
|
class classproperty(Generic[T_co]):
|
||||||
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
def __init__(self, fget: Callable[[Any], T_co]) -> None:
|
||||||
self.fget = fget
|
self.fget = fget
|
||||||
@ -86,12 +100,13 @@ class classproperty(Generic[T_co]):
|
|||||||
return self.fget(owner)
|
return self.fget(owner)
|
||||||
|
|
||||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||||
raise AttributeError('cannot set attribute')
|
raise AttributeError("cannot set attribute")
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(metaclass=ABCMeta):
|
class ModelBase(metaclass=ABCMeta):
|
||||||
#model_path: str
|
# model_path: str
|
||||||
#base_model: BaseModelType
|
# base_model: BaseModelType
|
||||||
#model_type: ModelType
|
# model_type: ModelType
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -110,7 +125,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
return None
|
return None
|
||||||
elif any(t is None for t in subtypes):
|
elif any(t is None for t in subtypes):
|
||||||
raise Exception(f"Unsupported definition: {subtypes}")
|
raise Exception(f"Unsupported definition: {subtypes}")
|
||||||
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
if subtypes[0] in ["diffusers", "transformers"]:
|
||||||
res_type = sys.modules[subtypes[0]]
|
res_type = sys.modules[subtypes[0]]
|
||||||
subtypes = subtypes[1:]
|
subtypes = subtypes[1:]
|
||||||
@ -119,7 +134,6 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
res_type = sys.modules["diffusers"]
|
res_type = sys.modules["diffusers"]
|
||||||
res_type = getattr(res_type, "pipelines")
|
res_type = getattr(res_type, "pipelines")
|
||||||
|
|
||||||
|
|
||||||
for subtype in subtypes:
|
for subtype in subtypes:
|
||||||
res_type = getattr(res_type, subtype)
|
res_type = getattr(res_type, subtype)
|
||||||
return res_type
|
return res_type
|
||||||
@ -128,7 +142,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
def _get_configs(cls):
|
def _get_configs(cls):
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
configs = dict()
|
configs = dict()
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
if name.startswith("__"):
|
if name.startswith("__"):
|
||||||
@ -138,7 +152,7 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(inspect,'get_annotations'):
|
if hasattr(inspect, "get_annotations"):
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
else:
|
else:
|
||||||
fields = value.__annotations__
|
fields = value.__annotations__
|
||||||
@ -151,7 +165,9 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
for model_format in field:
|
for model_format in field:
|
||||||
configs[model_format.value] = value
|
configs[model_format.value] = value
|
||||||
|
|
||||||
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
elif typing.get_origin(field) is Literal and all(
|
||||||
|
isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__
|
||||||
|
):
|
||||||
for model_format in field.__args__:
|
for model_format in field.__args__:
|
||||||
configs[model_format.value] = value
|
configs[model_format.value] = value
|
||||||
|
|
||||||
@ -203,8 +219,8 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
|
|
||||||
|
|
||||||
class DiffusersModel(ModelBase):
|
class DiffusersModel(ModelBase):
|
||||||
#child_types: Dict[str, Type]
|
# child_types: Dict[str, Type]
|
||||||
#child_sizes: Dict[str, int]
|
# child_sizes: Dict[str, int]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
super().__init__(model_path, base_model, model_type)
|
super().__init__(model_path, base_model, model_type)
|
||||||
@ -214,7 +230,7 @@ class DiffusersModel(ModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
config_data = DiffusionPipeline.load_config(self.model_path)
|
config_data = DiffusionPipeline.load_config(self.model_path)
|
||||||
#config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
# config_data = json.loads(os.path.join(self.model_path, "model_index.json"))
|
||||||
except:
|
except:
|
||||||
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
raise Exception("Invalid diffusers model! (model_index.json not found or invalid)")
|
||||||
|
|
||||||
@ -228,14 +244,12 @@ class DiffusersModel(ModelBase):
|
|||||||
self.child_types[child_name] = child_type
|
self.child_types[child_name] = child_type
|
||||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is None:
|
if child_type is None:
|
||||||
return sum(self.child_sizes.values())
|
return sum(self.child_sizes.values())
|
||||||
else:
|
else:
|
||||||
return self.child_sizes[child_type]
|
return self.child_sizes[child_type]
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
@ -245,7 +259,7 @@ class DiffusersModel(ModelBase):
|
|||||||
if child_type is None:
|
if child_type is None:
|
||||||
raise Exception("Child model type can't be null on diffusers model")
|
raise Exception("Child model type can't be null on diffusers model")
|
||||||
if child_type not in self.child_types:
|
if child_type not in self.child_types:
|
||||||
return None # TODO: or raise
|
return None # TODO: or raise
|
||||||
|
|
||||||
if torch_dtype == torch.float16:
|
if torch_dtype == torch.float16:
|
||||||
variants = ["fp16", None]
|
variants = ["fp16", None]
|
||||||
@ -265,8 +279,8 @@ class DiffusersModel(ModelBase):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#print("====ERR LOAD====")
|
# print("====ERR LOAD====")
|
||||||
#print(f"{variant}: {e}")
|
# print(f"{variant}: {e}")
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||||
@ -275,15 +289,10 @@ class DiffusersModel(ModelBase):
|
|||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
#def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
# def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None):
|
||||||
def calc_model_size_by_fs(
|
|
||||||
model_path: str,
|
|
||||||
subfolder: Optional[str] = None,
|
|
||||||
variant: Optional[str] = None
|
|
||||||
):
|
|
||||||
if subfolder is not None:
|
if subfolder is not None:
|
||||||
model_path = os.path.join(model_path, subfolder)
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
|
||||||
@ -325,12 +334,12 @@ def calc_model_size_by_fs(
|
|||||||
|
|
||||||
# calculate files size if there is no index file
|
# calculate files size if there is no index file
|
||||||
formats = [
|
formats = [
|
||||||
(".safetensors",), # safetensors
|
(".safetensors",), # safetensors
|
||||||
(".bin",), # torch
|
(".bin",), # torch
|
||||||
(".onnx", ".pb"), # onnx
|
(".onnx", ".pb"), # onnx
|
||||||
(".msgpack",), # flax
|
(".msgpack",), # flax
|
||||||
(".ckpt",), # tf
|
(".ckpt",), # tf
|
||||||
(".h5",), # tf2
|
(".h5",), # tf2
|
||||||
]
|
]
|
||||||
|
|
||||||
for file_format in formats:
|
for file_format in formats:
|
||||||
@ -343,9 +352,9 @@ def calc_model_size_by_fs(
|
|||||||
file_stats = os.stat(os.path.join(model_path, model_file))
|
file_stats = os.stat(os.path.join(model_path, model_file))
|
||||||
model_size += file_stats.st_size
|
model_size += file_stats.st_size
|
||||||
return model_size
|
return model_size
|
||||||
|
|
||||||
#raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
# raise NotImplementedError(f"Unknown model structure! Files: {all_files}")
|
||||||
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_data(model) -> int:
|
def calc_model_size_by_data(model) -> int:
|
||||||
@ -364,12 +373,12 @@ def _calc_pipeline_by_data(pipeline) -> int:
|
|||||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||||
res += _calc_model_by_data(submodel)
|
res += _calc_model_by_data(submodel)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _calc_model_by_data(model) -> int:
|
def _calc_model_by_data(model) -> int:
|
||||||
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
|
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||||
mem = mem_params + mem_bufs # in bytes
|
mem = mem_params + mem_bufs # in bytes
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
|
|
||||||
@ -377,11 +386,15 @@ def _fast_safetensors_reader(path: str):
|
|||||||
checkpoint = dict()
|
checkpoint = dict()
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
definition_len = int.from_bytes(f.read(8), 'little')
|
definition_len = int.from_bytes(f.read(8), "little")
|
||||||
definition_json = f.read(definition_len)
|
definition_json = f.read(definition_len)
|
||||||
definition = json.loads(definition_json)
|
definition = json.loads(definition_json)
|
||||||
|
|
||||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
|
||||||
|
"pt",
|
||||||
|
"torch",
|
||||||
|
"pytorch",
|
||||||
|
}:
|
||||||
raise Exception("Supported only pytorch safetensors files")
|
raise Exception("Supported only pytorch safetensors files")
|
||||||
definition.pop("__metadata__", None)
|
definition.pop("__metadata__", None)
|
||||||
|
|
||||||
@ -400,6 +413,7 @@ def _fast_safetensors_reader(path: str):
|
|||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||||
if str(path).endswith(".safetensors"):
|
if str(path).endswith(".safetensors"):
|
||||||
try:
|
try:
|
||||||
@ -411,25 +425,27 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|||||||
if scan:
|
if scan:
|
||||||
scan_result = scan_file_path(path)
|
scan_result = scan_file_path(path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
class SilenceWarnings(object):
|
class SilenceWarnings(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
transformers_logging.set_verbosity_error()
|
transformers_logging.set_verbosity_error()
|
||||||
diffusers_logging.set_verbosity_error()
|
diffusers_logging.set_verbosity_error()
|
||||||
warnings.simplefilter('ignore')
|
warnings.simplefilter("ignore")
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
warnings.simplefilter('default')
|
warnings.simplefilter("default")
|
||||||
|
@ -18,13 +18,15 @@ from .base import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelFormat(str, Enum):
|
class ControlNetModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModel(ModelBase):
|
class ControlNetModel(ModelBase):
|
||||||
#model_class: Type
|
# model_class: Type
|
||||||
#model_size: int
|
# model_size: int
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal[ControlNetModelFormat.Diffusers]
|
model_format: Literal[ControlNetModelFormat.Diffusers]
|
||||||
@ -39,7 +41,7 @@ class ControlNetModel(ModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
except:
|
except:
|
||||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||||
|
|
||||||
@ -67,7 +69,7 @@ class ControlNetModel(ModelBase):
|
|||||||
raise Exception("There is no child models in controlnet model")
|
raise Exception("There is no child models in controlnet model")
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
for variant in ['fp16',None]:
|
for variant in ["fp16", None]:
|
||||||
try:
|
try:
|
||||||
model = self.model_class.from_pretrained(
|
model = self.model_class.from_pretrained(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
@ -79,7 +81,7 @@ class ControlNetModel(ModelBase):
|
|||||||
pass
|
pass
|
||||||
if not model:
|
if not model:
|
||||||
raise ModelNotFoundException()
|
raise ModelNotFoundException()
|
||||||
|
|
||||||
# calc more accurate size
|
# calc more accurate size
|
||||||
self.model_size = calc_model_size_by_data(model)
|
self.model_size = calc_model_size_by_data(model)
|
||||||
return model
|
return model
|
||||||
@ -105,29 +107,30 @@ class ControlNetModel(ModelBase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
cls,
|
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
|
||||||
config: ModelConfigBase,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> str:
|
|
||||||
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
|
||||||
return _convert_controlnet_ckpt_and_cache(
|
|
||||||
model_path = model_path,
|
|
||||||
model_config = config.config,
|
|
||||||
output_path = output_path,
|
|
||||||
base_model = base_model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _convert_controlnet_ckpt_and_cache(
|
|
||||||
cls,
|
cls,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_config: ControlNetModel.CheckpointConfig,
|
) -> str:
|
||||||
|
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
|
||||||
|
return _convert_controlnet_ckpt_and_cache(
|
||||||
|
model_path=model_path,
|
||||||
|
model_config=config.config,
|
||||||
|
output_path=output_path,
|
||||||
|
base_model=base_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_controlnet_ckpt_and_cache(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_config: ControlNetModel.CheckpointConfig,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the controlnet from checkpoint format to diffusers format,
|
Convert the controlnet from checkpoint format to diffusers format,
|
||||||
@ -144,12 +147,13 @@ def _convert_controlnet_ckpt_and_cache(
|
|||||||
|
|
||||||
# to avoid circular import errors
|
# to avoid circular import errors
|
||||||
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||||
|
|
||||||
convert_controlnet_to_diffusers(
|
convert_controlnet_to_diffusers(
|
||||||
weights,
|
weights,
|
||||||
output_path,
|
output_path,
|
||||||
original_config_file = app_config.root_path / model_config,
|
original_config_file=app_config.root_path / model_config,
|
||||||
image_size = 512,
|
image_size=512,
|
||||||
scan_needed = True,
|
scan_needed=True,
|
||||||
from_safetensors = weights.suffix == ".safetensors"
|
from_safetensors=weights.suffix == ".safetensors",
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -12,18 +12,21 @@ from .base import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelFormat(str, Enum):
|
class LoRAModelFormat(str, Enum):
|
||||||
LyCORIS = "lycoris"
|
LyCORIS = "lycoris"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
# model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: LoRAModelFormat # TODO:
|
model_format: LoRAModelFormat # TODO:
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
|
@ -15,12 +15,13 @@ from .base import (
|
|||||||
)
|
)
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXLModelFormat(str, Enum):
|
class StableDiffusionXLModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusionXLModel(DiffusersModel):
|
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLModel(DiffusersModel):
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
|
||||||
@ -53,7 +54,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
checkpoint = read_checkpoint_meta(path)
|
checkpoint = read_checkpoint_meta(path)
|
||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
elif model_format == StableDiffusionXLModelFormat.Diffusers:
|
||||||
@ -61,7 +62,7 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
unet_config = json.loads(f.read())
|
unet_config = json.loads(f.read())
|
||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config["in_channels"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
@ -81,11 +82,10 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
if ckpt_config_path is None:
|
if ckpt_config_path is None:
|
||||||
# TO DO: implement picking
|
# TO DO: implement picking
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
@ -114,11 +114,12 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
# source code changes, we simply translate here
|
# source code changes, we simply translate here
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||||
|
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=base_model,
|
version=base_model,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
use_safetensors=False, # corrupts sdxl models for some reason
|
use_safetensors=False, # corrupts sdxl models for some reason
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -26,8 +26,8 @@ class StableDiffusion1ModelFormat(str, Enum):
|
|||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
|
||||||
|
|
||||||
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
@ -38,7 +38,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: str
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
assert model_type == ModelType.Main
|
assert model_type == ModelType.Main
|
||||||
@ -59,7 +59,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
checkpoint = read_checkpoint_meta(path)
|
checkpoint = read_checkpoint_meta(path)
|
||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||||
@ -67,7 +67,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
unet_config = json.loads(f.read())
|
unet_config = json.loads(f.read())
|
||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config["in_channels"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
||||||
@ -88,7 +88,6 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
@ -125,16 +124,17 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion2ModelFormat(str, Enum):
|
class StableDiffusion2ModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
|
||||||
|
|
||||||
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
@ -167,7 +167,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
checkpoint = read_checkpoint_meta(path)
|
checkpoint = read_checkpoint_meta(path)
|
||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||||
@ -175,7 +175,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
unet_config = json.loads(f.read())
|
unet_config = json.loads(f.read())
|
||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config["in_channels"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
@ -198,7 +198,6 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
model_format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
)
|
)
|
||||||
@ -239,17 +238,19 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
# pass precision - currently defaulting to fp16
|
# pass precision - currently defaulting to fp16
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
model_config: Union[StableDiffusion1Model.CheckpointConfig,
|
model_config: Union[
|
||||||
StableDiffusion2Model.CheckpointConfig,
|
StableDiffusion1Model.CheckpointConfig,
|
||||||
StableDiffusionXLModel.CheckpointConfig,
|
StableDiffusion2Model.CheckpointConfig,
|
||||||
],
|
StableDiffusionXLModel.CheckpointConfig,
|
||||||
output_path: str,
|
],
|
||||||
use_save_model: bool=False,
|
output_path: str,
|
||||||
**kwargs,
|
use_save_model: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert the checkpoint model indicated in mconfig into a
|
Convert the checkpoint model indicated in mconfig into a
|
||||||
@ -270,13 +271,14 @@ def _convert_ckpt_and_cache(
|
|||||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
from ...util.devices import choose_torch_device, torch_dtype
|
from ...util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
model_base_to_model_type = {BaseModelType.StableDiffusion1: 'FrozenCLIPEmbedder',
|
model_base_to_model_type = {
|
||||||
BaseModelType.StableDiffusion2: 'FrozenOpenCLIPEmbedder',
|
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||||
BaseModelType.StableDiffusionXL: 'SDXL',
|
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||||
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
|
BaseModelType.StableDiffusionXL: "SDXL",
|
||||||
}
|
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||||
logger.info(f'Converting {weights} to diffusers format')
|
}
|
||||||
with SilenceWarnings():
|
logger.info(f"Converting {weights} to diffusers format")
|
||||||
|
with SilenceWarnings():
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
weights,
|
weights,
|
||||||
output_path,
|
output_path,
|
||||||
@ -286,12 +288,13 @@ def _convert_ckpt_and_cache(
|
|||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
from_safetensors = weights.suffix == ".safetensors",
|
from_safetensors=weights.suffix == ".safetensors",
|
||||||
precision = torch_dtype(choose_torch_device()),
|
precision=torch_dtype(choose_torch_device()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||||
ckpt_configs = {
|
ckpt_configs = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
@ -299,7 +302,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||||
},
|
},
|
||||||
@ -321,8 +324,6 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
if config_path.is_relative_to(app_config.root_path):
|
if config_path.is_relative_to(app_config.root_path):
|
||||||
config_path = config_path.relative_to(app_config.root_path)
|
config_path = config_path.relative_to(app_config.root_path)
|
||||||
return str(config_path)
|
return str(config_path)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,11 +11,13 @@ from .base import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel(ModelBase):
|
class TextualInversionModel(ModelBase):
|
||||||
#model_size: int
|
# model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: None
|
model_format: None
|
||||||
@ -65,7 +67,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
|
||||||
return None # diffusers-ti
|
return None # diffusers-ti
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
|
||||||
|
@ -22,13 +22,15 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
class VaeModelFormat(str, Enum):
|
class VaeModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
class VaeModel(ModelBase):
|
||||||
#vae_class: Type
|
# vae_class: Type
|
||||||
#model_size: int
|
# model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
model_format: VaeModelFormat
|
model_format: VaeModelFormat
|
||||||
@ -39,7 +41,7 @@ class VaeModel(ModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
# config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
except:
|
except:
|
||||||
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
raise Exception("Invalid vae model! (config.json not found or invalid)")
|
||||||
|
|
||||||
@ -95,7 +97,7 @@ class VaeModel(ModelBase):
|
|||||||
cls,
|
cls,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||||
@ -108,6 +110,7 @@ class VaeModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
# TODO: rework
|
# TODO: rework
|
||||||
def _convert_vae_ckpt_and_cache(
|
def _convert_vae_ckpt_and_cache(
|
||||||
weights_path: str,
|
weights_path: str,
|
||||||
@ -138,13 +141,14 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
2.1 - 768
|
2.1 - 768
|
||||||
"""
|
"""
|
||||||
image_size = 512
|
image_size = 512
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
from .stable_diffusion import _select_ckpt_config
|
from .stable_diffusion import _select_ckpt_config
|
||||||
|
|
||||||
# all sd models use same vae settings
|
# all sd models use same vae settings
|
||||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||||
else:
|
else:
|
||||||
@ -152,7 +156,8 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
|
|
||||||
# this avoids circular import error
|
# this avoids circular import error
|
||||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
if weights_path.suffix == '.safetensors':
|
|
||||||
|
if weights_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||||
@ -161,15 +166,12 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
config = OmegaConf.load(app_config.root_path/config_file)
|
config = OmegaConf.load(app_config.root_path / config_file)
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
checkpoint = checkpoint,
|
checkpoint=checkpoint,
|
||||||
vae_config = config,
|
vae_config=config,
|
||||||
image_size = image_size,
|
image_size=image_size,
|
||||||
)
|
|
||||||
vae_model.save_pretrained(
|
|
||||||
output_path,
|
|
||||||
safe_serialization=is_safetensors_available()
|
|
||||||
)
|
)
|
||||||
|
vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available())
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -47,6 +47,7 @@ from .diffusion import (
|
|||||||
)
|
)
|
||||||
from .offloading import FullyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, ModelGroup
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
run_id: str
|
run_id: str
|
||||||
@ -72,7 +73,11 @@ class AddsMaskLatents:
|
|||||||
initial_image_latents: torch.Tensor
|
initial_image_latents: torch.Tensor
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
text_embeddings: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
model_input = self.add_mask_channels(latents)
|
model_input = self.add_mask_channels(latents)
|
||||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||||
@ -80,12 +85,8 @@ class AddsMaskLatents:
|
|||||||
def add_mask_channels(self, latents):
|
def add_mask_channels(self, latents):
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
# duplicate mask and latents for each batch
|
# duplicate mask and latents for each batch
|
||||||
mask = einops.repeat(
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
)
|
|
||||||
image_latents = einops.repeat(
|
|
||||||
self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
|
||||||
)
|
|
||||||
# add mask and image as additional channels
|
# add mask and image as additional channels
|
||||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||||
return model_input
|
return model_input
|
||||||
@ -103,9 +104,7 @@ class AddsMaskGuidance:
|
|||||||
noise: torch.Tensor
|
noise: torch.Tensor
|
||||||
_debug: Optional[Callable] = None
|
_debug: Optional[Callable] = None
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput:
|
||||||
self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
|
|
||||||
) -> BaseOutput:
|
|
||||||
output_class = step_output.__class__ # We'll create a new one with masked data.
|
output_class = step_output.__class__ # We'll create a new one with masked data.
|
||||||
|
|
||||||
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
# The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
|
||||||
@ -116,11 +115,7 @@ class AddsMaskGuidance:
|
|||||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||||
return output_class(
|
return output_class(
|
||||||
{
|
{
|
||||||
k: (
|
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
||||||
self.apply_mask(v, self._t_for_field(k, t))
|
|
||||||
if are_like_tensors(prev_sample, v)
|
|
||||||
else v
|
|
||||||
)
|
|
||||||
for k, v in step_output.items()
|
for k, v in step_output.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -132,9 +127,7 @@ class AddsMaskGuidance:
|
|||||||
|
|
||||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
mask = einops.repeat(
|
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size
|
|
||||||
)
|
|
||||||
if t.dim() == 0:
|
if t.dim() == 0:
|
||||||
# some schedulers expect t to be one-dimensional.
|
# some schedulers expect t to be one-dimensional.
|
||||||
# TODO: file diffusers bug about inconsistency?
|
# TODO: file diffusers bug about inconsistency?
|
||||||
@ -144,12 +137,8 @@ class AddsMaskGuidance:
|
|||||||
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
|
||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
mask_latents = einops.repeat(
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size
|
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
|
||||||
)
|
|
||||||
masked_input = torch.lerp(
|
|
||||||
mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)
|
|
||||||
)
|
|
||||||
if self._debug:
|
if self._debug:
|
||||||
self._debug(masked_input, f"t={t} lerped")
|
self._debug(masked_input, f"t={t} lerped")
|
||||||
return masked_input
|
return masked_input
|
||||||
@ -159,9 +148,7 @@ def trim_to_multiple_of(*args, multiple_of=8):
|
|||||||
return tuple((x - x % multiple_of) for x in args)
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
|
|
||||||
def image_resized_to_grid_as_tensor(
|
def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor:
|
||||||
image: PIL.Image.Image, normalize: bool = True, multiple_of=8
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param image: input image
|
:param image: input image
|
||||||
@ -211,6 +198,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ControlNetData:
|
class ControlNetData:
|
||||||
model: ControlNetModel = Field(default=None)
|
model: ControlNetModel = Field(default=None)
|
||||||
@ -341,9 +329,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# FIXME: can't currently register control module
|
# FIXME: can't currently register control module
|
||||||
# control_model=control_model,
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||||
self.unet, self._unet_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
@ -354,11 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
if (
|
if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers:
|
||||||
torch.cuda.is_available()
|
|
||||||
and is_xformers_available()
|
|
||||||
and not config.disable_xformers
|
|
||||||
):
|
|
||||||
self.enable_xformers_memory_efficient_attention()
|
self.enable_xformers_memory_efficient_attention()
|
||||||
else:
|
else:
|
||||||
if self.device.type == "cpu" or self.device.type == "mps":
|
if self.device.type == "cpu" or self.device.type == "mps":
|
||||||
@ -369,9 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
raise ValueError(f"unrecognized device {self.device}")
|
raise ValueError(f"unrecognized device {self.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||||
bytes_per_element_needed_for_baddbmm_duplication = (
|
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||||
latents.element_size() + 4
|
|
||||||
)
|
|
||||||
max_size_required_for_baddbmm = (
|
max_size_required_for_baddbmm = (
|
||||||
16
|
16
|
||||||
* latents.size(dim=2)
|
* latents.size(dim=2)
|
||||||
@ -380,9 +360,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
* latents.size(dim=3)
|
* latents.size(dim=3)
|
||||||
* bytes_per_element_needed_for_baddbmm_duplication
|
* bytes_per_element_needed_for_baddbmm_duplication
|
||||||
)
|
)
|
||||||
if max_size_required_for_baddbmm > (
|
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||||
mem_free * 3.0 / 4.0
|
|
||||||
): # 3.3 / 4.0 is from old Invoke code
|
|
||||||
self.enable_attention_slicing(slice_size="max")
|
self.enable_attention_slicing(slice_size="max")
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
# diffusers recommends always enabling for mps
|
# diffusers recommends always enabling for mps
|
||||||
@ -470,7 +448,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
scheduler_device = self._model_group.device_for(self.unet)
|
||||||
|
|
||||||
@ -488,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
|
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
@ -511,9 +488,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
with self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
):
|
):
|
||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
@ -607,16 +584,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# that are combined at higher level to make control_mode enum
|
# that are combined at higher level to make control_mode enum
|
||||||
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||||
# or default weighting (if False)
|
# or default weighting (if False)
|
||||||
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
soft_injection = control_mode == "more_prompt" or control_mode == "more_control"
|
||||||
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||||
# or the default both conditional and unconditional (if False)
|
# or the default both conditional and unconditional (if False)
|
||||||
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||||
|
|
||||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
|
||||||
if cfg_injection:
|
if cfg_injection:
|
||||||
control_latent_input = unet_latent_input
|
control_latent_input = unet_latent_input
|
||||||
else:
|
else:
|
||||||
@ -629,7 +605,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
encoder_hidden_states = conditioning_data.text_embeddings
|
encoder_hidden_states = conditioning_data.text_embeddings
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch(
|
(
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = self.invokeai_diffuser._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings,
|
conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings,
|
conditioning_data.text_embeddings,
|
||||||
)
|
)
|
||||||
@ -646,9 +625,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
if cfg_injection:
|
if cfg_injection:
|
||||||
@ -678,13 +657,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||||
noise_pred, timestep, latents, **conditioning_data.scheduler_args
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
# TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
|
||||||
# But the way things are now, scheduler runs _after_ that, so there was
|
# But the way things are now, scheduler runs _after_ that, so there was
|
||||||
@ -710,17 +687,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# use of AddsMaskLatents.
|
# use of AddsMaskLatents.
|
||||||
latents = AddsMaskLatents(
|
latents = AddsMaskLatents(
|
||||||
self._unet_forward,
|
self._unet_forward,
|
||||||
mask=torch.ones_like(
|
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||||
latents[:1, :1], device=latents.device, dtype=latents.dtype
|
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
||||||
),
|
|
||||||
initial_image_latents=torch.zeros_like(
|
|
||||||
latents[:1], device=latents.device, dtype=latents.dtype
|
|
||||||
),
|
|
||||||
).add_mask_channels(latents)
|
).add_mask_channels(latents)
|
||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
latents,
|
||||||
|
t,
|
||||||
|
text_embeddings,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
@ -774,9 +750,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
latents=initial_latents
|
||||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
if strength < 1.0
|
||||||
),
|
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -797,14 +773,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
def get_img2img_timesteps(
|
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
|
||||||
self, num_inference_steps: int, strength: float, device=None
|
|
||||||
) -> (torch.Tensor, int):
|
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
|
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
scheduler_device = self._model_group.device_for(self.unet)
|
scheduler_device = self._model_group.device_for(self.unet)
|
||||||
|
|
||||||
@ -849,18 +823,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||||
# because we have our own noise function
|
# because we have our own noise function
|
||||||
init_image_latents = self.non_noised_latents_from_image(
|
init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
|
||||||
init_image, device=device, dtype=latents_dtype
|
|
||||||
)
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
noise = noise_func(init_image_latents)
|
noise = noise_func(init_image_latents)
|
||||||
|
|
||||||
if mask.dim() == 3:
|
if mask.dim() == 3:
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
latent_mask = tv_resize(
|
latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR).to(
|
||||||
mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR
|
device=device, dtype=latents_dtype
|
||||||
).to(device=device, dtype=latents_dtype)
|
)
|
||||||
|
|
||||||
guidance: List[Callable] = []
|
guidance: List[Callable] = []
|
||||||
|
|
||||||
@ -868,22 +840,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
||||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
# (that's why there's a mask!) but it seems to really want that blanked out.
|
||||||
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
|
||||||
masked_latents = self.non_noised_latents_from_image(
|
masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)
|
||||||
masked_init_image, device=device, dtype=latents_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||||
self._unet_forward, latent_mask, masked_latents
|
self._unet_forward, latent_mask, masked_latents
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
guidance.append(
|
guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))
|
||||||
AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
latents=init_image_latents if strength < 1.0 else torch.zeros_like(
|
latents=init_image_latents
|
||||||
|
if strength < 1.0
|
||||||
|
else torch.zeros_like(
|
||||||
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
|
init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype
|
||||||
),
|
),
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
@ -914,18 +884,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self._model_group.load(self.vae)
|
self._model_group.load(self.vae)
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample().to(
|
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
dtype=dtype
|
|
||||||
) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
return init_latents
|
return init_latents
|
||||||
|
|
||||||
def check_for_safety(self, output, dtype):
|
def check_for_safety(self, output, dtype):
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
|
||||||
output.images, dtype=dtype
|
|
||||||
)
|
|
||||||
screened_attention_map_saver = None
|
screened_attention_map_saver = None
|
||||||
if has_nsfw_concept is None or not has_nsfw_concept:
|
if has_nsfw_concept is None or not has_nsfw_concept:
|
||||||
screened_attention_map_saver = output.attention_map_saver
|
screened_attention_map_saver = output.attention_map_saver
|
||||||
@ -949,9 +915,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
def debug_latents(self, latents, msg):
|
def debug_latents(self, latents, msg):
|
||||||
from invokeai.backend.image_util import debug_image
|
from invokeai.backend.image_util import debug_image
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
decoded = self.numpy_to_pil(self.decode_latents(latents))
|
||||||
for i, img in enumerate(decoded):
|
for i, img in enumerate(decoded):
|
||||||
debug_image(
|
debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
|
||||||
)
|
|
||||||
|
@ -17,6 +17,7 @@ from torch import nn
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@ -55,9 +56,7 @@ class Context:
|
|||||||
if name in self.self_cross_attention_module_identifiers:
|
if name in self.self_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.self_cross_attention_module_identifiers.append(name)
|
self.self_cross_attention_module_identifiers.append(name)
|
||||||
for name, module in get_cross_attention_modules(
|
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||||
model, CrossAttentionType.TOKENS
|
|
||||||
):
|
|
||||||
if name in self.tokens_cross_attention_module_identifiers:
|
if name in self.tokens_cross_attention_module_identifiers:
|
||||||
assert False, f"name {name} cannot appear more than once"
|
assert False, f"name {name} cannot appear more than once"
|
||||||
self.tokens_cross_attention_module_identifiers.append(name)
|
self.tokens_cross_attention_module_identifiers.append(name)
|
||||||
@ -68,9 +67,7 @@ class Context:
|
|||||||
else:
|
else:
|
||||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||||
|
|
||||||
def request_apply_saved_attention_maps(
|
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||||
self, cross_attention_type: CrossAttentionType
|
|
||||||
):
|
|
||||||
if cross_attention_type == CrossAttentionType.SELF:
|
if cross_attention_type == CrossAttentionType.SELF:
|
||||||
self.self_cross_attention_action = Context.Action.APPLY
|
self.self_cross_attention_action = Context.Action.APPLY
|
||||||
else:
|
else:
|
||||||
@ -139,9 +136,7 @@ class Context:
|
|||||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||||
if requested_dim is None:
|
if requested_dim is None:
|
||||||
if saved_attention_dict["dim"] is not None:
|
if saved_attention_dict["dim"] is not None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||||
f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}"
|
|
||||||
)
|
|
||||||
return saved_attention_dict["slices"][0]
|
return saved_attention_dict["slices"][0]
|
||||||
|
|
||||||
if saved_attention_dict["dim"] == requested_dim:
|
if saved_attention_dict["dim"] == requested_dim:
|
||||||
@ -154,21 +149,13 @@ class Context:
|
|||||||
if saved_attention_dict["dim"] is None:
|
if saved_attention_dict["dim"] is None:
|
||||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||||
if requested_dim == 0:
|
if requested_dim == 0:
|
||||||
return whole_saved_attention[
|
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||||
requested_offset : requested_offset + slice_size
|
|
||||||
]
|
|
||||||
elif requested_dim == 1:
|
elif requested_dim == 1:
|
||||||
return whole_saved_attention[
|
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||||
:, requested_offset : requested_offset + slice_size
|
|
||||||
]
|
|
||||||
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||||
f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_slicing_strategy(
|
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||||
self, identifier: str
|
|
||||||
) -> tuple[Optional[int], Optional[int]]:
|
|
||||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||||
if saved_attention is None:
|
if saved_attention is None:
|
||||||
return None, None
|
return None, None
|
||||||
@ -201,9 +188,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
|
|
||||||
def set_attention_slice_wrangler(
|
def set_attention_slice_wrangler(
|
||||||
self,
|
self,
|
||||||
wrangler: Optional[
|
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||||
Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Set custom attention calculator to be called when attention is calculated
|
Set custom attention calculator to be called when attention is calculated
|
||||||
@ -219,14 +204,10 @@ class InvokeAICrossAttentionMixin:
|
|||||||
"""
|
"""
|
||||||
self.attention_slice_wrangler = wrangler
|
self.attention_slice_wrangler = wrangler
|
||||||
|
|
||||||
def set_slicing_strategy_getter(
|
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||||
self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]
|
|
||||||
):
|
|
||||||
self.slicing_strategy_getter = getter
|
self.slicing_strategy_getter = getter
|
||||||
|
|
||||||
def set_attention_slice_calculated_callback(
|
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||||
self, callback: Optional[Callable[[torch.Tensor], None]]
|
|
||||||
):
|
|
||||||
self.attention_slice_calculated_callback = callback
|
self.attention_slice_calculated_callback = callback
|
||||||
|
|
||||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||||
@ -247,45 +228,31 @@ class InvokeAICrossAttentionMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# calculate attention slice by taking the best scores for each latent pixel
|
# calculate attention slice by taking the best scores for each latent pixel
|
||||||
default_attention_slice = attention_scores.softmax(
|
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||||
dim=-1, dtype=attention_scores.dtype
|
|
||||||
)
|
|
||||||
attention_slice_wrangler = self.attention_slice_wrangler
|
attention_slice_wrangler = self.attention_slice_wrangler
|
||||||
if attention_slice_wrangler is not None:
|
if attention_slice_wrangler is not None:
|
||||||
attention_slice = attention_slice_wrangler(
|
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||||
self, default_attention_slice, dim, offset, slice_size
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attention_slice = default_attention_slice
|
attention_slice = default_attention_slice
|
||||||
|
|
||||||
if self.attention_slice_calculated_callback is not None:
|
if self.attention_slice_calculated_callback is not None:
|
||||||
self.attention_slice_calculated_callback(
|
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||||
attention_slice, dim, offset, slice_size
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = torch.bmm(attention_slice, value)
|
hidden_states = torch.bmm(attention_slice, value)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||||
r = torch.zeros(
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
|
||||||
)
|
|
||||||
for i in range(0, q.shape[0], slice_size):
|
for i in range(0, q.shape[0], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
r[i:end] = self.einsum_lowest_level(
|
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||||
q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size
|
|
||||||
)
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||||
r = torch.zeros(
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype
|
|
||||||
)
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
r[:, i:end] = self.einsum_lowest_level(
|
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||||
q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size
|
|
||||||
)
|
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def einsum_op_mps_v1(self, q, k, v):
|
def einsum_op_mps_v1(self, q, k, v):
|
||||||
@ -353,6 +320,7 @@ def restore_default_cross_attention(
|
|||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
@ -372,7 +340,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
|||||||
indices = torch.arange(max_length, dtype=torch.long)
|
indices = torch.arange(max_length, dtype=torch.long)
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||||
if b0 < max_length:
|
if b0 < max_length:
|
||||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
# these tokens have not been edited
|
# these tokens have not been edited
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
mask[b0:b1] = 1
|
mask[b0:b1] = 1
|
||||||
@ -386,16 +354,14 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
|||||||
else:
|
else:
|
||||||
# try to re-use an existing slice size
|
# try to re-use an existing slice size
|
||||||
default_slice_size = 4
|
default_slice_size = 4
|
||||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
slice_size = next(
|
||||||
|
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
||||||
|
)
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
|
|
||||||
def get_cross_attention_modules(
|
|
||||||
model, which: CrossAttentionType
|
|
||||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
|
||||||
|
|
||||||
cross_attention_class: type = (
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
InvokeAIDiffusersCrossAttention
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||||
)
|
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [
|
attention_module_tuples = [
|
||||||
(name, module)
|
(name, module)
|
||||||
@ -420,9 +386,7 @@ def get_cross_attention_modules(
|
|||||||
def inject_attention_function(unet, context: Context):
|
def inject_attention_function(unet, context: Context):
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
def attention_slice_wrangler(
|
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||||
module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size
|
|
||||||
):
|
|
||||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||||
|
|
||||||
attention_slice = suggested_attention_slice
|
attention_slice = suggested_attention_slice
|
||||||
@ -430,9 +394,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
if context.get_should_save_maps(module.identifier):
|
if context.get_should_save_maps(module.identifier):
|
||||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||||
slice_to_save = (
|
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||||
attention_slice.to("cpu") if dim is not None else attention_slice
|
|
||||||
)
|
|
||||||
context.save_slice(
|
context.save_slice(
|
||||||
module.identifier,
|
module.identifier,
|
||||||
slice_to_save,
|
slice_to_save,
|
||||||
@ -442,31 +404,20 @@ def inject_attention_function(unet, context: Context):
|
|||||||
)
|
)
|
||||||
elif context.get_should_apply_saved_maps(module.identifier):
|
elif context.get_should_apply_saved_maps(module.identifier):
|
||||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||||
saved_attention_slice = context.get_slice(
|
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||||
module.identifier, dim, offset, slice_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# slice may have been offloaded to CPU
|
# slice may have been offloaded to CPU
|
||||||
saved_attention_slice = saved_attention_slice.to(
|
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||||
suggested_attention_slice.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if context.is_tokens_cross_attention(module.identifier):
|
if context.is_tokens_cross_attention(module.identifier):
|
||||||
index_map = context.cross_attention_index_map
|
index_map = context.cross_attention_index_map
|
||||||
remapped_saved_attention_slice = torch.index_select(
|
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||||
saved_attention_slice, -1, index_map
|
|
||||||
)
|
|
||||||
this_attention_slice = suggested_attention_slice
|
this_attention_slice = suggested_attention_slice
|
||||||
|
|
||||||
mask = context.cross_attention_mask.to(
|
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||||
torch_dtype(suggested_attention_slice.device)
|
|
||||||
)
|
|
||||||
saved_mask = mask
|
saved_mask = mask
|
||||||
this_mask = 1 - mask
|
this_mask = 1 - mask
|
||||||
attention_slice = (
|
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||||
remapped_saved_attention_slice * saved_mask
|
|
||||||
+ this_attention_slice * this_mask
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# just use everything
|
# just use everything
|
||||||
attention_slice = saved_attention_slice
|
attention_slice = saved_attention_slice
|
||||||
@ -480,14 +431,10 @@ def inject_attention_function(unet, context: Context):
|
|||||||
module.identifier = identifier
|
module.identifier = identifier
|
||||||
try:
|
try:
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
module.set_slicing_strategy_getter(
|
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
|
||||||
lambda module: context.get_slicing_strategy(identifier)
|
|
||||||
)
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||||
print(
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
|
||||||
) # TODO
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -503,9 +450,7 @@ def remove_attention_function(unet):
|
|||||||
module.set_slicing_strategy_getter(None)
|
module.set_slicing_strategy_getter(None)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||||
print(
|
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||||
f"TODO: implement set_attention_slice_wrangler for {type(module)}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -530,9 +475,7 @@ def get_mem_free_total(device):
|
|||||||
return mem_free_total
|
return mem_free_total
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||||
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
|
|
||||||
):
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
InvokeAICrossAttentionMixin.__init__(self)
|
InvokeAICrossAttentionMixin.__init__(self)
|
||||||
@ -641,11 +584,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
# kwargs
|
# kwargs
|
||||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||||
):
|
):
|
||||||
attention_type = (
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
CrossAttentionType.SELF
|
|
||||||
if encoder_hidden_states is None
|
|
||||||
else CrossAttentionType.TOKENS
|
|
||||||
)
|
|
||||||
|
|
||||||
# if cross-attention control is not in play, just call through to the base implementation.
|
# if cross-attention control is not in play, just call through to the base implementation.
|
||||||
if (
|
if (
|
||||||
@ -654,9 +593,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
||||||
):
|
):
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||||
return super().__call__(
|
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||||
attn, hidden_states, encoder_hidden_states, attention_mask
|
|
||||||
)
|
|
||||||
# else:
|
# else:
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||||
|
|
||||||
@ -699,18 +636,10 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
query_slice = query[start_idx:end_idx]
|
query_slice = query[start_idx:end_idx]
|
||||||
original_key_slice = original_text_key[start_idx:end_idx]
|
original_key_slice = original_text_key[start_idx:end_idx]
|
||||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||||
attn_mask_slice = (
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
attention_mask[start_idx:end_idx]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
original_attn_slice = attn.get_attention_scores(
|
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||||
query_slice, original_key_slice, attn_mask_slice
|
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||||
)
|
|
||||||
modified_attn_slice = attn.get_attention_scores(
|
|
||||||
query_slice, modified_key_slice, attn_mask_slice
|
|
||||||
)
|
|
||||||
|
|
||||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||||
# the original attention probabilities must be remapped to account for token index changes in the
|
# the original attention probabilities must be remapped to account for token index changes in the
|
||||||
@ -722,9 +651,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||||
mask = swap_cross_attn_context.mask
|
mask = swap_cross_attn_context.mask
|
||||||
inverse_mask = 1 - mask
|
inverse_mask = 1 - mask
|
||||||
attn_slice = (
|
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||||
remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
del remapped_original_attn_slice, modified_attn_slice
|
del remapped_original_attn_slice, modified_attn_slice
|
||||||
|
|
||||||
@ -744,6 +671,4 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
|
|
||||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SwapCrossAttnProcessor, self).__init__(
|
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
||||||
slice_size=int(1e9)
|
|
||||||
) # massive slice size = don't slice
|
|
||||||
|
@ -59,9 +59,7 @@ class AttentionMapSaver:
|
|||||||
for key, maps in self.collated_maps.items():
|
for key, maps in self.collated_maps.items():
|
||||||
# maps has shape [(H*W), N] for N tokens
|
# maps has shape [(H*W), N] for N tokens
|
||||||
# but we want [N, H, W]
|
# but we want [N, H, W]
|
||||||
this_scale_factor = math.sqrt(
|
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||||
maps.shape[0] / (latents_width * latents_height)
|
|
||||||
)
|
|
||||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||||
# and we need to do some dimension juggling
|
# and we need to do some dimension juggling
|
||||||
@ -72,9 +70,7 @@ class AttentionMapSaver:
|
|||||||
|
|
||||||
# scale to output size if necessary
|
# scale to output size if necessary
|
||||||
if this_scale_factor != 1:
|
if this_scale_factor != 1:
|
||||||
maps = tv_resize(
|
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||||
maps, [latents_height, latents_width], InterpolationMode.BICUBIC
|
|
||||||
)
|
|
||||||
|
|
||||||
# normalize
|
# normalize
|
||||||
maps_min = torch.min(maps)
|
maps_min = torch.min(maps)
|
||||||
@ -83,9 +79,7 @@ class AttentionMapSaver:
|
|||||||
maps_normalized = (maps - maps_min) / maps_range
|
maps_normalized = (maps - maps_min) / maps_range
|
||||||
# expand to (-0.1, 1.1) and clamp
|
# expand to (-0.1, 1.1) and clamp
|
||||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||||
maps_normalized_expanded_clamped = torch.clamp(
|
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||||
maps_normalized_expanded, 0, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# merge together, producing a vertical stack
|
# merge together, producing a vertical stack
|
||||||
maps_stacked = torch.reshape(
|
maps_stacked = torch.reshape(
|
||||||
|
@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
@ -81,14 +82,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int
|
step_count: int,
|
||||||
):
|
):
|
||||||
old_attn_processors = None
|
old_attn_processors = None
|
||||||
if extra_conditioning_info and (
|
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
||||||
extra_conditioning_info.wants_cross_attention_control
|
|
||||||
):
|
|
||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
# Load lora conditions into the model
|
# Load lora conditions into the model
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
@ -116,27 +115,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
return
|
return
|
||||||
saver.add_attention_maps(slice, key)
|
saver.add_attention_maps(slice, key)
|
||||||
|
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
self.model, CrossAttentionType.TOKENS
|
|
||||||
)
|
|
||||||
for identifier, module in tokens_cross_attention_modules:
|
for identifier, module in tokens_cross_attention_modules:
|
||||||
key = (
|
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||||
"down"
|
|
||||||
if identifier.startswith("down")
|
|
||||||
else "up"
|
|
||||||
if identifier.startswith("up")
|
|
||||||
else "mid"
|
|
||||||
)
|
|
||||||
module.set_attention_slice_calculated_callback(
|
module.set_attention_slice_calculated_callback(
|
||||||
lambda slice, dim, offset, slice_size, key=key: callback(
|
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||||
slice, dim, offset, slice_size, key
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_attention_map_saving(self):
|
def remove_attention_map_saving(self):
|
||||||
tokens_cross_attention_modules = get_cross_attention_modules(
|
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||||
self.model, CrossAttentionType.TOKENS
|
|
||||||
)
|
|
||||||
for _, module in tokens_cross_attention_modules:
|
for _, module in tokens_cross_attention_modules:
|
||||||
module.set_attention_slice_calculated_callback(None)
|
module.set_attention_slice_calculated_callback(None)
|
||||||
|
|
||||||
@ -171,10 +158,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = step_index / total_step_count
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
||||||
context.get_active_cross_attention_control_types_for_step(
|
percent_through
|
||||||
percent_through
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
@ -182,7 +167,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
@ -201,7 +190,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -209,12 +202,18 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning, **kwargs,
|
x,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
|
conditioning,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||||
unconditioned_next_x, conditioned_next_x, guidance_scale
|
unconditioned_next_x,
|
||||||
|
conditioned_next_x,
|
||||||
|
guidance_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
@ -229,37 +228,47 @@ class InvokeAIDiffuserComponent:
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if postprocessing_settings is not None:
|
if postprocessing_settings is not None:
|
||||||
percent_through = step_index / total_step_count
|
percent_through = step_index / total_step_count
|
||||||
latents = self.apply_threshold(
|
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||||
postprocessing_settings, latents, percent_through
|
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||||
)
|
|
||||||
latents = self.apply_symmetry(
|
|
||||||
postprocessing_settings, latents, percent_through
|
|
||||||
)
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
conditioning_attention_mask = torch.ones(
|
||||||
|
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||||
|
)
|
||||||
|
|
||||||
if cond.shape[1] < max_len:
|
if cond.shape[1] < max_len:
|
||||||
conditioning_attention_mask = torch.cat([
|
conditioning_attention_mask = torch.cat(
|
||||||
conditioning_attention_mask,
|
[
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
conditioning_attention_mask,
|
||||||
], dim=1)
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
cond = torch.cat([
|
cond = torch.cat(
|
||||||
cond,
|
[
|
||||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
cond,
|
||||||
], dim=1)
|
torch.zeros(
|
||||||
|
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||||
|
device=cond.device,
|
||||||
|
dtype=cond.dtype,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_attention_mask is None:
|
if encoder_attention_mask is None:
|
||||||
encoder_attention_mask = conditioning_attention_mask
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
else:
|
else:
|
||||||
encoder_attention_mask = torch.cat([
|
encoder_attention_mask = torch.cat(
|
||||||
encoder_attention_mask,
|
[
|
||||||
conditioning_attention_mask,
|
encoder_attention_mask,
|
||||||
])
|
conditioning_attention_mask,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return cond, encoder_attention_mask
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
@ -277,11 +286,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning)
|
||||||
unconditioning, conditioning
|
|
||||||
)
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings,
|
x_twice,
|
||||||
|
sigma_twice,
|
||||||
|
both_conditionings,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -312,13 +321,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, unconditioning,
|
x,
|
||||||
|
sigma,
|
||||||
|
unconditioning,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, conditioning,
|
x,
|
||||||
|
sigma,
|
||||||
|
conditioning,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -335,13 +348,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
for k in conditioning:
|
for k in conditioning:
|
||||||
if isinstance(conditioning[k], list):
|
if isinstance(conditioning[k], list):
|
||||||
both_conditionings[k] = [
|
both_conditionings[k] = [
|
||||||
torch.cat([unconditioning[k][i], conditioning[k][i]])
|
torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k]))
|
||||||
for i in range(len(conditioning[k]))
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
x_twice,
|
||||||
|
sigma_twice,
|
||||||
|
both_conditionings,
|
||||||
|
**kwargs,
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -388,9 +403,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = (
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
cross_attention_control_types_to_do
|
|
||||||
)
|
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
@ -414,19 +427,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
percent_through: float,
|
percent_through: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if (
|
if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0:
|
||||||
postprocessing_settings.threshold is None
|
|
||||||
or postprocessing_settings.threshold == 0.0
|
|
||||||
):
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
threshold = postprocessing_settings.threshold
|
threshold = postprocessing_settings.threshold
|
||||||
warmup = postprocessing_settings.warmup
|
warmup = postprocessing_settings.warmup
|
||||||
|
|
||||||
if percent_through < warmup:
|
if percent_through < warmup:
|
||||||
current_threshold = threshold + threshold * 5 * (
|
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||||
1 - (percent_through / warmup)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
current_threshold = threshold
|
current_threshold = threshold
|
||||||
|
|
||||||
@ -440,18 +448,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||||
outside = torch.count_nonzero(
|
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||||
(latents < -current_threshold) | (latents > current_threshold)
|
logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})")
|
||||||
)
|
logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}")
|
||||||
logger.info(
|
logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
|
||||||
)
|
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
return latents
|
return latents
|
||||||
@ -464,25 +464,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
latents = torch.clone(latents)
|
latents = torch.clone(latents)
|
||||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||||
num_altered += torch.count_nonzero(latents > maxval)
|
num_altered += torch.count_nonzero(latents > maxval)
|
||||||
latents[latents > maxval] = (
|
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||||
torch.rand_like(latents[latents > maxval]) * maxval
|
|
||||||
)
|
|
||||||
|
|
||||||
if minval < -current_threshold:
|
if minval < -current_threshold:
|
||||||
latents = torch.clone(latents)
|
latents = torch.clone(latents)
|
||||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||||
num_altered += torch.count_nonzero(latents < minval)
|
num_altered += torch.count_nonzero(latents < minval)
|
||||||
latents[latents < minval] = (
|
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||||
torch.rand_like(latents[latents < minval]) * minval
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
logger.debug(
|
logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})")
|
||||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered")
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
|
||||||
)
|
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
@ -501,15 +493,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Check for out of bounds
|
# Check for out of bounds
|
||||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||||
if h_symmetry_time_pct is not None and (
|
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||||
h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0
|
|
||||||
):
|
|
||||||
h_symmetry_time_pct = None
|
h_symmetry_time_pct = None
|
||||||
|
|
||||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||||
if v_symmetry_time_pct is not None and (
|
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||||
v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0
|
|
||||||
):
|
|
||||||
v_symmetry_time_pct = None
|
v_symmetry_time_pct = None
|
||||||
|
|
||||||
dev = latents.device.type
|
dev = latents.device.type
|
||||||
@ -554,9 +542,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
def estimate_percent_through(self, step_index, sigma):
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
# percent_through will never reach 1.0 (but this is intended)
|
# percent_through will never reach 1.0 (but this is intended)
|
||||||
return float(step_index) / float(
|
return float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||||
self.cross_attention_control_context.step_count
|
|
||||||
)
|
|
||||||
# find the best possible index of the current sigma in the sigma sequence
|
# find the best possible index of the current sigma in the sigma sequence
|
||||||
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
|
||||||
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
|
||||||
@ -567,19 +553,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# todo: make this work
|
# todo: make this work
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply_conjunction(
|
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||||
cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale
|
|
||||||
):
|
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2) # aka sigmas
|
t_in = torch.cat([t] * 2) # aka sigmas
|
||||||
|
|
||||||
deltas = None
|
deltas = None
|
||||||
uncond_latents = None
|
uncond_latents = None
|
||||||
weighted_cond_list = (
|
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||||
c_or_weighted_c_list
|
|
||||||
if type(c_or_weighted_c_list) is list
|
|
||||||
else [(c_or_weighted_c_list, 1)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
@ -608,15 +588,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||||
|
|
||||||
# merge the weighted deltas together into a single merged delta
|
# merge the weighted deltas together into a single merged delta
|
||||||
per_delta_weights = torch.tensor(
|
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||||
weights[1:], dtype=deltas.dtype, device=deltas.device
|
|
||||||
)
|
|
||||||
normalize = False
|
normalize = False
|
||||||
if normalize:
|
if normalize:
|
||||||
per_delta_weights /= torch.sum(per_delta_weights)
|
per_delta_weights /= torch.sum(per_delta_weights)
|
||||||
reshaped_weights = per_delta_weights.reshape(
|
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||||
per_delta_weights.shape + (1, 1, 1)
|
|
||||||
)
|
|
||||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||||
|
|
||||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||||
|
@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
|
|||||||
year={2018}
|
year={2018}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
x = ndimage.filters.convolve(
|
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
|
||||||
) # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
x = bicubic_degradation(x, sf=sf)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -389,21 +387,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
noise_level = random.randint(noise_level1, noise_level2)
|
noise_level = random.randint(noise_level1, noise_level2)
|
||||||
rnum = np.random.rand()
|
rnum = np.random.rand()
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
if rnum > 0.6: # add color Gaussian noise
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||||
img = img + np.random.normal(
|
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
|
||||||
).astype(np.float32)
|
|
||||||
else: # add noise
|
else: # add noise
|
||||||
L = noise_level2 / 255.0
|
L = noise_level2 / 255.0
|
||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img = img + np.random.multivariate_normal(
|
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
|
||||||
).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -413,21 +405,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
rnum = random.random()
|
rnum = random.random()
|
||||||
if rnum > 0.6:
|
if rnum > 0.6:
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
elif rnum < 0.4:
|
elif rnum < 0.4:
|
||||||
img += img * np.random.normal(
|
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
|
||||||
).astype(np.float32)
|
|
||||||
else:
|
else:
|
||||||
L = noise_level2 / 255.0
|
L = noise_level2 / 255.0
|
||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img += img * np.random.multivariate_normal(
|
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
|
||||||
).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -440,9 +426,7 @@ def add_Poisson_noise(img):
|
|||||||
else:
|
else:
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||||
noise_gray = (
|
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
)
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
img += noise_gray[:, :, np.newaxis]
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
@ -451,9 +435,7 @@ def add_Poisson_noise(img):
|
|||||||
def add_JPEG_noise(img):
|
def add_JPEG_noise(img):
|
||||||
quality_factor = random.randint(30, 95)
|
quality_factor = random.randint(30, 95)
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||||
result, encimg = cv2.imencode(
|
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
|
||||||
)
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
img = cv2.imdecode(encimg, 1)
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||||
return img
|
return img
|
||||||
@ -540,9 +522,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
img = ndimage.filters.convolve(
|
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
|
||||||
)
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
|
|
||||||
@ -646,9 +626,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
image = ndimage.filters.convolve(
|
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
|
||||||
)
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
image = np.clip(image, 0.0, 1.0)
|
image = np.clip(image, 0.0, 1.0)
|
||||||
|
|
||||||
@ -796,9 +774,7 @@ if __name__ == "__main__":
|
|||||||
print(i)
|
print(i)
|
||||||
img_lq = deg_fn(img)
|
img_lq = deg_fn(img)
|
||||||
print(img_lq)
|
print(img_lq)
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
|
||||||
)(image=img)["image"]
|
|
||||||
print(img_lq.shape)
|
print(img_lq.shape)
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
print("bicubic", img_lq_bicubic.shape)
|
||||||
print(img_hq.shape)
|
print(img_hq.shape)
|
||||||
@ -812,7 +788,5 @@ if __name__ == "__main__":
|
|||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||||
interpolation=0,
|
interpolation=0,
|
||||||
)
|
)
|
||||||
img_concat = np.concatenate(
|
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
|
||||||
)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
util.imsave(img_concat, str(i) + ".png")
|
||||||
|
@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3):
|
|||||||
year={2018}
|
year={2018}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
x = ndimage.filters.convolve(
|
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||||
x, np.expand_dims(k, axis=2), mode="wrap"
|
|
||||||
) # 'nearest' | 'mirror'
|
|
||||||
x = bicubic_degradation(x, sf=sf)
|
x = bicubic_degradation(x, sf=sf)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -393,21 +391,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
noise_level = random.randint(noise_level1, noise_level2)
|
noise_level = random.randint(noise_level1, noise_level2)
|
||||||
rnum = np.random.rand()
|
rnum = np.random.rand()
|
||||||
if rnum > 0.6: # add color Gaussian noise
|
if rnum > 0.6: # add color Gaussian noise
|
||||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||||
img = img + np.random.normal(
|
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
|
||||||
).astype(np.float32)
|
|
||||||
else: # add noise
|
else: # add noise
|
||||||
L = noise_level2 / 255.0
|
L = noise_level2 / 255.0
|
||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img = img + np.random.multivariate_normal(
|
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
|
||||||
).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -417,21 +409,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
|||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
rnum = random.random()
|
rnum = random.random()
|
||||||
if rnum > 0.6:
|
if rnum > 0.6:
|
||||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(
|
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
elif rnum < 0.4:
|
elif rnum < 0.4:
|
||||||
img += img * np.random.normal(
|
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||||
0, noise_level / 255.0, (*img.shape[:2], 1)
|
|
||||||
).astype(np.float32)
|
|
||||||
else:
|
else:
|
||||||
L = noise_level2 / 255.0
|
L = noise_level2 / 255.0
|
||||||
D = np.diag(np.random.rand(3))
|
D = np.diag(np.random.rand(3))
|
||||||
U = orth(np.random.rand(3, 3))
|
U = orth(np.random.rand(3, 3))
|
||||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||||
img += img * np.random.multivariate_normal(
|
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32)
|
||||||
[0, 0, 0], np.abs(L**2 * conv), img.shape[:2]
|
|
||||||
).astype(np.float32)
|
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -444,9 +430,7 @@ def add_Poisson_noise(img):
|
|||||||
else:
|
else:
|
||||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||||
noise_gray = (
|
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||||
np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
|
||||||
)
|
|
||||||
img += noise_gray[:, :, np.newaxis]
|
img += noise_gray[:, :, np.newaxis]
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
return img
|
return img
|
||||||
@ -455,9 +439,7 @@ def add_Poisson_noise(img):
|
|||||||
def add_JPEG_noise(img):
|
def add_JPEG_noise(img):
|
||||||
quality_factor = random.randint(80, 95)
|
quality_factor = random.randint(80, 95)
|
||||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||||
result, encimg = cv2.imencode(
|
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||||
".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]
|
|
||||||
)
|
|
||||||
img = cv2.imdecode(encimg, 1)
|
img = cv2.imdecode(encimg, 1)
|
||||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||||
return img
|
return img
|
||||||
@ -544,9 +526,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
|||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
img = ndimage.filters.convolve(
|
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||||
img, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
|
||||||
)
|
|
||||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
img = np.clip(img, 0.0, 1.0)
|
img = np.clip(img, 0.0, 1.0)
|
||||||
|
|
||||||
@ -653,9 +633,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
|||||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||||
k_shifted = shift_pixel(k, sf)
|
k_shifted = shift_pixel(k, sf)
|
||||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||||
image = ndimage.filters.convolve(
|
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||||
image, np.expand_dims(k_shifted, axis=2), mode="mirror"
|
|
||||||
)
|
|
||||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||||
|
|
||||||
image = np.clip(image, 0.0, 1.0)
|
image = np.clip(image, 0.0, 1.0)
|
||||||
@ -705,9 +683,9 @@ if __name__ == "__main__":
|
|||||||
img_lq = deg_fn(img)["image"]
|
img_lq = deg_fn(img)["image"]
|
||||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||||
print(img_lq)
|
print(img_lq)
|
||||||
img_lq_bicubic = albumentations.SmallestMaxSize(
|
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
|
||||||
max_size=h, interpolation=cv2.INTER_CUBIC
|
"image"
|
||||||
)(image=img_hq)["image"]
|
]
|
||||||
print(img_lq.shape)
|
print(img_lq.shape)
|
||||||
print("bicubic", img_lq_bicubic.shape)
|
print("bicubic", img_lq_bicubic.shape)
|
||||||
print(img_hq.shape)
|
print(img_hq.shape)
|
||||||
@ -721,7 +699,5 @@ if __name__ == "__main__":
|
|||||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||||
interpolation=0,
|
interpolation=0,
|
||||||
)
|
)
|
||||||
img_concat = np.concatenate(
|
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||||
[lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1
|
|
||||||
)
|
|
||||||
util.imsave(img_concat, str(i) + ".png")
|
util.imsave(img_concat, str(i) + ".png")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user