mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
7 Commits
v2.3-lates
...
invokeai-b
Author | SHA1 | Date | |
---|---|---|---|
dbd2161601 | |||
1f83ac2eae | |||
f7bb68d01c | |||
8cddf9c5b3 | |||
9b546ccf06 | |||
73dbf73a95 | |||
18a1f3893f |
34
.github/CODEOWNERS
vendored
34
.github/CODEOWNERS
vendored
@ -1,13 +1,13 @@
|
|||||||
# continuous integration
|
# continuous integration
|
||||||
/.github/workflows/ @lstein @blessedcoolant
|
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @blessedcoolant
|
/docs/ @lstein @mauwii @blessedcoolant
|
||||||
mkdocs.yml @lstein @ebr
|
mkdocs.yml @mauwii @lstein
|
||||||
|
|
||||||
# installation and configuration
|
# installation and configuration
|
||||||
/pyproject.toml @lstein @ebr
|
/pyproject.toml @mauwii @lstein @ebr
|
||||||
/docker/ @lstein
|
/docker/ @mauwii
|
||||||
/scripts/ @ebr @lstein @blessedcoolant
|
/scripts/ @ebr @lstein @blessedcoolant
|
||||||
/installer/ @ebr @lstein
|
/installer/ @ebr @lstein
|
||||||
ldm/invoke/config @lstein @ebr
|
ldm/invoke/config @lstein @ebr
|
||||||
@ -21,13 +21,13 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
|||||||
|
|
||||||
# generation and model management
|
# generation and model management
|
||||||
/ldm/*.py @lstein @blessedcoolant
|
/ldm/*.py @lstein @blessedcoolant
|
||||||
/ldm/generate.py @lstein @gregghelt2
|
/ldm/generate.py @lstein @keturn
|
||||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||||
/ldm/invoke/config @lstein @ebr @blessedcoolant
|
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||||
/ldm/invoke/generator @gregghelt2 @damian0815
|
/ldm/invoke/generator @keturn @damian0815
|
||||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||||
@ -36,17 +36,17 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
|||||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||||
|
|
||||||
# attention, textual inversion, model configuration
|
# attention, textual inversion, model configuration
|
||||||
/ldm/models @damian0815 @gregghelt2 @blessedcoolant
|
/ldm/models @damian0815 @keturn @blessedcoolant
|
||||||
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
|
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
|
||||||
/ldm/modules/attention.py @damian0815 @gregghelt2
|
/ldm/modules/attention.py @damian0815 @keturn
|
||||||
/ldm/modules/diffusionmodules @damian0815 @gregghelt2
|
/ldm/modules/diffusionmodules @damian0815 @keturn
|
||||||
/ldm/modules/distributions @damian0815 @gregghelt2
|
/ldm/modules/distributions @damian0815 @keturn
|
||||||
/ldm/modules/ema.py @damian0815 @gregghelt2
|
/ldm/modules/ema.py @damian0815 @keturn
|
||||||
/ldm/modules/embedding_manager.py @lstein
|
/ldm/modules/embedding_manager.py @lstein
|
||||||
/ldm/modules/encoders @damian0815 @gregghelt2
|
/ldm/modules/encoders @damian0815 @keturn
|
||||||
/ldm/modules/image_degradation @damian0815 @gregghelt2
|
/ldm/modules/image_degradation @damian0815 @keturn
|
||||||
/ldm/modules/losses @damian0815 @gregghelt2
|
/ldm/modules/losses @damian0815 @keturn
|
||||||
/ldm/modules/x_transformer.py @damian0815 @gregghelt2
|
/ldm/modules/x_transformer.py @damian0815 @keturn
|
||||||
|
|
||||||
# Nodes
|
# Nodes
|
||||||
apps/ @Kyle0654 @jpphoto
|
apps/ @Kyle0654 @jpphoto
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -233,3 +233,5 @@ installer/install.sh
|
|||||||
installer/update.bat
|
installer/update.bat
|
||||||
installer/update.sh
|
installer/update.sh
|
||||||
|
|
||||||
|
# no longer stored in source directory
|
||||||
|
models
|
||||||
|
@ -41,16 +41,6 @@ Windows systems). If the `loras` folder does not already exist, just
|
|||||||
create it. The vast majority of LoRA models use the Kohya file format,
|
create it. The vast majority of LoRA models use the Kohya file format,
|
||||||
which is a type of `.safetensors` file.
|
which is a type of `.safetensors` file.
|
||||||
|
|
||||||
!!! warning "LoRA Naming Restrictions"
|
|
||||||
|
|
||||||
InvokeAI will only recognize LoRA files that contain the
|
|
||||||
characters a-z, A-Z, 0-9 and the underscore character
|
|
||||||
_. Other characters, including the hyphen, will cause the
|
|
||||||
LoRA file not to load. These naming restrictions may be
|
|
||||||
relaxed in the future, but for now you will need to rename
|
|
||||||
files that contain hyphens, commas, brackets, and other
|
|
||||||
non-word characters.
|
|
||||||
|
|
||||||
You may change where InvokeAI looks for the `loras` folder by passing the
|
You may change where InvokeAI looks for the `loras` folder by passing the
|
||||||
`--lora_directory` option to the `invoke.sh`/`invoke.bat` launcher, or
|
`--lora_directory` option to the `invoke.sh`/`invoke.bat` launcher, or
|
||||||
by placing the option in `invokeai.init`. For example:
|
by placing the option in `invokeai.init`. For example:
|
||||||
|
@ -33,11 +33,6 @@ title: Overview
|
|||||||
Restore mangled faces and make images larger with upscaling. Also see
|
Restore mangled faces and make images larger with upscaling. Also see
|
||||||
the [Embiggen Upscaling Guide](EMBIGGEN.md).
|
the [Embiggen Upscaling Guide](EMBIGGEN.md).
|
||||||
|
|
||||||
- The [Using LoRA Models](LORAS.md)
|
|
||||||
|
|
||||||
Add custom subjects and styles using HuggingFace's repository of
|
|
||||||
embeddings.
|
|
||||||
|
|
||||||
- The [Concepts Library](CONCEPTS.md)
|
- The [Concepts Library](CONCEPTS.md)
|
||||||
|
|
||||||
Add custom subjects and styles using HuggingFace's repository of
|
Add custom subjects and styles using HuggingFace's repository of
|
||||||
|
@ -79,7 +79,7 @@ title: Manual Installation, Linux
|
|||||||
and obtaining an access token for downloading. It will then download and
|
and obtaining an access token for downloading. It will then download and
|
||||||
install the weights files for you.
|
install the weights files for you.
|
||||||
|
|
||||||
Please look [here](../020_INSTALL_MANUAL.md) for a manual process for doing
|
Please look [here](../INSTALL_MANUAL.md) for a manual process for doing
|
||||||
the same thing.
|
the same thing.
|
||||||
|
|
||||||
7. Start generating images!
|
7. Start generating images!
|
||||||
|
@ -75,7 +75,7 @@ Note that you will need NVIDIA drivers, Python 3.10, and Git installed beforehan
|
|||||||
obtaining an access token for downloading. It will then download and install the
|
obtaining an access token for downloading. It will then download and install the
|
||||||
weights files for you.
|
weights files for you.
|
||||||
|
|
||||||
Please look [here](../020_INSTALL_MANUAL.md) for a manual process for doing the
|
Please look [here](../INSTALL_MANUAL.md) for a manual process for doing the
|
||||||
same thing.
|
same thing.
|
||||||
|
|
||||||
8. Start generating images!
|
8. Start generating images!
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
mkdocs
|
|
||||||
mkdocs-material>=8, <9
|
|
||||||
mkdocs-git-revision-date-localized-plugin
|
|
||||||
mkdocs-redirects==1.2.0
|
|
||||||
|
|
@ -243,15 +243,16 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
# Note that we're installing pinned versions of torch and
|
# Note that we're installing pinned versions of torch and
|
||||||
# torchvision here, which *should* correspond to what is
|
# torchvision here, which *should* correspond to what is
|
||||||
# in pyproject.toml.
|
# in pyproject.toml. This is to prevent torch 2.0 from
|
||||||
|
# being installed and immediately uninstalled and replaced with 1.13
|
||||||
pip = local[self.pip]
|
pip = local[self.pip]
|
||||||
|
|
||||||
(
|
(
|
||||||
pip[
|
pip[
|
||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"torch~=2.0.0",
|
"torch~=1.13.1",
|
||||||
"torchvision>=0.14.1",
|
"torchvision~=0.14.1",
|
||||||
"--force-reinstall",
|
"--force-reinstall",
|
||||||
"--find-links" if find_links is not None else None,
|
"--find-links" if find_links is not None else None,
|
||||||
find_links,
|
find_links,
|
||||||
|
@ -25,11 +25,12 @@ from invokeai.backend.modules.parameters import parameters_to_command
|
|||||||
import invokeai.frontend.dist as frontend
|
import invokeai.frontend.dist as frontend
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.conditioning import (
|
from ldm.invoke.conditioning import (
|
||||||
get_tokens_for_prompt_object,
|
get_tokens_for_prompt_object,
|
||||||
get_prompt_structure,
|
get_prompt_structure,
|
||||||
split_weighted_subprompts,
|
split_weighted_subprompts,
|
||||||
|
get_tokenizer,
|
||||||
)
|
)
|
||||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
@ -37,11 +38,11 @@ from ldm.invoke.globals import (
|
|||||||
Globals,
|
Globals,
|
||||||
global_converted_ckpts_dir,
|
global_converted_ckpts_dir,
|
||||||
global_models_dir,
|
global_models_dir,
|
||||||
|
global_lora_models_dir,
|
||||||
)
|
)
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
from compel.prompt_parser import Blend
|
from compel.prompt_parser import Blend
|
||||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
from ldm.modules.lora_manager import LoraManager
|
|
||||||
|
|
||||||
# Loading Arguments
|
# Loading Arguments
|
||||||
opt = Args()
|
opt = Args()
|
||||||
@ -523,12 +524,20 @@ class InvokeAIWebServer:
|
|||||||
@socketio.on("getLoraModels")
|
@socketio.on("getLoraModels")
|
||||||
def get_lora_models():
|
def get_lora_models():
|
||||||
try:
|
try:
|
||||||
model = self.generate.model
|
lora_path = global_lora_models_dir()
|
||||||
lora_mgr = LoraManager(model)
|
loras = []
|
||||||
loras = lora_mgr.list_compatible_loras()
|
for root, _, files in os.walk(lora_path):
|
||||||
|
models = [
|
||||||
|
Path(root, x)
|
||||||
|
for x in files
|
||||||
|
if Path(x).suffix in [".ckpt", ".pt", ".safetensors"]
|
||||||
|
]
|
||||||
|
loras = loras + models
|
||||||
|
|
||||||
found_loras = []
|
found_loras = []
|
||||||
for lora in sorted(loras, key=str.casefold):
|
for lora in sorted(loras, key=lambda s: s.stem.lower()):
|
||||||
found_loras.append({"name":lora,"location":str(loras[lora])})
|
location = str(lora.resolve()).replace("\\", "/")
|
||||||
|
found_loras.append({"name": lora.stem, "location": location})
|
||||||
socketio.emit("foundLoras", found_loras)
|
socketio.emit("foundLoras", found_loras)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
@ -538,7 +547,7 @@ class InvokeAIWebServer:
|
|||||||
try:
|
try:
|
||||||
local_triggers = self.generate.model.textual_inversion_manager.get_all_trigger_strings()
|
local_triggers = self.generate.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
locals = [{'name': x} for x in sorted(local_triggers, key=str.casefold)]
|
locals = [{'name': x} for x in sorted(local_triggers, key=str.casefold)]
|
||||||
concepts = get_hf_concepts_lib().list_concepts(minimum_likes=5)
|
concepts = HuggingFaceConceptsLibrary().list_concepts(minimum_likes=5)
|
||||||
concepts = [{'name': f'<{x}>'} for x in sorted(concepts, key=str.casefold) if f'<{x}>' not in local_triggers]
|
concepts = [{'name': f'<{x}>'} for x in sorted(concepts, key=str.casefold) if f'<{x}>' not in local_triggers]
|
||||||
socketio.emit("foundTextualInversionTriggers", {'local_triggers': locals, 'huggingface_concepts': concepts})
|
socketio.emit("foundTextualInversionTriggers", {'local_triggers': locals, 'huggingface_concepts': concepts})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1305,7 +1314,7 @@ class InvokeAIWebServer:
|
|||||||
None
|
None
|
||||||
if type(parsed_prompt) is Blend
|
if type(parsed_prompt) is Blend
|
||||||
else get_tokens_for_prompt_object(
|
else get_tokens_for_prompt_object(
|
||||||
self.generate.model.tokenizer, parsed_prompt
|
get_tokenizer(self.generate.model), parsed_prompt
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
attention_maps_image_base64_url = (
|
attention_maps_image_base64_url = (
|
||||||
|
@ -80,8 +80,7 @@ trinart-2.0:
|
|||||||
repo_id: stabilityai/sd-vae-ft-mse
|
repo_id: stabilityai/sd-vae-ft-mse
|
||||||
recommended: False
|
recommended: False
|
||||||
waifu-diffusion-1.4:
|
waifu-diffusion-1.4:
|
||||||
description: An SD-2.1 model trained on 5.4M anime/manga-style images (4.27 GB)
|
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
|
||||||
revision: main
|
|
||||||
repo_id: hakurei/waifu-diffusion
|
repo_id: hakurei/waifu-diffusion
|
||||||
format: diffusers
|
format: diffusers
|
||||||
vae:
|
vae:
|
||||||
|
File diff suppressed because one or more lines are too long
2
invokeai/frontend/dist/index.html
vendored
2
invokeai/frontend/dist/index.html
vendored
@ -5,7 +5,7 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||||
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
||||||
<script type="module" crossorigin src="./assets/index-b12e648e.js"></script>
|
<script type="module" crossorigin src="./assets/index-f56b39bc.js"></script>
|
||||||
<link rel="stylesheet" href="./assets/index-2ab0eb58.css">
|
<link rel="stylesheet" href="./assets/index-2ab0eb58.css">
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
|
@ -33,10 +33,6 @@ import {
|
|||||||
setIntermediateImage,
|
setIntermediateImage,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
import {
|
|
||||||
getLoraModels,
|
|
||||||
getTextualInversionTriggers,
|
|
||||||
} from 'app/socketio/actions';
|
|
||||||
import type { RootState } from 'app/store';
|
import type { RootState } from 'app/store';
|
||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
import {
|
import {
|
||||||
@ -467,8 +463,6 @@ const makeSocketIOListeners = (
|
|||||||
const { model_name, model_list } = data;
|
const { model_name, model_list } = data;
|
||||||
dispatch(setModelList(model_list));
|
dispatch(setModelList(model_list));
|
||||||
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
|
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
|
||||||
dispatch(getLoraModels());
|
|
||||||
dispatch(getTextualInversionTriggers());
|
|
||||||
dispatch(setIsProcessing(false));
|
dispatch(setIsProcessing(false));
|
||||||
dispatch(setIsCancelable(true));
|
dispatch(setIsCancelable(true));
|
||||||
dispatch(
|
dispatch(
|
||||||
|
File diff suppressed because one or more lines are too long
@ -13,16 +13,11 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import diffusers
|
import diffusers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import skimage
|
import skimage
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
@ -638,8 +633,9 @@ class Generate:
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
print("** Could not generate image.")
|
|
||||||
raise
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(">> Could not generate image.")
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print("\n>> Usage stats:")
|
print("\n>> Usage stats:")
|
||||||
@ -984,15 +980,13 @@ class Generate:
|
|||||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
if self.embedding_path and not model_data.get("ti_embeddings_loaded"):
|
if self.embedding_path and not model_data.get("ti_embeddings_loaded"):
|
||||||
print(f'>> Loading embeddings from {self.embedding_path}')
|
print(f'>> Loading embeddings from {self.embedding_path}')
|
||||||
with warnings.catch_warnings():
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
for name in files:
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
ti_path = os.path.join(root, name)
|
||||||
for name in files:
|
self.model.textual_inversion_manager.load_textual_inversion(
|
||||||
ti_path = os.path.join(root, name)
|
ti_path, defer_injecting_tokens=True
|
||||||
self.model.textual_inversion_manager.load_textual_inversion(
|
)
|
||||||
ti_path, defer_injecting_tokens=True
|
model_data["ti_embeddings_loaded"] = True
|
||||||
)
|
|
||||||
model_data["ti_embeddings_loaded"] = True
|
|
||||||
print(
|
print(
|
||||||
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
|
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from compel import PromptParser
|
from compel import PromptParser
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
@ -16,6 +17,8 @@ if sys.platform == "darwin":
|
|||||||
|
|
||||||
import pyparsing # type: ignore
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
|
print(f'DEBUG: [1] All system modules imported', file=sys.stderr)
|
||||||
|
|
||||||
import ldm.invoke
|
import ldm.invoke
|
||||||
|
|
||||||
from ..generate import Generate
|
from ..generate import Generate
|
||||||
@ -30,13 +33,21 @@ from .pngwriter import PngWriter, retrieve_metadata, write_metadata
|
|||||||
from .readline import Completer, get_completer
|
from .readline import Completer, get_completer
|
||||||
from ..util import url_attachment_name
|
from ..util import url_attachment_name
|
||||||
|
|
||||||
|
print(f'DEBUG: [2] All invokeai modules imported', file=sys.stderr)
|
||||||
|
|
||||||
# global used in multiple functions (fix)
|
# global used in multiple functions (fix)
|
||||||
infile = None
|
infile = None
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Initialize command-line parsers and the diffusion model"""
|
"""Initialize command-line parsers and the diffusion model"""
|
||||||
global infile
|
global infile
|
||||||
|
|
||||||
|
print('DEBUG: [3] Entered main()', file=sys.stderr)
|
||||||
|
print('DEBUG: INVOKEAI ENVIRONMENT:')
|
||||||
|
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
|
||||||
|
print("\n".join([f'{x}:{os.environ[x]}' for x in os.environ.keys()]))
|
||||||
|
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
|
||||||
|
|
||||||
opt = Args()
|
opt = Args()
|
||||||
args = opt.parse_args()
|
args = opt.parse_args()
|
||||||
if not args:
|
if not args:
|
||||||
@ -65,9 +76,13 @@ def main():
|
|||||||
Globals.sequential_guidance = args.sequential_guidance
|
Globals.sequential_guidance = args.sequential_guidance
|
||||||
Globals.ckpt_convert = True # always true as of 2.3.4 for LoRA support
|
Globals.ckpt_convert = True # always true as of 2.3.4 for LoRA support
|
||||||
|
|
||||||
|
print(f'DEBUG: [4] Globals initialized', file=sys.stderr)
|
||||||
|
|
||||||
# run any post-install patches needed
|
# run any post-install patches needed
|
||||||
run_patches()
|
run_patches()
|
||||||
|
|
||||||
|
print(f'DEBUG: [5] Patches run', file=sys.stderr)
|
||||||
|
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
@ -83,8 +98,9 @@ def main():
|
|||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
|
print(f'DEBUG: [6] Importing torch modules', file=sys.stderr)
|
||||||
|
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
|
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -92,6 +108,7 @@ def main():
|
|||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
print(f'DEBUG: [7] loading restoration models', file=sys.stderr)
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
# Loading Face Restoration and ESRGAN Modules
|
||||||
gfpgan, codeformer, esrgan = load_face_restoration(opt)
|
gfpgan, codeformer, esrgan = load_face_restoration(opt)
|
||||||
|
|
||||||
@ -113,6 +130,7 @@ def main():
|
|||||||
Globals.lora_models_dir = opt.lora_path
|
Globals.lora_models_dir = opt.lora_path
|
||||||
|
|
||||||
# migrate legacy models
|
# migrate legacy models
|
||||||
|
print(f'DEBUG: [8] migrating models', file=sys.stderr)
|
||||||
ModelManager.migrate_models()
|
ModelManager.migrate_models()
|
||||||
|
|
||||||
# load the infile as a list of lines
|
# load the infile as a list of lines
|
||||||
@ -130,6 +148,7 @@ def main():
|
|||||||
|
|
||||||
model = opt.model or retrieve_last_used_model()
|
model = opt.model or retrieve_last_used_model()
|
||||||
|
|
||||||
|
print(f'DEBUG: [9] Creating generate object', file=sys.stderr)
|
||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
gen = Generate(
|
||||||
@ -156,6 +175,7 @@ def main():
|
|||||||
print(">> changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
|
print(f'DEBUG: [10] Loading default model', file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
gen.load_model()
|
gen.load_model()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -203,6 +223,7 @@ def main():
|
|||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt, completer):
|
def main_loop(gen, opt, completer):
|
||||||
"""prompt/read/execute loop"""
|
"""prompt/read/execute loop"""
|
||||||
|
print(f'DEBUG: [11] In main loop', file=sys.stderr)
|
||||||
global infile
|
global infile
|
||||||
done = False
|
done = False
|
||||||
doneAfterInFile = infile is not None
|
doneAfterInFile = infile is not None
|
||||||
@ -1321,15 +1342,16 @@ def install_missing_config_files():
|
|||||||
install ckpt configuration files that may have been added to the
|
install ckpt configuration files that may have been added to the
|
||||||
distro after original root directory configuration
|
distro after original root directory configuration
|
||||||
"""
|
"""
|
||||||
import invokeai.configs as conf
|
pass
|
||||||
from shutil import copyfile
|
# import invokeai.configs as conf
|
||||||
|
# from shutil import copyfile
|
||||||
|
|
||||||
root_configs = Path(global_config_dir(), 'stable-diffusion')
|
# root_configs = Path(global_config_dir(), 'stable-diffusion')
|
||||||
repo_configs = Path(conf.__path__[0], 'stable-diffusion')
|
# repo_configs = Path(conf.__path__[0], 'stable-diffusion')
|
||||||
for src in repo_configs.iterdir():
|
# for src in repo_configs.iterdir():
|
||||||
dest = root_configs / src.name
|
# dest = root_configs / src.name
|
||||||
if not dest.exists():
|
# if not dest.exists():
|
||||||
copyfile(src,dest)
|
# copyfile(src,dest)
|
||||||
|
|
||||||
def do_version_update(root_version: version.Version, app_version: Union[str, version.Version]):
|
def do_version_update(root_version: version.Version, app_version: Union[str, version.Version]):
|
||||||
"""
|
"""
|
||||||
|
@ -1,3 +1 @@
|
|||||||
__version__='2.3.5.post2'
|
__version__='2.3.4'
|
||||||
|
|
||||||
|
|
||||||
|
@ -620,10 +620,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(vae_key):
|
if key.startswith(vae_key):
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
|
||||||
return new_checkpoint
|
|
||||||
|
|
||||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
|
@ -12,14 +12,6 @@ from urllib import request, error as ul_error
|
|||||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
singleton = None
|
|
||||||
|
|
||||||
def get_hf_concepts_lib():
|
|
||||||
global singleton
|
|
||||||
if singleton is None:
|
|
||||||
singleton = HuggingFaceConceptsLibrary()
|
|
||||||
return singleton
|
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
'''
|
'''
|
||||||
|
@ -15,10 +15,19 @@ from compel import Compel
|
|||||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
|
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
|
||||||
Conjunction
|
Conjunction
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
|
def get_tokenizer(model) -> CLIPTokenizer:
|
||||||
|
# TODO remove legacy ckpt fallback handling
|
||||||
|
return (getattr(model, 'tokenizer', None) # diffusers
|
||||||
|
or model.cond_stage_model.tokenizer) # ldm
|
||||||
|
|
||||||
|
def get_text_encoder(model) -> Any:
|
||||||
|
# TODO remove legacy ckpt fallback handling
|
||||||
|
return (getattr(model, 'text_encoder', None) # diffusers
|
||||||
|
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||||
|
|
||||||
class UnsqueezingLDMTransformer:
|
class UnsqueezingLDMTransformer:
|
||||||
def __init__(self, ldm_transformer):
|
def __init__(self, ldm_transformer):
|
||||||
self.ldm_transformer = ldm_transformer
|
self.ldm_transformer = ldm_transformer
|
||||||
@ -32,15 +41,15 @@ class UnsqueezingLDMTransformer:
|
|||||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string,
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
model: StableDiffusionGeneratorPipeline,
|
|
||||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
|
|
||||||
compel = Compel(tokenizer=model.tokenizer,
|
tokenizer = get_tokenizer(model)
|
||||||
text_encoder=model.text_encoder,
|
text_encoder = get_text_encoder(model)
|
||||||
|
compel = Compel(tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype)
|
dtype_for_device_getter=torch_dtype)
|
||||||
|
|
||||||
@ -69,20 +78,14 @@ def get_uc_and_c_and_ec(prompt_string,
|
|||||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||||
|
|
||||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||||
|
|
||||||
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
lora_conditions=lora_conditions)
|
|
||||||
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
|
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||||
extra_conditioning_info=lora_conditioning_ec,
|
|
||||||
step_count=-1):
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
|
||||||
|
|
||||||
# now build the "real" ec
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
cross_attention_control_args=options.get(
|
cross_attention_control_args=options.get(
|
||||||
'cross_attention_control', None),
|
'cross_attention_control', None),
|
||||||
|
@ -6,7 +6,6 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import pkg_resources
|
|
||||||
from rich import box, print
|
from rich import box, print
|
||||||
from rich.console import Console, group
|
from rich.console import Console, group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
@ -40,7 +39,7 @@ def invokeai_is_running()->bool:
|
|||||||
if matches:
|
if matches:
|
||||||
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
|
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
|
||||||
return True
|
return True
|
||||||
except (psutil.AccessDenied,psutil.NoSuchProcess):
|
except psutil.AccessDenied:
|
||||||
continue
|
continue
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -73,20 +72,10 @@ def welcome(versions: dict):
|
|||||||
)
|
)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
def get_extras():
|
|
||||||
extras = ''
|
|
||||||
try:
|
|
||||||
dist = pkg_resources.get_distribution('xformers')
|
|
||||||
extras = '[xformers]'
|
|
||||||
except pkg_resources.DistributionNotFound:
|
|
||||||
pass
|
|
||||||
return extras
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
versions = get_versions()
|
versions = get_versions()
|
||||||
if invokeai_is_running():
|
if invokeai_is_running():
|
||||||
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
|
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
|
||||||
input('Press any key to continue...')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
welcome(versions)
|
welcome(versions)
|
||||||
@ -105,15 +94,13 @@ def main():
|
|||||||
elif choice=='4':
|
elif choice=='4':
|
||||||
branch = Prompt.ask('Enter an InvokeAI branch name')
|
branch = Prompt.ask('Enter an InvokeAI branch name')
|
||||||
|
|
||||||
extras = get_extras()
|
|
||||||
|
|
||||||
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
||||||
if release:
|
if release:
|
||||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip' --use-pep517 --upgrade"
|
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
|
||||||
elif tag:
|
elif tag:
|
||||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip' --use-pep517 --upgrade"
|
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
|
||||||
else:
|
else:
|
||||||
cmd = f"pip install 'invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip' --use-pep517 --upgrade"
|
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
|
||||||
print('')
|
print('')
|
||||||
print('')
|
print('')
|
||||||
if os.system(cmd)==0:
|
if os.system(cmd)==0:
|
||||||
|
@ -196,6 +196,16 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
|
self.convert_models = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||||
|
values=["Keep original format", "Convert to diffusers"],
|
||||||
|
value=0,
|
||||||
|
begin_entry_at=4,
|
||||||
|
max_height=4,
|
||||||
|
hidden=True, # will appear when imported models box is edited
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
self.cancel = self.add_widget_intelligent(
|
self.cancel = self.add_widget_intelligent(
|
||||||
npyscreen.ButtonPress,
|
npyscreen.ButtonPress,
|
||||||
name="CANCEL",
|
name="CANCEL",
|
||||||
@ -230,6 +240,8 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||||
|
|
||||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||||
|
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||||
|
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||||
|
|
||||||
def resize(self):
|
def resize(self):
|
||||||
super().resize()
|
super().resize()
|
||||||
@ -240,6 +252,13 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
if not self.show_directory_fields.value:
|
if not self.show_directory_fields.value:
|
||||||
self.autoload_directory.value = ""
|
self.autoload_directory.value = ""
|
||||||
|
|
||||||
|
def _show_hide_convert(self):
|
||||||
|
model_paths = self.import_model_paths.value or ""
|
||||||
|
autoload_directory = self.autoload_directory.value or ""
|
||||||
|
self.convert_models.hidden = (
|
||||||
|
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||||
|
)
|
||||||
|
|
||||||
def _get_starter_model_labels(self) -> List[str]:
|
def _get_starter_model_labels(self) -> List[str]:
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
label_width = 25
|
label_width = 25
|
||||||
@ -299,6 +318,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
.scan_directory: Path to a directory of models to scan and import
|
.scan_directory: Path to a directory of models to scan and import
|
||||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||||
|
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||||
"""
|
"""
|
||||||
# we're using a global here rather than storing the result in the parentapp
|
# we're using a global here rather than storing the result in the parentapp
|
||||||
# due to some bug in npyscreen that is causing attributes to be lost
|
# due to some bug in npyscreen that is causing attributes to be lost
|
||||||
@ -334,6 +354,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
# URLs and the like
|
# URLs and the like
|
||||||
selections.import_model_paths = self.import_model_paths.value.split()
|
selections.import_model_paths = self.import_model_paths.value.split()
|
||||||
|
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||||
|
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
@ -346,6 +367,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
|||||||
scan_directory=None,
|
scan_directory=None,
|
||||||
autoscan_on_startup=None,
|
autoscan_on_startup=None,
|
||||||
import_model_paths=None,
|
import_model_paths=None,
|
||||||
|
convert_to_diffusers=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
@ -365,6 +387,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
|||||||
directory_to_scan = selections.scan_directory
|
directory_to_scan = selections.scan_directory
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
scan_at_startup = selections.autoscan_on_startup
|
||||||
potential_models_to_install = selections.import_model_paths
|
potential_models_to_install = selections.import_model_paths
|
||||||
|
convert_to_diffusers = selections.convert_to_diffusers
|
||||||
|
|
||||||
install_requested_models(
|
install_requested_models(
|
||||||
install_initial_models=models_to_install,
|
install_initial_models=models_to_install,
|
||||||
@ -372,6 +395,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
|||||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||||
external_models=potential_models_to_install,
|
external_models=potential_models_to_install,
|
||||||
scan_at_startup=scan_at_startup,
|
scan_at_startup=scan_at_startup,
|
||||||
|
convert_to_diffusers=convert_to_diffusers,
|
||||||
precision="float32"
|
precision="float32"
|
||||||
if opt.full_precision
|
if opt.full_precision
|
||||||
else choose_precision(torch.device(choose_torch_device())),
|
else choose_precision(torch.device(choose_torch_device())),
|
||||||
|
@ -11,7 +11,6 @@ from tempfile import TemporaryFile
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers import logging as dlogging
|
|
||||||
from huggingface_hub import hf_hub_url
|
from huggingface_hub import hf_hub_url
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -69,6 +68,7 @@ def install_requested_models(
|
|||||||
scan_directory: Path = None,
|
scan_directory: Path = None,
|
||||||
external_models: List[str] = None,
|
external_models: List[str] = None,
|
||||||
scan_at_startup: bool = False,
|
scan_at_startup: bool = False,
|
||||||
|
convert_to_diffusers: bool = False,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
purge_deleted: bool = False,
|
purge_deleted: bool = False,
|
||||||
config_file_path: Path = None,
|
config_file_path: Path = None,
|
||||||
@ -114,16 +114,17 @@ def install_requested_models(
|
|||||||
try:
|
try:
|
||||||
model_manager.heuristic_import(
|
model_manager.heuristic_import(
|
||||||
path_url_or_repo,
|
path_url_or_repo,
|
||||||
|
convert=convert_to_diffusers,
|
||||||
config_file_callback=_pick_configuration_file,
|
config_file_callback=_pick_configuration_file,
|
||||||
commit_to_conf=config_file_path
|
commit_to_conf=config_file_path
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f'An exception has occurred: {str(e)}')
|
pass
|
||||||
|
|
||||||
if scan_at_startup and scan_directory.is_dir():
|
if scan_at_startup and scan_directory.is_dir():
|
||||||
argument = '--autoconvert'
|
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||||
initfile = Path(Globals.root, Globals.initfile)
|
initfile = Path(Globals.root, Globals.initfile)
|
||||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||||
directory = str(scan_directory).replace('\\','/')
|
directory = str(scan_directory).replace('\\','/')
|
||||||
@ -295,21 +296,13 @@ def _download_diffusion_weights(
|
|||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||||
):
|
):
|
||||||
repo_id = mconfig["repo_id"]
|
repo_id = mconfig["repo_id"]
|
||||||
revision = mconfig.get('revision',None)
|
|
||||||
model_class = (
|
model_class = (
|
||||||
StableDiffusionGeneratorPipeline
|
StableDiffusionGeneratorPipeline
|
||||||
if mconfig.get("format", None) == "diffusers"
|
if mconfig.get("format", None) == "diffusers"
|
||||||
else AutoencoderKL
|
else AutoencoderKL
|
||||||
)
|
)
|
||||||
extra_arg_list = [{"revision": revision}] if revision \
|
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||||
else [{"revision": "fp16"}, {}] if precision == "float16" \
|
|
||||||
else [{}]
|
|
||||||
path = None
|
path = None
|
||||||
|
|
||||||
# quench safety checker warnings
|
|
||||||
verbosity = dlogging.get_verbosity()
|
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
for extra_args in extra_arg_list:
|
for extra_args in extra_arg_list:
|
||||||
try:
|
try:
|
||||||
path = download_from_hf(
|
path = download_from_hf(
|
||||||
@ -325,7 +318,6 @@ def _download_diffusion_weights(
|
|||||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||||
if path:
|
if path:
|
||||||
break
|
break
|
||||||
dlogging.set_verbosity(verbosity)
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
@ -456,8 +448,6 @@ def new_config_file_contents(
|
|||||||
stanza["description"] = mod["description"]
|
stanza["description"] = mod["description"]
|
||||||
stanza["repo_id"] = mod["repo_id"]
|
stanza["repo_id"] = mod["repo_id"]
|
||||||
stanza["format"] = mod["format"]
|
stanza["format"] = mod["format"]
|
||||||
if "revision" in mod:
|
|
||||||
stanza["revision"] = mod["revision"]
|
|
||||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||||
# so we no longer require these in INITIAL_MODELS.yaml
|
# so we no longer require these in INITIAL_MODELS.yaml
|
||||||
if "width" in mod:
|
if "width" in mod:
|
||||||
@ -482,9 +472,10 @@ def new_config_file_contents(
|
|||||||
|
|
||||||
conf[model] = stanza
|
conf[model] = stanza
|
||||||
|
|
||||||
# if no default model was chosen, then we select the first one in the list
|
# if no default model was chosen, then we select the first
|
||||||
|
# one in the list
|
||||||
if not default_selected:
|
if not default_selected:
|
||||||
conf[list(conf.keys())[0]]["default"] = True
|
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||||
|
|
||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ def expand_prompts(
|
|||||||
template_file: Path,
|
template_file: Path,
|
||||||
run_invoke: bool = False,
|
run_invoke: bool = False,
|
||||||
invoke_model: str = None,
|
invoke_model: str = None,
|
||||||
invoke_outdir: Path = None,
|
invoke_outdir: str = None,
|
||||||
|
invoke_root: str = None,
|
||||||
processes_per_gpu: int = 1,
|
processes_per_gpu: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -61,6 +62,8 @@ def expand_prompts(
|
|||||||
invokeai_args = [shutil.which("invokeai"), "--from_file", "-"]
|
invokeai_args = [shutil.which("invokeai"), "--from_file", "-"]
|
||||||
if invoke_model:
|
if invoke_model:
|
||||||
invokeai_args.extend(("--model", invoke_model))
|
invokeai_args.extend(("--model", invoke_model))
|
||||||
|
if invoke_root:
|
||||||
|
invokeai_args.extend(("--root", invoke_root))
|
||||||
if invoke_outdir:
|
if invoke_outdir:
|
||||||
outdir = os.path.expanduser(invoke_outdir)
|
outdir = os.path.expanduser(invoke_outdir)
|
||||||
invokeai_args.extend(("--outdir", outdir))
|
invokeai_args.extend(("--outdir", outdir))
|
||||||
@ -79,6 +82,11 @@ def expand_prompts(
|
|||||||
)
|
)
|
||||||
import ldm.invoke.CLI
|
import ldm.invoke.CLI
|
||||||
|
|
||||||
|
print(f'DEBUG: BATCH PARENT ENVIRONMENT:')
|
||||||
|
print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
|
||||||
|
print("\n".join([f'{x}:{os.environ[x]}' for x in os.environ.keys()]))
|
||||||
|
print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
|
||||||
|
|
||||||
parent_conn, child_conn = Pipe()
|
parent_conn, child_conn = Pipe()
|
||||||
children = set()
|
children = set()
|
||||||
for i in range(processes_to_launch):
|
for i in range(processes_to_launch):
|
||||||
@ -111,6 +119,13 @@ def expand_prompts(
|
|||||||
for p in children:
|
for p in children:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
|
def _dummy_cli_main():
|
||||||
|
counter = 0
|
||||||
|
while line := sys.stdin.readline():
|
||||||
|
print(f'[{counter}] {os.getpid()} got command {line.rstrip()}\n')
|
||||||
|
counter += 1
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
def _get_fn_format(directory:str, sequence:int)->str:
|
def _get_fn_format(directory:str, sequence:int)->str:
|
||||||
"""
|
"""
|
||||||
Get a filename that doesn't exceed filename length restrictions
|
Get a filename that doesn't exceed filename length restrictions
|
||||||
@ -179,9 +194,9 @@ def _run_invoke(
|
|||||||
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
|
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
|
||||||
sys.argv = args
|
sys.argv = args
|
||||||
sys.stdin = MessageToStdin(conn_in)
|
sys.stdin = MessageToStdin(conn_in)
|
||||||
sys.stdout = FilterStream(sys.stdout, include=re.compile("^\[\d+\]"))
|
# sys.stdout = FilterStream(sys.stdout, include=re.compile("^\[\d+\]"))
|
||||||
with open(logfile, "w") as stderr, redirect_stderr(stderr):
|
# with open(logfile, "w") as stderr, redirect_stderr(stderr):
|
||||||
entry_point()
|
entry_point()
|
||||||
|
|
||||||
|
|
||||||
def _filter_output(stream: TextIOBase):
|
def _filter_output(stream: TextIOBase):
|
||||||
@ -238,6 +253,10 @@ def main():
|
|||||||
default=1,
|
default=1,
|
||||||
help="When executing invokeai, how many parallel processes to execute per CUDA GPU.",
|
help="When executing invokeai, how many parallel processes to execute per CUDA GPU.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--root_dir',
|
||||||
|
default=None,
|
||||||
|
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai' )
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
if opt.example:
|
if opt.example:
|
||||||
@ -261,6 +280,7 @@ def main():
|
|||||||
run_invoke=opt.invoke,
|
run_invoke=opt.invoke,
|
||||||
invoke_model=opt.model,
|
invoke_model=opt.model,
|
||||||
invoke_outdir=opt.outdir,
|
invoke_outdir=opt.outdir,
|
||||||
|
invoke_root=opt.root,
|
||||||
processes_per_gpu=opt.processes_per_gpu,
|
processes_per_gpu=opt.processes_per_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -400,15 +400,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
submodels = []
|
values = [getattr(self, name) for name in module_names.keys()]
|
||||||
for name in module_names.keys():
|
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||||
if hasattr(self, name):
|
|
||||||
value = getattr(self, name)
|
|
||||||
else:
|
|
||||||
value = getattr(self.config, name)
|
|
||||||
if isinstance(value, torch.nn.Module):
|
|
||||||
submodels.append(value)
|
|
||||||
return submodels
|
|
||||||
|
|
||||||
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
@ -474,12 +467,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
extra_conditioning_info = conditioning_data.extra
|
extra_conditioning_info = conditioning_data.extra
|
||||||
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
|
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
step_count=len(self.scheduler.timesteps)
|
||||||
step_count=len(self.scheduler.timesteps)
|
|
||||||
):
|
):
|
||||||
|
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.config.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents)
|
latents=latents)
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
@ -763,7 +755,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def channels(self) -> int:
|
def channels(self) -> int:
|
||||||
"""Compatible with DiffusionWrapper"""
|
"""Compatible with DiffusionWrapper"""
|
||||||
return self.unet.config.in_channels
|
return self.unet.in_channels
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||||
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
@ -30,10 +31,11 @@ from huggingface_hub import scan_cache_dir
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from ldm.invoke.devices import CPU_DEVICE
|
from ldm.invoke.devices import CPU_DEVICE
|
||||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
from ldm.invoke.globals import Globals, global_cache_dir
|
from ldm.invoke.globals import Globals, global_cache_dir
|
||||||
from ldm.util import ask_user, download_with_resume, url_attachment_name
|
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
@ -368,9 +370,14 @@ class ModelManager(object):
|
|||||||
print(
|
print(
|
||||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||||
)
|
)
|
||||||
from .ckpt_to_diffuser import (
|
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
|
||||||
)
|
# try:
|
||||||
|
# if self.list_models()[self.current_model]['status'] == 'active':
|
||||||
|
# self.offload_model(self.current_model)
|
||||||
|
# except Exception:
|
||||||
|
# pass
|
||||||
|
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
@ -416,9 +423,9 @@ class ModelManager(object):
|
|||||||
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
pipeline_args.update(torch_dtype=torch.float16)
|
pipeline_args.update(torch_dtype=torch.float16)
|
||||||
revision = mconfig.get('revision') or ('fp16' if using_fp16 else None)
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
fp_args_list = [{"revision": revision}] if revision else []
|
else:
|
||||||
fp_args_list.append({})
|
fp_args_list = [{}]
|
||||||
|
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
@ -432,7 +439,7 @@ class ModelManager(object):
|
|||||||
**fp_args,
|
**fp_args,
|
||||||
)
|
)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
if 'Revision Not Found' in str(e):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
@ -1155,7 +1162,7 @@ class ModelManager(object):
|
|||||||
return self.device.type == "cuda"
|
return self.device.type == "cuda"
|
||||||
|
|
||||||
def _diffuser_sha256(
|
def _diffuser_sha256(
|
||||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
self, name_or_path: Union[str, Path], chunksize=4096
|
||||||
) -> Union[str, bytes]:
|
) -> Union[str, bytes]:
|
||||||
path = None
|
path = None
|
||||||
if isinstance(name_or_path, Path):
|
if isinstance(name_or_path, Path):
|
||||||
@ -1229,17 +1236,6 @@ class ModelManager(object):
|
|||||||
return vae_path
|
return vae_path
|
||||||
|
|
||||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||||
using_fp16 = self.precision == "float16"
|
|
||||||
dtype = torch.float16 if using_fp16 else torch.float32
|
|
||||||
|
|
||||||
# Handle the common case of a user shoving a VAE .ckpt into
|
|
||||||
# the vae field for a diffusers. We convert it into diffusers
|
|
||||||
# format and use it.
|
|
||||||
if isinstance(vae_config,(str,Path)):
|
|
||||||
return self.convert_vae(vae_config).to(dtype=dtype)
|
|
||||||
elif isinstance(vae_config,DictConfig) and (vae_path := vae_config.get('path')):
|
|
||||||
return self.convert_vae(vae_path).to(dtype=dtype)
|
|
||||||
|
|
||||||
vae_args = {}
|
vae_args = {}
|
||||||
try:
|
try:
|
||||||
name_or_path = self.model_name_or_path(vae_config)
|
name_or_path = self.model_name_or_path(vae_config)
|
||||||
@ -1247,6 +1243,7 @@ class ModelManager(object):
|
|||||||
return None
|
return None
|
||||||
if name_or_path is None:
|
if name_or_path is None:
|
||||||
return None
|
return None
|
||||||
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
vae_args.update(
|
vae_args.update(
|
||||||
cache_dir=global_cache_dir("hub"),
|
cache_dir=global_cache_dir("hub"),
|
||||||
@ -1286,32 +1283,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_vae(vae_path: Union[Path,str])->AutoencoderKL:
|
|
||||||
print(" | A checkpoint VAE was detected. Converting to diffusers format.")
|
|
||||||
vae_path = Path(Globals.root,vae_path).resolve()
|
|
||||||
|
|
||||||
from .ckpt_to_diffuser import (
|
|
||||||
create_vae_diffusers_config,
|
|
||||||
convert_ldm_vae_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
vae_path = Path(vae_path)
|
|
||||||
if vae_path.suffix in ['.pt','.ckpt']:
|
|
||||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
|
||||||
else:
|
|
||||||
vae_state_dict = safetensors.torch.load_file(vae_path)
|
|
||||||
if 'state_dict' in vae_state_dict:
|
|
||||||
vae_state_dict = vae_state_dict['state_dict']
|
|
||||||
# TODO: see if this works with 1.x inpaint models and 2.x models
|
|
||||||
config_file_path = Path(Globals.root,"configs/stable-diffusion/v1-inference.yaml")
|
|
||||||
original_conf = OmegaConf.load(config_file_path)
|
|
||||||
vae_config = create_vae_diffusers_config(original_conf, image_size=512) # TODO: fix
|
|
||||||
diffusers_vae = convert_ldm_vae_state_dict(vae_state_dict,vae_config)
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
|
||||||
vae.load_state_dict(diffusers_vae)
|
|
||||||
return vae
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _delete_model_from_cache(repo_id):
|
def _delete_model_from_cache(repo_id):
|
||||||
cache_info = scan_cache_dir(global_cache_dir("diffusers"))
|
cache_info = scan_cache_dir(global_cache_dir("diffusers"))
|
||||||
|
@ -13,7 +13,7 @@ import re
|
|||||||
import atexit
|
import atexit
|
||||||
from typing import List
|
from typing import List
|
||||||
from ldm.invoke.args import Args
|
from ldm.invoke.args import Args
|
||||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.modules.lora_manager import LoraManager
|
from ldm.modules.lora_manager import LoraManager
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ class Completer(object):
|
|||||||
def _concept_completions(self, text, state):
|
def _concept_completions(self, text, state):
|
||||||
if self.concepts is None:
|
if self.concepts is None:
|
||||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||||
self.concepts = get_hf_concepts_lib()
|
self.concepts = HuggingFaceConceptsLibrary()
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
else:
|
else:
|
||||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
|
@ -14,6 +14,7 @@ from torch import nn
|
|||||||
|
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from ldm.invoke.devices import torch_dtype
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -162,7 +163,7 @@ class Context:
|
|||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
class InvokeAICrossAttentionMixin:
|
||||||
"""
|
"""
|
||||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
and dymamic slicing strategy selection.
|
and dymamic slicing strategy selection.
|
||||||
"""
|
"""
|
||||||
@ -177,7 +178,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
Set custom attention calculator to be called when attention is calculated
|
Set custom attention calculator to be called when attention is calculated
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
`module` is the current Attention module for which the callback is being invoked.
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
@ -287,7 +288,16 @@ class InvokeAICrossAttentionMixin:
|
|||||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
|
||||||
|
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
|
||||||
|
if is_running_diffusers:
|
||||||
|
unet = model
|
||||||
|
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
remove_attention_function(model)
|
||||||
|
|
||||||
|
|
||||||
|
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||||
"""
|
"""
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
@ -313,19 +323,26 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
|||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
context.cross_attention_mask = mask.to(device)
|
||||||
context.cross_attention_index_map = indices.to(device)
|
context.cross_attention_index_map = indices.to(device)
|
||||||
old_attn_processors = unet.attn_processors
|
if is_running_diffusers:
|
||||||
if torch.backends.mps.is_available():
|
unet = model
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
old_attn_processors = unet.attn_processors
|
||||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
if torch.backends.mps.is_available():
|
||||||
|
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||||
|
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||||
|
else:
|
||||||
|
# try to re-use an existing slice size
|
||||||
|
default_slice_size = 4
|
||||||
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
else:
|
else:
|
||||||
# try to re-use an existing slice size
|
context.register_cross_attention_modules(model)
|
||||||
default_slice_size = 4
|
inject_attention_function(model, context)
|
||||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
|
||||||
|
|
||||||
|
|
||||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
from ldm.modules.attention import CrossAttention # avoid circular import # TODO: rename as in diffusers?
|
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
attention_module_tuples = [(name,module) for name, module in model.named_modules() if
|
||||||
@ -431,7 +448,7 @@ def get_mem_free_total(device):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin):
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -456,8 +473,8 @@ class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, Invo
|
|||||||
"""
|
"""
|
||||||
# base implementation
|
# base implementation
|
||||||
|
|
||||||
class AttnProcessor:
|
class CrossAttnProcessor:
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
@ -486,7 +503,7 @@ from dataclasses import field, dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.attention_processor import Attention, AttnProcessor, SlicedAttnProcessor
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -531,7 +548,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
|
|
||||||
# TODO: dynamically pick slice size based on memory conditions
|
# TODO: dynamically pick slice size based on memory conditions
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||||
# kwargs
|
# kwargs
|
||||||
swap_cross_attn_context: SwapCrossAttnContext=None):
|
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||||
|
|
||||||
|
@ -12,6 +12,17 @@ class DDIMSampler(Sampler):
|
|||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
# This is the central routine
|
# This is the central routine
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -38,6 +38,15 @@ class CFGDenoiser(nn.Module):
|
|||||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||||
|
@ -14,6 +14,17 @@ class PLMSSampler(Sampler):
|
|||||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||||
super().__init__(model,schedule,model.num_timesteps, device)
|
super().__init__(model,schedule,model.num_timesteps, device)
|
||||||
|
|
||||||
|
def prepare_to_sample(self, t_enc, **kwargs):
|
||||||
|
super().prepare_to_sample(t_enc, **kwargs)
|
||||||
|
|
||||||
|
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||||
|
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.restore_default_cross_attention()
|
||||||
|
|
||||||
|
|
||||||
# this is the essential routine
|
# this is the essential routine
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Callable, Optional, Union, Any
|
from typing import Callable, Optional, Union, Any, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.cross_attention_control import (
|
from ldm.models.diffusion.cross_attention_control import (
|
||||||
Arguments,
|
Arguments,
|
||||||
setup_cross_attention_control_attention_processors,
|
restore_default_cross_attention,
|
||||||
|
override_cross_attention,
|
||||||
Context,
|
Context,
|
||||||
get_cross_attention_modules,
|
get_cross_attention_modules,
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -83,45 +84,66 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = Globals.sequential_guidance
|
self.sequential_guidance = Globals.sequential_guidance
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
clss,
|
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
|
||||||
step_count: int
|
|
||||||
):
|
):
|
||||||
old_attn_processors = None
|
old_attn_processor = None
|
||||||
if extra_conditioning_info and (
|
if extra_conditioning_info and (
|
||||||
extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info.wants_cross_attention_control
|
||||||
| extra_conditioning_info.has_lora_conditions
|
| extra_conditioning_info.has_lora_conditions
|
||||||
):
|
):
|
||||||
old_attn_processors = unet.attn_processors
|
old_attn_processor = self.override_attention_processors(
|
||||||
# Load lora conditions into the model
|
extra_conditioning_info, step_count=step_count
|
||||||
if extra_conditioning_info.has_lora_conditions:
|
)
|
||||||
for condition in extra_conditioning_info.lora_conditions:
|
|
||||||
condition() # target model is stored in condition state for some reason
|
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
|
||||||
cross_attention_control_context = Context(
|
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
|
||||||
step_count=step_count,
|
|
||||||
)
|
|
||||||
setup_cross_attention_control_attention_processors(
|
|
||||||
unet,
|
|
||||||
cross_attention_control_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if old_attn_processors is not None:
|
if old_attn_processor is not None:
|
||||||
unet.set_attn_processor(old_attn_processors)
|
self.restore_default_cross_attention(old_attn_processor)
|
||||||
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
|
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
|
||||||
for lora_condition in extra_conditioning_info.lora_conditions:
|
for lora_condition in extra_conditioning_info.lora_conditions:
|
||||||
lora_condition.unload()
|
lora_condition.unload()
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
|
def override_attention_processors(
|
||||||
|
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||||
|
) -> Dict[str, AttnProcessor]:
|
||||||
|
"""
|
||||||
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
|
"""
|
||||||
|
old_attn_processors = self.model.attn_processors
|
||||||
|
|
||||||
|
# Load lora conditions into the model
|
||||||
|
if conditioning.has_lora_conditions:
|
||||||
|
for condition in conditioning.lora_conditions:
|
||||||
|
condition(self.model)
|
||||||
|
|
||||||
|
if conditioning.wants_cross_attention_control:
|
||||||
|
self.cross_attention_control_context = Context(
|
||||||
|
arguments=conditioning.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
|
)
|
||||||
|
override_cross_attention(
|
||||||
|
self.model,
|
||||||
|
self.cross_attention_control_context,
|
||||||
|
is_running_diffusers=self.is_running_diffusers,
|
||||||
|
)
|
||||||
|
return old_attn_processors
|
||||||
|
|
||||||
|
def restore_default_cross_attention(
|
||||||
|
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
|
||||||
|
):
|
||||||
|
self.cross_attention_control_context = None
|
||||||
|
restore_default_cross_attention(
|
||||||
|
self.model,
|
||||||
|
is_running_diffusers=self.is_running_diffusers,
|
||||||
|
processors_to_restore=processors_to_restore,
|
||||||
|
)
|
||||||
|
|
||||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||||
def callback(slice, dim, offset, slice_size, key):
|
def callback(slice, dim, offset, slice_size, key):
|
||||||
if dim is not None:
|
if dim is not None:
|
||||||
|
@ -6,7 +6,7 @@ from torch import nn
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
from ldm.data.personalized import per_img_token_list
|
from ldm.data.personalized import per_img_token_list
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -39,7 +39,7 @@ class EmbeddingManager(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
self.concepts_library=get_hf_concepts_lib()
|
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||||
|
|
||||||
self.string_to_token_dict = {}
|
self.string_to_token_dict = {}
|
||||||
self.string_to_param_dict = nn.ParameterDict()
|
self.string_to_param_dict = nn.ParameterDict()
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
import json
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compel import Compel
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from filelock import FileLock, Timeout
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
from ..invoke.globals import global_lora_models_dir, Globals
|
from ldm.invoke.devices import choose_torch_device
|
||||||
from ..invoke.devices import choose_torch_device
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts
|
This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts
|
||||||
@ -18,11 +17,6 @@ To be removed once support for diffusers LoRA weights is well supported
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class IncompatibleModelException(Exception):
|
|
||||||
"Raised when there is an attempt to load a LoRA into a model that is incompatible with it"
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayer:
|
class LoRALayer:
|
||||||
lora_name: str
|
lora_name: str
|
||||||
name: str
|
name: str
|
||||||
@ -37,14 +31,18 @@ class LoRALayer:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||||
|
|
||||||
def forward(self, lora, input_h):
|
def forward(self, lora, input_h, output):
|
||||||
if self.mid is None:
|
if self.mid is None:
|
||||||
weight = self.up(self.down(*input_h))
|
output = (
|
||||||
|
output
|
||||||
|
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weight = self.up(self.mid(self.down(*input_h)))
|
output = (
|
||||||
|
output
|
||||||
return weight * lora.multiplier * self.scale
|
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
class LoHALayer:
|
class LoHALayer:
|
||||||
lora_name: str
|
lora_name: str
|
||||||
@ -66,7 +64,8 @@ class LoHALayer:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||||
|
|
||||||
def forward(self, lora, input_h):
|
def forward(self, lora, input_h, output):
|
||||||
|
|
||||||
if type(self.org_module) == torch.nn.Conv2d:
|
if type(self.org_module) == torch.nn.Conv2d:
|
||||||
op = torch.nn.functional.conv2d
|
op = torch.nn.functional.conv2d
|
||||||
extra_args = dict(
|
extra_args = dict(
|
||||||
@ -81,87 +80,21 @@ class LoHALayer:
|
|||||||
extra_args = {}
|
extra_args = {}
|
||||||
|
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rebuild1 = torch.einsum(
|
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||||
)
|
|
||||||
rebuild2 = torch.einsum(
|
|
||||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
|
||||||
)
|
|
||||||
weight = rebuild1 * rebuild2
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
bias = self.bias if self.bias is not None else 0
|
||||||
return op(
|
return output + op(
|
||||||
*input_h,
|
*input_h,
|
||||||
(weight + bias).view(self.org_module.weight.shape),
|
(weight + bias).view(self.org_module.weight.shape),
|
||||||
None,
|
None,
|
||||||
**extra_args,
|
**extra_args,
|
||||||
) * lora.multiplier * self.scale
|
) * lora.multiplier * self.scale
|
||||||
|
|
||||||
class LoKRLayer:
|
|
||||||
lora_name: str
|
|
||||||
name: str
|
|
||||||
scale: float
|
|
||||||
|
|
||||||
w1: Optional[torch.Tensor] = None
|
|
||||||
w1_a: Optional[torch.Tensor] = None
|
|
||||||
w1_b: Optional[torch.Tensor] = None
|
|
||||||
w2: Optional[torch.Tensor] = None
|
|
||||||
w2_a: Optional[torch.Tensor] = None
|
|
||||||
w2_b: Optional[torch.Tensor] = None
|
|
||||||
t2: Optional[torch.Tensor] = None
|
|
||||||
bias: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
org_module: torch.nn.Module
|
|
||||||
|
|
||||||
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
|
|
||||||
self.lora_name = lora_name
|
|
||||||
self.name = name
|
|
||||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
|
||||||
|
|
||||||
def forward(self, lora, input_h):
|
|
||||||
|
|
||||||
if type(self.org_module) == torch.nn.Conv2d:
|
|
||||||
op = torch.nn.functional.conv2d
|
|
||||||
extra_args = dict(
|
|
||||||
stride=self.org_module.stride,
|
|
||||||
padding=self.org_module.padding,
|
|
||||||
dilation=self.org_module.dilation,
|
|
||||||
groups=self.org_module.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = torch.nn.functional.linear
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
w1 = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
|
|
||||||
|
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
|
||||||
return op(
|
|
||||||
*input_h,
|
|
||||||
(weight + bias).view(self.org_module.weight.shape),
|
|
||||||
None,
|
|
||||||
**extra_args
|
|
||||||
) * lora.multiplier * self.scale
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModuleWrapper:
|
class LoRAModuleWrapper:
|
||||||
unet: UNet2DConditionModel
|
unet: UNet2DConditionModel
|
||||||
@ -178,22 +111,12 @@ class LoRAModuleWrapper:
|
|||||||
self.applied_loras = {}
|
self.applied_loras = {}
|
||||||
self.loaded_loras = {}
|
self.loaded_loras = {}
|
||||||
|
|
||||||
self.UNET_TARGET_REPLACE_MODULE = [
|
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"]
|
||||||
"Transformer2DModel",
|
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
|
||||||
"Attention",
|
|
||||||
"ResnetBlock2D",
|
|
||||||
"Downsample2D",
|
|
||||||
"Upsample2D",
|
|
||||||
"SpatialTransformer",
|
|
||||||
]
|
|
||||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = [
|
|
||||||
"ResidualAttentionBlock",
|
|
||||||
"CLIPAttention",
|
|
||||||
"CLIPMLP",
|
|
||||||
]
|
|
||||||
self.LORA_PREFIX_UNET = "lora_unet"
|
self.LORA_PREFIX_UNET = "lora_unet"
|
||||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
|
||||||
def find_modules(
|
def find_modules(
|
||||||
prefix, root_module: torch.nn.Module, target_replace_modules
|
prefix, root_module: torch.nn.Module, target_replace_modules
|
||||||
) -> dict[str, torch.nn.Module]:
|
) -> dict[str, torch.nn.Module]:
|
||||||
@ -224,6 +147,7 @@ class LoRAModuleWrapper:
|
|||||||
self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE
|
self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def lora_forward_hook(self, name):
|
def lora_forward_hook(self, name):
|
||||||
wrapper = self
|
wrapper = self
|
||||||
|
|
||||||
@ -235,7 +159,7 @@ class LoRAModuleWrapper:
|
|||||||
layer = lora.layers.get(name, None)
|
layer = lora.layers.get(name, None)
|
||||||
if layer is None:
|
if layer is None:
|
||||||
continue
|
continue
|
||||||
output += layer.forward(lora, input_h)
|
output = layer.forward(lora, input_h, output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
return lora_forward
|
return lora_forward
|
||||||
@ -256,7 +180,6 @@ class LoRAModuleWrapper:
|
|||||||
def clear_loaded_loras(self):
|
def clear_loaded_loras(self):
|
||||||
self.loaded_loras.clear()
|
self.loaded_loras.clear()
|
||||||
|
|
||||||
|
|
||||||
class LoRA:
|
class LoRA:
|
||||||
name: str
|
name: str
|
||||||
layers: dict[str, LoRALayer]
|
layers: dict[str, LoRALayer]
|
||||||
@ -282,6 +205,7 @@ class LoRA:
|
|||||||
state_dict_groupped[stem] = dict()
|
state_dict_groupped[stem] = dict()
|
||||||
state_dict_groupped[stem][leaf] = value
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
|
||||||
for stem, values in state_dict_groupped.items():
|
for stem, values in state_dict_groupped.items():
|
||||||
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
||||||
wrapped = self.wrapper.text_modules.get(stem, None)
|
wrapped = self.wrapper.text_modules.get(stem, None)
|
||||||
@ -302,59 +226,34 @@ class LoRA:
|
|||||||
if "alpha" in values:
|
if "alpha" in values:
|
||||||
alpha = values["alpha"].item()
|
alpha = values["alpha"].item()
|
||||||
|
|
||||||
if (
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
"bias_indices" in values
|
|
||||||
and "bias_values" in values
|
|
||||||
and "bias_size" in values
|
|
||||||
):
|
|
||||||
bias = torch.sparse_coo_tensor(
|
bias = torch.sparse_coo_tensor(
|
||||||
values["bias_indices"],
|
values["bias_indices"],
|
||||||
values["bias_values"],
|
values["bias_values"],
|
||||||
tuple(values["bias_size"]),
|
tuple(values["bias_size"]),
|
||||||
).to(device=self.device, dtype=self.dtype)
|
).to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_down.weight" in values:
|
||||||
value_down = values["lora_down.weight"]
|
value_down = values["lora_down.weight"]
|
||||||
value_mid = values.get("lora_mid.weight", None)
|
value_mid = values.get("lora_mid.weight", None)
|
||||||
value_up = values["lora_up.weight"]
|
value_up = values["lora_up.weight"]
|
||||||
|
|
||||||
if type(wrapped) == torch.nn.Conv2d:
|
if type(wrapped) == torch.nn.Conv2d:
|
||||||
if value_mid is not None:
|
if value_mid is not None:
|
||||||
layer_down = torch.nn.Conv2d(
|
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False)
|
||||||
value_down.shape[1], value_down.shape[0], (1, 1), bias=False
|
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||||
)
|
|
||||||
layer_mid = torch.nn.Conv2d(
|
|
||||||
value_mid.shape[1],
|
|
||||||
value_mid.shape[0],
|
|
||||||
wrapped.kernel_size,
|
|
||||||
wrapped.stride,
|
|
||||||
wrapped.padding,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
layer_down = torch.nn.Conv2d(
|
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||||
value_down.shape[1],
|
layer_mid = None
|
||||||
value_down.shape[0],
|
|
||||||
wrapped.kernel_size,
|
|
||||||
wrapped.stride,
|
|
||||||
wrapped.padding,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
layer_mid = None
|
|
||||||
|
|
||||||
layer_up = torch.nn.Conv2d(
|
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False)
|
||||||
value_up.shape[1], value_up.shape[0], (1, 1), bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
elif type(wrapped) == torch.nn.Linear:
|
elif type(wrapped) == torch.nn.Linear:
|
||||||
layer_down = torch.nn.Linear(
|
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False)
|
||||||
value_down.shape[1], value_down.shape[0], bias=False
|
layer_mid = None
|
||||||
)
|
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False)
|
||||||
layer_mid = None
|
|
||||||
layer_up = torch.nn.Linear(
|
|
||||||
value_up.shape[1], value_up.shape[0], bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
@ -362,90 +261,52 @@ class LoRA:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
layer_down.weight.copy_(value_down)
|
layer_down.weight.copy_(value_down)
|
||||||
if layer_mid is not None:
|
if layer_mid is not None:
|
||||||
layer_mid.weight.copy_(value_mid)
|
layer_mid.weight.copy_(value_mid)
|
||||||
layer_up.weight.copy_(value_up)
|
layer_up.weight.copy_(value_up)
|
||||||
|
|
||||||
|
|
||||||
layer_down.to(device=self.device, dtype=self.dtype)
|
layer_down.to(device=self.device, dtype=self.dtype)
|
||||||
if layer_mid is not None:
|
if layer_mid is not None:
|
||||||
layer_mid.to(device=self.device, dtype=self.dtype)
|
layer_mid.to(device=self.device, dtype=self.dtype)
|
||||||
layer_up.to(device=self.device, dtype=self.dtype)
|
layer_up.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
rank = value_down.shape[0]
|
rank = value_down.shape[0]
|
||||||
|
|
||||||
layer = LoRALayer(self.name, stem, rank, alpha)
|
layer = LoRALayer(self.name, stem, rank, alpha)
|
||||||
# layer.bias = bias # TODO: find and debug lora/locon with bias
|
#layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||||
layer.down = layer_down
|
layer.down = layer_down
|
||||||
layer.mid = layer_mid
|
layer.mid = layer_mid
|
||||||
layer.up = layer_up
|
layer.up = layer_up
|
||||||
|
|
||||||
# loha
|
# loha
|
||||||
elif "hada_w1_b" in values:
|
elif "hada_w1_b" in values:
|
||||||
|
|
||||||
rank = values["hada_w1_b"].shape[0]
|
rank = values["hada_w1_b"].shape[0]
|
||||||
|
|
||||||
layer = LoHALayer(self.name, stem, rank, alpha)
|
layer = LoHALayer(self.name, stem, rank, alpha)
|
||||||
layer.org_module = wrapped
|
layer.org_module = wrapped
|
||||||
layer.bias = bias
|
layer.bias = bias
|
||||||
|
|
||||||
layer.w1_a = values["hada_w1_a"].to(
|
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||||
device=self.device, dtype=self.dtype
|
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||||
)
|
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||||
layer.w1_b = values["hada_w1_b"].to(
|
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
layer.w2_a = values["hada_w2_a"].to(
|
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
layer.w2_b = values["hada_w2_b"].to(
|
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if "hada_t1" in values:
|
if "hada_t1" in values:
|
||||||
layer.t1 = values["hada_t1"].to(
|
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
layer.t1 = None
|
layer.t1 = None
|
||||||
|
|
||||||
if "hada_t2" in values:
|
if "hada_t2" in values:
|
||||||
layer.t2 = values["hada_t2"].to(
|
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
|
||||||
device=self.device, dtype=self.dtype
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
layer.t2 = None
|
layer.t2 = None
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
|
||||||
rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
|
||||||
rank = values["lokr_w2_b"].shape[0]
|
|
||||||
else:
|
|
||||||
rank = None # unscaled
|
|
||||||
|
|
||||||
layer = LoKRLayer(self.name, stem, rank, alpha)
|
|
||||||
layer.org_module = wrapped
|
|
||||||
layer.bias = bias
|
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
|
||||||
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
|
|
||||||
else:
|
|
||||||
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
|
|
||||||
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
|
|
||||||
else:
|
|
||||||
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
|
|
||||||
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
||||||
@ -456,25 +317,14 @@ class LoRA:
|
|||||||
|
|
||||||
|
|
||||||
class KohyaLoraManager:
|
class KohyaLoraManager:
|
||||||
|
def __init__(self, pipe, lora_path):
|
||||||
def __init__(self, pipe):
|
|
||||||
self.vector_length_cache_path = self.lora_path / '.vectorlength.cache'
|
|
||||||
self.unet = pipe.unet
|
self.unet = pipe.unet
|
||||||
|
self.lora_path = lora_path
|
||||||
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
|
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
|
||||||
self.text_encoder = pipe.text_encoder
|
self.text_encoder = pipe.text_encoder
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
self.dtype = pipe.unet.dtype
|
self.dtype = pipe.unet.dtype
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@property
|
|
||||||
def lora_path(cls)->Path:
|
|
||||||
return Path(global_lora_models_dir())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@property
|
|
||||||
def vector_length_cache_path(cls)->Path:
|
|
||||||
return cls.lora_path / '.vectorlength.cache'
|
|
||||||
|
|
||||||
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
def load_lora_module(self, name, path_file, multiplier: float = 1.0):
|
||||||
print(f" | Found lora {name} at {path_file}")
|
print(f" | Found lora {name} at {path_file}")
|
||||||
if path_file.suffix == ".safetensors":
|
if path_file.suffix == ".safetensors":
|
||||||
@ -482,9 +332,6 @@ class KohyaLoraManager:
|
|||||||
else:
|
else:
|
||||||
checkpoint = torch.load(path_file, map_location="cpu")
|
checkpoint = torch.load(path_file, map_location="cpu")
|
||||||
|
|
||||||
if not self.check_model_compatibility(checkpoint):
|
|
||||||
raise IncompatibleModelException
|
|
||||||
|
|
||||||
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
|
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
|
||||||
lora.load_from_dict(checkpoint)
|
lora.load_from_dict(checkpoint)
|
||||||
self.wrapper.loaded_loras[name] = lora
|
self.wrapper.loaded_loras[name] = lora
|
||||||
@ -492,14 +339,12 @@ class KohyaLoraManager:
|
|||||||
return lora
|
return lora
|
||||||
|
|
||||||
def apply_lora_model(self, name, mult: float = 1.0):
|
def apply_lora_model(self, name, mult: float = 1.0):
|
||||||
path_file = None
|
|
||||||
for suffix in ["ckpt", "safetensors", "pt"]:
|
for suffix in ["ckpt", "safetensors", "pt"]:
|
||||||
path_files = [x for x in Path(self.lora_path).glob(f"**/{name}.{suffix}")]
|
path_file = Path(self.lora_path, f"{name}.{suffix}")
|
||||||
if len(path_files):
|
if path_file.exists():
|
||||||
path_file = path_files[0]
|
|
||||||
print(f" | Loading lora {path_file.name} with weight {mult}")
|
print(f" | Loading lora {path_file.name} with weight {mult}")
|
||||||
break
|
break
|
||||||
if not path_file:
|
if not path_file.exists():
|
||||||
print(f" ** Unable to find lora: {name}")
|
print(f" ** Unable to find lora: {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -510,89 +355,13 @@ class KohyaLoraManager:
|
|||||||
lora.multiplier = mult
|
lora.multiplier = mult
|
||||||
self.wrapper.applied_loras[name] = lora
|
self.wrapper.applied_loras[name] = lora
|
||||||
|
|
||||||
def unload_applied_lora(self, lora_name: str) -> bool:
|
def unload_applied_lora(self, lora_name: str):
|
||||||
"""If the indicated LoRA has previously been applied then
|
|
||||||
unload it and return True. Return False if the LoRA was
|
|
||||||
not previously applied (for status reporting)
|
|
||||||
"""
|
|
||||||
if lora_name in self.wrapper.applied_loras:
|
if lora_name in self.wrapper.applied_loras:
|
||||||
del self.wrapper.applied_loras[lora_name]
|
del self.wrapper.applied_loras[lora_name]
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def unload_lora(self, lora_name: str) -> bool:
|
def unload_lora(self, lora_name: str):
|
||||||
if lora_name in self.wrapper.loaded_loras:
|
if lora_name in self.wrapper.loaded_loras:
|
||||||
del self.wrapper.loaded_loras[lora_name]
|
del self.wrapper.loaded_loras[lora_name]
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def clear_loras(self):
|
def clear_loras(self):
|
||||||
self.wrapper.clear_applied_loras()
|
self.wrapper.clear_applied_loras()
|
||||||
|
|
||||||
def check_model_compatibility(self, checkpoint) -> bool:
|
|
||||||
"""Checks whether the LoRA checkpoint is compatible with the token vector
|
|
||||||
length of the model that this manager is associated with.
|
|
||||||
"""
|
|
||||||
model_token_vector_length = (
|
|
||||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
|
||||||
)
|
|
||||||
lora_token_vector_length = self.vector_length_from_checkpoint(checkpoint)
|
|
||||||
return model_token_vector_length == lora_token_vector_length
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vector_length_from_checkpoint(checkpoint: dict) -> int:
|
|
||||||
"""Return the vector token length for the passed LoRA checkpoint object.
|
|
||||||
This is used to determine which SD model version the LoRA was based on.
|
|
||||||
768 -> SDv1
|
|
||||||
1024-> SDv2
|
|
||||||
"""
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
|
||||||
lora_token_vector_length = (
|
|
||||||
checkpoint[key1].shape[1]
|
|
||||||
if key1 in checkpoint
|
|
||||||
else checkpoint[key2].shape[0]
|
|
||||||
if key2 in checkpoint
|
|
||||||
else 768
|
|
||||||
)
|
|
||||||
return lora_token_vector_length
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def vector_length_from_checkpoint_file(self, checkpoint_path: Path) -> int:
|
|
||||||
with LoraVectorLengthCache(self.vector_length_cache_path) as cache:
|
|
||||||
if str(checkpoint_path) not in cache:
|
|
||||||
if checkpoint_path.suffix == ".safetensors":
|
|
||||||
checkpoint = load_file(
|
|
||||||
checkpoint_path.absolute().as_posix(), device="cpu"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
||||||
cache[str(checkpoint_path)] = KohyaLoraManager.vector_length_from_checkpoint(
|
|
||||||
checkpoint
|
|
||||||
)
|
|
||||||
return cache[str(checkpoint_path)]
|
|
||||||
|
|
||||||
class LoraVectorLengthCache(object):
|
|
||||||
def __init__(self, cache_path: Path):
|
|
||||||
self.cache_path = cache_path
|
|
||||||
self.lock = FileLock(Path(cache_path.parent, ".cachelock"))
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.lock.acquire(timeout=10)
|
|
||||||
try:
|
|
||||||
if self.cache_path.exists():
|
|
||||||
with open(self.cache_path, "r") as json_file:
|
|
||||||
self.cache = json.load(json_file)
|
|
||||||
except Timeout:
|
|
||||||
print(
|
|
||||||
"** Can't acquire lock on lora vector length cache. Operations will be slower"
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, OSError):
|
|
||||||
self.cache_path.unlink()
|
|
||||||
return self.cache
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
with open(self.cache_path, "w") as json_file:
|
|
||||||
json.dump(self.cache, json_file)
|
|
||||||
self.lock.release()
|
|
||||||
|
@ -1,101 +1,66 @@
|
|||||||
import os
|
import os
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
|
|
||||||
from ldm.invoke.globals import global_lora_models_dir
|
from ldm.invoke.globals import global_lora_models_dir
|
||||||
from .kohya_lora_manager import KohyaLoraManager, IncompatibleModelException
|
from .kohya_lora_manager import KohyaLoraManager
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
class LoraCondition:
|
class LoraCondition:
|
||||||
name: str
|
name: str
|
||||||
weight: float
|
weight: float
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
|
||||||
name,
|
|
||||||
weight: float = 1.0,
|
|
||||||
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
|
|
||||||
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
|
|
||||||
):
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.kohya_manager = kohya_manager
|
self.kohya_manager = kohya_manager
|
||||||
self.unet = unet
|
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self, model):
|
||||||
# TODO: make model able to load from huggingface, rather then just local files
|
# TODO: make model able to load from huggingface, rather then just local files
|
||||||
path = Path(global_lora_models_dir(), self.name)
|
path = Path(global_lora_models_dir(), self.name)
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
if not self.unet:
|
if model.load_attn_procs:
|
||||||
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
|
|
||||||
return
|
|
||||||
if self.unet.load_attn_procs:
|
|
||||||
file = Path(path, "pytorch_lora_weights.bin")
|
file = Path(path, "pytorch_lora_weights.bin")
|
||||||
if file.is_file():
|
if file.is_file():
|
||||||
print(f">> Loading LoRA: {path}")
|
print(f">> Loading LoRA: {path}")
|
||||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
model.load_attn_procs(path.absolute().as_posix())
|
||||||
else:
|
else:
|
||||||
print(f" ** Unable to find valid LoRA at: {path}")
|
print(f" ** Unable to find valid LoRA at: {path}")
|
||||||
else:
|
else:
|
||||||
print(" ** Invalid Model to load LoRA")
|
print(" ** Invalid Model to load LoRA")
|
||||||
elif self.kohya_manager:
|
elif self.kohya_manager:
|
||||||
try:
|
self.kohya_manager.apply_lora_model(self.name,self.weight)
|
||||||
self.kohya_manager.apply_lora_model(self.name,self.weight)
|
|
||||||
except IncompatibleModelException:
|
|
||||||
print(f" ** LoRA {self.name} is incompatible with this model; will generate without the LoRA applied.")
|
|
||||||
else:
|
else:
|
||||||
print(" ** Unable to load LoRA")
|
print(" ** Unable to load LoRA")
|
||||||
|
|
||||||
def unload(self):
|
def unload(self):
|
||||||
if self.kohya_manager and self.kohya_manager.unload_applied_lora(self.name):
|
if self.kohya_manager:
|
||||||
print(f'>> unloading LoRA {self.name}')
|
print(f'>> unloading LoRA {self.name}')
|
||||||
|
self.kohya_manager.unload_applied_lora(self.name)
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
def __init__(self, pipe: StableDiffusionPipeline):
|
def __init__(self, pipe):
|
||||||
# Kohya class handles lora not generated through diffusers
|
# Kohya class handles lora not generated through diffusers
|
||||||
self.kohya = KohyaLoraManager(pipe)
|
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
|
||||||
self.unet = pipe.unet
|
|
||||||
|
|
||||||
def set_loras_conditions(self, lora_weights: list):
|
def set_loras_conditions(self, lora_weights: list):
|
||||||
conditions = []
|
conditions = []
|
||||||
if len(lora_weights) > 0:
|
if len(lora_weights) > 0:
|
||||||
for lora in lora_weights:
|
for lora in lora_weights:
|
||||||
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
|
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
|
||||||
|
|
||||||
if len(conditions) > 0:
|
if len(conditions) > 0:
|
||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_compatible_loras(self)->Dict[str, Path]:
|
|
||||||
'''
|
|
||||||
List all the LoRAs in the global lora directory that
|
|
||||||
are compatible with the current model. Return a dictionary
|
|
||||||
of the lora basename and its path.
|
|
||||||
'''
|
|
||||||
model_length = self.kohya.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
|
||||||
return self.list_loras(model_length)
|
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def list_loras(token_vector_length:int=None)->Dict[str, Path]:
|
def list_loras(self)->Dict[str, Path]:
|
||||||
'''List the LoRAS in the global lora directory.
|
|
||||||
If token_vector_length is provided, then only return
|
|
||||||
LoRAS that have the indicated length:
|
|
||||||
768: v1 models
|
|
||||||
1024: v2 models
|
|
||||||
'''
|
|
||||||
path = Path(global_lora_models_dir())
|
path = Path(global_lora_models_dir())
|
||||||
models_found = dict()
|
models_found = dict()
|
||||||
for root,_,files in os.walk(path):
|
for root,_,files in os.walk(path):
|
||||||
for x in files:
|
for x in files:
|
||||||
name = Path(x).stem
|
name = Path(x).stem
|
||||||
suffix = Path(x).suffix
|
suffix = Path(x).suffix
|
||||||
if suffix not in [".ckpt", ".pt", ".safetensors"]:
|
if suffix in [".ckpt", ".pt", ".safetensors"]:
|
||||||
continue
|
models_found[name]=Path(root,x)
|
||||||
path = Path(root,x)
|
|
||||||
if token_vector_length is None:
|
|
||||||
models_found[name]=Path(root,x) # unconditional addition
|
|
||||||
elif token_vector_length == KohyaLoraManager.vector_length_from_checkpoint_file(path):
|
|
||||||
models_found[name]=Path(root,x) # conditional on the base model matching
|
|
||||||
return models_found
|
return models_found
|
||||||
|
|
||||||
|
@ -3,16 +3,14 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import warnings
|
import safetensors.torch
|
||||||
with warnings.catch_warnings():
|
import torch
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextualInversion:
|
class TextualInversion:
|
||||||
@ -36,7 +34,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.full_precision = full_precision
|
self.full_precision = full_precision
|
||||||
self.hf_concepts_library = get_hf_concepts_lib()
|
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||||
self.trigger_to_sourcefile = dict()
|
self.trigger_to_sourcefile = dict()
|
||||||
default_textual_inversions: list[TextualInversion] = []
|
default_textual_inversions: list[TextualInversion] = []
|
||||||
self.textual_inversions = default_textual_inversions
|
self.textual_inversions = default_textual_inversions
|
||||||
|
@ -32,9 +32,9 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch",
|
"clip_anytorch",
|
||||||
"compel~=1.1.5",
|
"compel~=1.1.0",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.16.1",
|
"diffusers[torch]~=0.14",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
"einops",
|
"einops",
|
||||||
"eventlet",
|
"eventlet",
|
||||||
@ -76,7 +76,7 @@ dependencies = [
|
|||||||
"taming-transformers-rom1504",
|
"taming-transformers-rom1504",
|
||||||
"test-tube>=0.7.5",
|
"test-tube>=0.7.5",
|
||||||
"torch-fidelity",
|
"torch-fidelity",
|
||||||
"torch~=2.0.0",
|
"torch~=1.13.1",
|
||||||
"torchmetrics",
|
"torchmetrics",
|
||||||
"torchvision>=0.14.1",
|
"torchvision>=0.14.1",
|
||||||
"transformers~=4.26",
|
"transformers~=4.26",
|
||||||
@ -108,7 +108,7 @@ requires-python = ">=3.9, <3.11"
|
|||||||
"test" = ["pytest-cov", "pytest>6.0.0"]
|
"test" = ["pytest-cov", "pytest>6.0.0"]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.16; sys_platform!='darwin'",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
Reference in New Issue
Block a user