Compare commits

...

154 Commits

Author SHA1 Message Date
53e1199902 prevent potential infinite recursion on exceptions raised by event handlers 2023-10-12 14:34:35 -04:00
0f9c676fcb remove download queue change_priority() calls completely 2023-10-12 14:03:28 -04:00
a51b165a40 clean up model downloader status locking to avoid race conditions 2023-10-12 13:07:09 -04:00
5f80d4dd07 Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor 2023-10-11 23:12:20 -04:00
b708aef5cc misc small fixes requested by Ryan 2023-10-11 23:02:22 -04:00
aace679505 Update invokeai/app/services/model_convert.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-10-11 22:59:47 -04:00
a2079bdd70 Update docs/installation/050_INSTALLING_MODELS.md
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-10-11 22:59:35 -04:00
0a0412f75f restore CLI to broken state 2023-10-11 22:57:08 -04:00
e079cc9f07 add back source URL validation to download job hierarchy 2023-10-11 22:42:07 -04:00
76aa19a0f7 first draft of documentation finished 2023-10-11 15:39:59 -04:00
71e7e61c0f add documentation for model record service and loader 2023-10-10 16:30:38 -04:00
67607f053d fix issues with module import order breaking pytest node tests 2023-10-09 22:43:00 -04:00
4bab724288 fix broken import 2023-10-09 16:45:32 -04:00
e50a257198 merge with main 2023-10-09 14:02:19 -04:00
4149d357bf refactor installer class hierarchy 2023-10-09 13:56:28 -04:00
33d4756c48 improve selection of huggingface repo id files to download 2023-10-09 08:53:03 -04:00
3962914f7d merge with main 2023-10-09 00:30:55 -04:00
3644d40e04 Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor 2023-10-09 00:28:48 -04:00
fe1038665c address all PR 4252 comments from ryan through October 5 2023-10-09 00:28:21 -04:00
a80ff75b52 Update invokeai/app/invocations/model.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-10-08 22:55:22 -04:00
ce2baa36a9 port support for AutoencoderTiny models 2023-10-08 19:49:03 -04:00
bccfe8b3cc fix some type mismatches introduces by reorg 2023-10-08 19:30:04 -04:00
e5b2bc8532 refactor download queue jobs 2023-10-08 16:39:23 -04:00
a64a34b49a add support for repo_id subfolders 2023-10-08 12:45:06 -04:00
51060543dc support clipvision image encoder downloading 2023-10-07 19:13:41 -04:00
7f68f58cf7 restore printing of version when invokeai-web and invokeai called with --version 2023-10-07 18:23:34 -04:00
432231ea18 merge with main 2023-10-07 16:46:32 -04:00
44216381cb fix conversion call 2023-10-07 15:29:28 -04:00
00e85bcd67 make autoimport directory optional, defaulting to inactive 2023-10-07 14:00:38 -04:00
6303f74616 allow user to select main database or external file for model record/config db 2023-10-07 13:31:21 -04:00
8e06088152 refactor services 2023-10-06 18:10:20 -04:00
9cbc62d8d3 fix reorganized module dependencies 2023-10-04 23:53:29 -04:00
cd5d3e30c7 refactor model_manager_service.py into small functional modules 2023-10-04 23:45:58 -04:00
cb0fdf3394 refactor model install job class hierarchy 2023-10-04 14:51:59 -04:00
a180c0f241 check model hash before and after moving in filesystem 2023-10-04 09:40:15 -04:00
16ec7a323b fix type mismatches in download_manager service 2023-10-04 08:58:49 -04:00
de90d4068b Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor 2023-10-04 08:42:07 -04:00
4624de0151 Merge branch 'main' into lstein/model-manager-refactor 2023-10-03 22:44:22 -04:00
459f0238dd multiple minor fixes 2023-10-03 22:43:19 -04:00
e3912e8826 replace config.ram_cache_size with config.ram and similarly for vram 2023-10-03 15:36:23 -04:00
062a6ed180 prevent crash on windows due to lack of os.pathconf call 2023-10-03 15:30:07 -04:00
48c3d926b0 make textual inversion training work with new model manager 2023-10-02 22:23:49 -04:00
63f6c12aa3 make merge script read invokeai.yaml when default root passed 2023-10-02 21:22:43 -04:00
c91429d4ab merge with main 2023-10-02 21:11:07 -04:00
230ee18536 do not ignore keyboard interrupt while scanning models 2023-09-30 14:21:39 -04:00
c025c9c4ed speed up model scanning at startup 2023-09-30 13:57:13 -04:00
acaaff4b7e make model merge script work with new model manager 2023-09-30 12:24:39 -04:00
807ae821ea more type mismatch fixes 2023-09-30 10:19:22 -04:00
208d390779 almost all type mismatches fixed 2023-09-29 19:23:08 -04:00
cbf0310a2c add README explaining reorg of tests directory 2023-09-29 01:17:07 -04:00
4555aec17c remove unused code from invokeai.backend.model_manager.storage.yaml 2023-09-29 01:07:18 -04:00
3b832f1db2 fix one more type mismatch in probe module 2023-09-29 00:44:50 -04:00
2f16a2c35d fix migrate script and type mismatches in probe, config and loader 2023-09-29 00:09:07 -04:00
81fce18c73 reorder pytests to prevent fixture race condition 2023-09-28 09:55:20 -04:00
0b75a4fbb5 resolve merge conflicts 2023-09-27 22:51:06 -04:00
2e9a7b0454 Merge branch 'main' into lstein/model-manager-refactor 2023-09-26 00:15:37 -04:00
1d6a4e7ee7 add tests for model installation events 2023-09-26 00:04:27 -04:00
effced8560 added cancel_all and prune model install operations to router API 2023-09-25 17:34:59 -04:00
ac4634000a merge with main & resolve conflicts 2023-09-25 17:02:21 -04:00
f9b92ddc12 resolve conflicts with get_logger() code changes from main 2023-09-24 10:34:06 -04:00
8bc1ca046c allow priority to be set at install job submission time 2023-09-24 10:08:21 -04:00
6edee2d22b automatically convert models.yaml to new format 2023-09-23 17:00:53 -04:00
ab58eb29c5 resolve conflicts with ip-adapter change 2023-09-23 13:00:47 -04:00
d5d517d2fa correctly download the selected version of a civitai model 2023-09-22 22:54:46 -04:00
d2cdbe5c4e configure script now working 2023-09-22 22:15:42 -04:00
07ddd601e1 fix install of models with relative paths 2023-09-22 11:49:18 -04:00
c9cd418ed8 add/delete from command line working; training words downloaded 2023-09-21 18:18:35 -04:00
30aea54f1a remove debug statement 2023-09-21 12:05:51 -04:00
3199409fd3 TUI installer functional; minor cosmetic work needed 2023-09-20 21:41:45 -04:00
3402cf6542 preserve description in metadata when installing a starter model 2023-09-20 20:30:35 -04:00
ed91f48a92 TUI installer more or less working 2023-09-20 17:07:11 -07:00
de666fd7bc move incorrectly placed models into correct directory at startup time 2023-09-19 01:18:03 -04:00
73bc088fa7 blackify 2023-09-19 00:54:14 -04:00
0c8849155e Merge branch 'main' into lstein/model-manager-refactor 2023-09-18 22:38:55 -04:00
d1382f232c fasthash produces same results on windows & linux 2023-09-18 19:38:33 -07:00
151ba02022 fix models.yaml version assertion error in pytests 2023-09-17 17:22:50 -04:00
d051c0868e attempt to fix flake8 lint errors 2023-09-17 17:13:56 -04:00
238d7fa0ee add models.yaml conversion script 2023-09-17 16:26:45 -04:00
f0ce559d28 add install job control to web API 2023-09-17 15:28:37 -04:00
e880f4bcfb add logs to confirm that event info is being sent to bus 2023-09-16 22:38:37 -04:00
539776a15a import_model API now working 2023-09-16 22:17:39 -04:00
c029534243 all methods in router API now tested and working 2023-09-16 19:43:01 -04:00
dc683475d4 loading and conversions of checkpoints working 2023-09-16 16:27:57 -04:00
c090c5f907 update_model and delete_model working; convert is WIP 2023-09-16 12:22:23 -04:00
db7fdc3555 fix more isort issues 2023-09-15 22:22:43 -04:00
b9a90fbd28 blackify and isort 2023-09-15 22:19:29 -04:00
08952b9aa0 Merge branch 'main' into lstein/model-manager-refactor 2023-09-15 22:18:48 -04:00
b7789bb7bb list_models() API call now working 2023-09-15 21:58:28 -04:00
3529925234 services rewritten; starting work on routes 2023-09-15 18:22:24 -04:00
a033ccc776 blackify 2023-09-14 21:12:41 -04:00
716a1b6423 model_manager_service now mostly type correct 2023-09-14 21:12:31 -04:00
171d789646 model loader autoscans models_dir on initialization 2023-09-14 14:07:14 -05:00
ac88863fd2 fix exception traceback reporting 2023-09-14 10:52:26 -05:00
27dcd89c90 merge with main; model_manager_service.py needs to be rewritten 2023-09-13 20:19:14 -04:00
4b932b275d refactor create_download_job; override probe info in install call 2023-09-13 18:53:33 -05:00
6d8b2a7385 pytests mostly working; model_manager_service needs rewriting 2023-09-11 23:47:24 -04:00
7430d87301 loader working 2023-09-10 23:11:25 -04:00
b583bddeb1 loading works -- web app broken 2023-09-10 22:59:58 -04:00
f454304c91 make it possible to pause/resume repo_id downloads 2023-09-10 17:20:47 -04:00
8052f2eb5d Merge branch 'main' into lstein/model-manager-refactor 2023-09-10 13:01:19 -04:00
8636015d92 increase download chunksize for better speed 2023-09-09 22:15:34 -04:00
b7a6a536e6 fix flake8 warnings 2023-09-09 21:26:09 -04:00
b2892f9068 incorporate civitai metadata into model config 2023-09-09 21:17:55 -04:00
3582cfa267 make download manager optional in InvokeAIServices during development 2023-09-09 14:06:36 -04:00
64424c6db0 install of repo_ids records author, tags and license 2023-09-09 14:02:05 -04:00
598fe8101e wire together download and install; now need to write install events 2023-09-09 11:42:07 -04:00
b7ca983f9c blackify 2023-09-07 21:14:24 -04:00
2165d55a67 add checks for malformed URLs and malicious content dispositions 2023-09-07 21:14:10 -04:00
a7aca29765 implement regression tests for pause/cancel/error conditions 2023-09-07 17:06:59 -04:00
79b2423159 last flake8 fix - why is local flake8 not identical to git flake8? 2023-09-07 09:38:15 -04:00
b09e012baa Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor 2023-09-07 09:20:32 -04:00
c9a016f1a2 more flake8 fixes 2023-09-07 09:20:23 -04:00
d979c50de3 Merge branch 'main' into lstein/model-manager-refactor 2023-09-07 09:17:16 -04:00
11ead34022 fix flake8 warnings 2023-09-07 09:16:56 -04:00
82499d4ef0 fix various typing errors in api dependencies initialization 2023-09-06 23:59:45 -04:00
3448edac1a fix progress reporting for repo_ids 2023-09-06 19:33:04 -04:00
626acd5105 remove unecessary HTTP probe for repo_id model component sizes 2023-09-06 19:18:15 -04:00
404cfe0eb9 add download manager to invoke services 2023-09-06 18:47:30 -04:00
e9074176bd add unit tests for queued model download 2023-09-06 18:25:04 -04:00
ca6d24810c resolve merge conflicts 2023-09-04 21:13:09 -04:00
57552deab2 threaded repo_id download working; error conditions not tested 2023-09-04 21:10:21 -04:00
8f51adc737 chore: black 2023-09-05 10:22:46 +10:00
d1c5990abe merge and resolve conflicts 2023-09-04 18:50:06 -04:00
8fc20925b5 added download manager service and began repo_id download 2023-09-04 18:26:28 -04:00
869f310ae7 download of individual files working 2023-09-02 14:52:21 -04:00
e6512e1b9a add ABC for download manager 2023-08-30 09:08:31 -04:00
8396bf7c99 Merge branch 'main' into lstein/model-manager-refactor 2023-08-29 21:27:19 -04:00
97f2e778ee make ModelSearch pydantic 2023-08-24 13:37:49 -04:00
93cef55964 blackify 2023-08-23 19:53:21 -04:00
055ad0101d merge with main; resolve conflicts 2023-08-23 19:45:25 -04:00
9adc897302 added install module 2023-08-23 19:41:25 -04:00
4b3d54dbc0 install ABC written 2023-08-23 08:44:22 -04:00
6f9bf87a7a reimplement and clean up probe class 2023-08-22 22:24:07 -04:00
f023e342ef added main templates 2023-08-20 21:34:43 -04:00
1784aeb343 fix flake8 errors 2023-08-20 16:38:41 -04:00
0deb3f9e2a Merge branch 'main' into lstein/model-manager-refactor 2023-08-20 16:15:14 -04:00
916cc26193 partial rewrite of checkpoint template creator 2023-08-16 21:21:42 -04:00
e83d00595d module skeleton written 2023-08-14 21:49:32 -04:00
1c7d9dbf40 start installer module 2023-08-14 21:10:45 -04:00
7db71ed42e rename modules 2023-08-14 20:55:30 -04:00
c56fb38855 added ability to force config class returned by make_config() 2023-08-13 19:08:50 -04:00
155d9fcb13 Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor 2023-08-13 18:49:38 -04:00
81da3d3b23 change model field name "hash" to "id" 2023-08-13 18:49:30 -04:00
51e84e6986 Merge branch 'main' into lstein/model-manager-refactor 2023-08-13 18:17:28 -04:00
1ea0ccb7b9 add SQL backend 2023-08-13 18:15:49 -04:00
5434dcd273 fix test to work with string paths 2023-08-13 13:36:31 -04:00
0c7430048e change paths to str to make json serializable 2023-08-13 13:26:19 -04:00
6c9b9e1787 Merge branch 'main' into lstein/model-manager-refactor 2023-08-12 20:13:53 -04:00
b2894b5270 add class docstring and blackify 2023-08-12 20:13:00 -04:00
32958db6f6 add YAML file storage backend 2023-08-12 20:06:00 -04:00
e8815a1676 rename ModelConfig to ModelConfigFactory 2023-08-12 18:30:14 -04:00
e8edb0d434 add ABC for config storage 2023-08-12 17:50:55 -04:00
b5d97b18f1 blackify 2023-08-12 17:24:03 -04:00
ae56c000fc define model configuration classes 2023-08-12 17:11:34 -04:00
134 changed files with 8638 additions and 3186 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ Once you're setup, for more information, you can review the documentation specif
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
* #### [Frontend Documentation](./contributingToFrontend.md)
* #### [Node Documentation](../INVOCATIONS.md)
* #### [InvokeAI Model Manager](../MODEL_MANAGER.md)
* #### [Local Development](../LOCAL_DEVELOPMENT.md)

View File

@ -207,11 +207,8 @@ if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory (not recommended)|
| `model_config_db` | `auto` | Location of the model configuration database. Specify `auto` to use the main invokeai.db database, or specify a `.yaml` or `.db` file to store the data externally.|
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
@ -234,6 +231,18 @@ Paths:
# controlnet_dir: null
```
### Model Cache
These options control the size of various caches that InvokeAI uses
during the model loading and conversion process. All units are in GB
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `disk` | `20.0` | Before loading a model into memory, InvokeAI converts .ckpt and .safetensors models into diffusers format and saves them to disk. This option controls the maximum size of the directory in which these converted models are stored. If set to zero, then only the most recently-used model will be cached. |
| `ram` | `6.0` | After loading a model from disk, it is kept in system RAM until it is needed again. This option controls how much RAM is set aside for this purpose. Larger amounts allow more models to reside in RAM and for InvokeAI to quickly switch between them. |
| `vram` | `0.25` | This allows smaller models to remain in VRAM, speeding up execution modestly. It should be a small number. |
### Logging
These settings control the information, warning, and debugging

View File

@ -123,11 +123,20 @@ installation. Examples:
# (list all controlnet models)
invokeai-model-install --list controlnet
# (install the model at the indicated URL)
# (install the diffusers model using its hugging face repo_id)
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0
# (install a diffusers model that lives in a subfolder)
invokeai-model-install --add stabilityai/stable-diffusion-xl-base-1.0:vae
# (install the checkpoint model at the indicated URL)
invokeai-model-install --add https://civitai.com/api/download/models/128713
# (delete the named model)
invokeai-model-install --delete sd-1/main/analog-diffusion
# (delete the named model if its name is unique)
invokeai-model-install --delete analog-diffusion
# (delete the named model using its fully qualified name)
invokeai-model-install --delete sd-1/main/test_model
```
### Installation via the Web GUI
@ -141,6 +150,24 @@ left-hand panel) and navigate to *Import Models*
wish to install. You may use a URL, HuggingFace repo id, or a path on
your local disk.
There is special scanning for CivitAI URLs which lets
you cut-and-paste either the URL for a CivitAI model page
(e.g. https://civitai.com/models/12345), or the direct download link
for a model (e.g. https://civitai.com/api/download/models/12345).
If the desired model is a HuggingFace diffusers model that is located
in a subfolder of the repository (e.g. vae), then append the subfolder
to the end of the repo_id like this:
```
# a VAE model located in subfolder "vae"
stabilityai/stable-diffusion-xl-base-1.0:vae
# version 2 of the model located in subfolder "v2"
monster-labs/control_v1p_sd15_qrcode_monster:v2
```
3. Alternatively, the *Scan for Models* button allows you to paste in
the path to a folder somewhere on your machine. It will be scanned for
importable models and prompt you to add the ones of your choice.

View File

@ -19,6 +19,7 @@ from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.download_manager import DownloadQueueService
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -26,7 +27,9 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats import InvocationStatsService
from ..services.invoker import Invoker
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.model_manager_service import ModelManagerService
from ..services.model_install_service import ModelInstallService
from ..services.model_loader_service import ModelLoadService
from ..services.model_record_service import ModelRecordServiceBase
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.thread import lock
@ -127,8 +130,12 @@ class ApiDependencies:
)
)
download_queue = DownloadQueueService(event_bus=events)
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=lock)
model_loader = ModelLoadService(config, model_record_store)
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
@ -141,6 +148,10 @@ class ApiDependencies:
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
download_queue=download_queue,
model_record_store=model_record_store,
model_loader=model_loader,
model_installer=model_installer,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),

View File

@ -2,35 +2,60 @@
import pathlib
from typing import List, Literal, Optional, Union
from enum import Enum
from typing import Any, List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
from invokeai.app.services.model_install_service import ModelInstallJob
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import (
from invokeai.backend.model_manager import (
OPENAPI_MODEL_CONFIGS,
DuplicateModelException,
InvalidModelException,
ModelNotFoundException,
ModelConfigBase,
ModelSearch,
SchedulerPredictionType,
UnknownModelException,
)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
# There are still numerous mypy errors here because it does not seem to like this
# way of dynamically generating the typing hints below.
InvokeAIModelConfig: Any = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
models: List[InvokeAIModelConfig]
class ModelDownloadStatus(BaseModel):
"""Return information about a background installation job."""
job_id: int
source: str
priority: int
bytes: int
total_bytes: int
status: DownloadJobStatus
class JobControlOperation(str, Enum):
START = "Start"
PAUSE = "Pause"
CANCEL = "Cancel"
@models_router.get(
@ -42,19 +67,22 @@ async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_record_store
if base_models and len(base_models) > 0:
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
models_raw.extend(
[x.dict() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
)
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models_raw = [x.dict() for x in record_store.search_by_name(model_type=model_type)]
models = parse_obj_as(ModelsList, {"models": models_raw})
return models
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
"/i/{key}",
operation_id="update_model",
responses={
200: {"description": "The model was updated successfully"},
@ -63,69 +91,36 @@ async def list_models(
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=UpdateModelResponse,
response_model=InvokeAIModelConfig,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
key: str = Path(description="Unique key of model"),
info: InvokeAIModelConfig = Body(description="Model configuration"),
) -> InvokeAIModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
info_dict = info.dict()
record_store = ApiDependencies.invoker.services.model_record_store
model_install = ApiDependencies.invoker.services.model_installer
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=info.model_name,
new_base=info.base_model,
)
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get("path") != previous_info.get(
"path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except ModelNotFoundException as e:
new_config = record_store.update_model(key, config=info_dict)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
try:
# In the event that the model's name, type or base has changed, and the model itself
# resides in the invokeai root models directory, then the next statement will move
# the model file into its new canonical location.
new_config = model_install.sync_model_path(new_config.key)
model_response = parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@ -141,7 +136,7 @@ async def update_model(
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=ModelDownloadStatus,
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
@ -149,30 +144,47 @@ async def import_model(
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
default=None,
),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
priority: Optional[int] = Body(
description="Which import jobs run first. Lower values run before higher ones.",
default=10,
),
) -> ModelDownloadStatus:
"""
Add a model using its local path, repo_id, or remote URL.
items_to_import = {location}
prediction_types = {x.value: x for x in SchedulerPredictionType}
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has a `job_id` property
that can be used to control the download job.
The priority controls which import jobs run first. Lower values run before
higher ones.
The prediction_type applies to SDv2 models only and can be one of
"v_prediction", "epsilon", or "sample". Default if not provided is
"v_prediction".
Listen on the event bus for a series of `model_event` events with an `id`
matching the returned job id to get the progress, completion status, errors,
and information on the model that was installed.
"""
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
installer = ApiDependencies.invoker.services.model_installer
result = installer.install_model(
location,
probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None},
priority=priority,
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
return ModelDownloadStatus(
job_id=result.id,
source=result.source,
priority=result.priority,
bytes=result.bytes,
total_bytes=result.total_bytes,
status=result.status,
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
@ -189,29 +201,40 @@ async def import_model(
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=InvokeAIModelConfig,
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
info: InvokeAIModelConfig = Body(description="Model configuration"),
) -> InvokeAIModelConfig:
"""
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
This call will block until the model is installed.
"""
logger = ApiDependencies.invoker.services.logger
path = info.path
installer = ApiDependencies.invoker.services.model_installer
record_store = ApiDependencies.invoker.services.model_record_store
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
)
logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
key = installer.install_path(path)
logger.info(f"Created model {key} for {path}")
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# update with the provided info
try:
info_dict = info.dict()
new_config = record_store.update_model(key, new_config=info_dict)
return parse_obj_as(InvokeAIModelConfig, new_config.dict())
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -220,33 +243,34 @@ async def add_model(
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
"/i/{key}",
operation_id="del_model",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204,
response_model=None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
key: str = Path(description="Unique key of model to remove from model registry."),
delete_files: Optional[bool] = Query(description="Delete underlying files and directories as well.", default=False),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type
)
logger.info(f"Deleted model: {model_name}")
installer = ApiDependencies.invoker.services.model_installer
if delete_files:
installer.delete(key)
else:
installer.unregister(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except ModelNotFoundException as e:
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
"/convert/{key}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
@ -254,33 +278,26 @@ async def delete_model(
404: {"description": "Model not found"},
},
status_code=200,
response_model=ConvertModelResponse,
response_model=InvokeAIModelConfig,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
key: str = Path(description="Unique key of model to convert from checkpoint/safetensors to diffusers format."),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse:
) -> InvokeAIModelConfig:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model,
model_type=model_type,
convert_dest_directory=dest,
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
model_config = converter.convert_model(key, dest_directory=dest)
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@ -299,11 +316,12 @@ async def convert_model(
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]:
"""Search for all models in a server-local path."""
if not search_path.is_dir():
raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
return ModelSearch().search(search_path)
@models_router.get(
@ -317,7 +335,10 @@ async def search_for_models(
)
async def list_ckpt_configs() -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
config = ApiDependencies.invoker.services.configuration
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
@models_router.post(
@ -330,27 +351,32 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
response_model=bool,
)
async def sync_to_config() -> bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config()
"""
Synchronize model in-memory data structures with disk.
Call after making changes to models.yaml, autoimport directories
or models directory.
"""
installer = ApiDependencies.invoker.services.model_installer
installer.sync_to_config()
return True
@models_router.put(
"/merge/{base_model}",
"/merge",
operation_id="merge_models",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
409: {"description": "An identical merged model is already installed"},
},
status_code=200,
response_model=MergeModelResponse,
response_model=InvokeAIModelConfig,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
keys: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(
@ -360,29 +386,147 @@ async def merge_models(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
) -> InvokeAIModelConfig:
"""Merge the indicated diffusers model."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names,
base_model,
merged_model_name=merged_model_name or "+".join(model_names),
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store,
)
result: ModelConfigBase = converter.merge_models(
model_keys=keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
response = parse_obj_as(InvokeAIModelConfig, result.dict())
except DuplicateModelException as e:
raise HTTPException(status_code=409, detail=str(e))
except UnknownModelException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{keys}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/jobs",
operation_id="list_install_jobs",
responses={
200: {"description": "The control job was updated successfully"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=List[ModelDownloadStatus],
)
async def list_install_jobs() -> List[ModelDownloadStatus]:
"""List active and pending model installation jobs."""
job_mgr = ApiDependencies.invoker.services.download_queue
jobs = job_mgr.list_jobs()
return [
ModelDownloadStatus(
job_id=x.id,
source=x.source,
priority=x.priority,
bytes=x.bytes,
total_bytes=x.total_bytes,
status=x.status,
)
for x in jobs
if isinstance(x, ModelInstallJob)
]
@models_router.patch(
"/jobs/control/{operation}/{job_id}",
operation_id="control_download_jobs",
responses={
200: {"description": "The control job was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The job could not be found"},
},
status_code=200,
response_model=ModelDownloadStatus,
)
async def control_download_jobs(
job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"),
operation: JobControlOperation = Path(description="The operation to perform on the job."),
priority_delta: Optional[int] = Body(
description="Change in job priority for priority operations only. Negative numbers increase priority.",
default=None,
),
) -> ModelDownloadStatus:
"""Start, pause, cancel, or change the run priority of a running model install job."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.download_queue
try:
job = job_mgr.id_to_job(job_id)
if operation == JobControlOperation.START:
job_mgr.start_job(job_id)
elif operation == JobControlOperation.PAUSE:
job_mgr.pause_job(job_id)
elif operation == JobControlOperation.CANCEL:
job_mgr.cancel_job(job_id)
else:
raise ValueError("unknown operation {operation}")
bytes = 0
total_bytes = 0
if isinstance(job, DownloadJobRemoteSource):
bytes = job.bytes
total_bytes = job.total_bytes
return ModelDownloadStatus(
job_id=job_id,
source=job.source,
priority=job.priority,
status=job.status,
bytes=bytes,
total_bytes=total_bytes,
)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.patch(
"/jobs/cancel_all",
operation_id="cancel_all_download_jobs",
responses={
204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"},
},
)
async def cancel_all_download_jobs():
"""Cancel all model installation jobs."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.download_queue
logger.info("Cancelling all download jobs.")
job_mgr.cancel_all_jobs()
return Response(status_code=204)
@models_router.patch(
"/jobs/prune",
operation_id="prune_jobs",
responses={
204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_jobs():
"""Prune all completed and errored jobs."""
mgr = ApiDependencies.invoker.services.download_queue
mgr.prune_jobs()
return Response(status_code=204)

View File

@ -151,7 +151,7 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
from invokeai.backend.model_management.models import get_model_config_enums
from invokeai.backend.model_manager.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
@ -201,6 +201,10 @@ app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=Tru
def invoke_api():
if app_config.version:
print(f"InvokeAI version {__version__}")
return
def find_port(port: int):
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
@ -252,7 +256,4 @@ def invoke_api():
if __name__ == "__main__":
if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api()
invoke_api()

View File

@ -10,10 +10,11 @@ from pathlib import Path
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import ModelType
from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from ..services.invocation_services import InvocationServices
from ..services.model_record_service import ModelRecordServiceBase
from .commands import BaseCommand
# singleton object, class variable
@ -21,11 +22,11 @@ completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
def __init__(self, model_record_store: ModelRecordServiceBase):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
self.store = model_record_store
return
def complete(self, text, state):
@ -127,7 +128,7 @@ class Completer(object):
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == "model":
return self.manager.model_names()
return [x.name for x in self.store.model_info_by_name(model_type=ModelType.Main)]
def _pre_input_hook(self):
if self.linebuffer:
@ -142,7 +143,7 @@ def set_autocompleter(services: InvocationServices) -> Completer:
if completer:
return completer
completer = Completer(services.model_manager)
completer = Completer(services.model_record_store)
readline.set_completer(completer.complete)
try:

View File

@ -30,6 +30,8 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -38,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
from .services.download_manager import DownloadQueueService
from .services.events import EventServiceBase
from .services.graph import (
Edge,
@ -52,9 +55,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .services.model_manager_service import ModelManagerService
from .services.model_install_service import ModelInstallService
from .services.model_loader_service import ModelLoadService
from .services.model_record_service import ModelRecordServiceBase
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.thread import lock
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
@ -228,7 +234,12 @@ def invoke_all(context: CliContext):
def invoke_cli():
if config.version:
print(f"InvokeAI version {__version__}")
return
logger.info(f"InvokeAI version {__version__}")
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument("commands", nargs="*")
@ -239,8 +250,6 @@ def invoke_cli():
if infile := config.from_file:
sys.stdin = open(infile, "r")
model_manager = ModelManagerService(config, logger)
events = EventServiceBase()
output_folder = config.output_path
@ -254,15 +263,22 @@ def invoke_cli():
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
download_queue = DownloadQueueService(event_bus=events)
model_record_store = ModelRecordServiceBase.open(config, conn=db_conn, lock=None)
model_loader = ModelLoadService(config, model_record_store)
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
conn=db_conn, table_name="graph_executions", lock=lock
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
boards = BoardService(
services=BoardServiceDependencies(
@ -297,20 +313,25 @@ def invoke_cli():
)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
download_queue=download_queue,
model_record_store=model_record_store,
model_loader=model_loader,
model_installer=model_installer,
configuration=config,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
)
system_graphs = create_system_graphs(services.graph_library)
@ -478,7 +499,4 @@ def invoke_cli():
if __name__ == "__main__":
if config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_cli()
invoke_cli()

View File

@ -13,8 +13,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
SDXLConditioningInfo,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.model_manager import ModelType, UnknownModelException
from ...backend.model_manager.lora import ModelPatcher
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
@ -60,23 +60,23 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
tokenizer_info = context.services.model_loader.get_model(
**self.clip.tokenizer.dict(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
text_encoder_info = context.services.model_loader.get_model(
**self.clip.text_encoder.dict(),
context=context,
)
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
@ -85,7 +85,7 @@ class CompelInvocation(BaseInvocation):
ti_list.append(
(
name,
context.services.model_manager.get_model(
context.services.model_loader.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
@ -93,7 +93,7 @@ class CompelInvocation(BaseInvocation):
).context.model,
)
)
except ModelNotFoundException:
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
@ -159,11 +159,11 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
):
tokenizer_info = context.services.model_manager.get_model(
tokenizer_info = context.services.model_loader.get_model(
**clip_field.tokenizer.dict(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
text_encoder_info = context.services.model_loader.get_model(
**clip_field.text_encoder.dict(),
context=context,
)
@ -186,12 +186,12 @@ class SDXLPromptInvocationBase:
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
lora_info = context.services.model_loader.get_model(**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
# loras = [(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
@ -200,7 +200,7 @@ class SDXLPromptInvocationBase:
ti_list.append(
(
name,
context.services.model_manager.get_model(
context.services.model_loader.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
@ -208,7 +208,7 @@ class SDXLPromptInvocationBase:
).context.model,
)
)
except ModelNotFoundException:
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())

View File

@ -28,7 +28,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType
from ...backend.model_manager import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,

View File

@ -17,8 +17,8 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.models.ip_adapter import get_ip_adapter_image_encoder_model_id
class IPAdapterModelField(BaseModel):

View File

@ -37,12 +37,11 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.model_manager import BaseModelType, ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.model_manager.lora import ModelPatcher
from ...backend.model_manager.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@ -133,7 +132,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image is not None:
vae_info = context.services.model_manager.get_model(
vae_info = context.services.model_loader.get_model(
**self.vae.vae.dict(),
context=context,
)
@ -166,7 +165,7 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_manager.get_model(
orig_scheduler_info = context.services.model_loader.get_model(
**scheduler_info.dict(),
context=context,
)
@ -362,7 +361,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
context.services.model_loader.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
@ -430,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
context.services.model_loader.get_model(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
@ -438,7 +437,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
image_encoder_model_info = context.services.model_manager.get_model(
image_encoder_model_info = context.services.model_loader.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
@ -488,7 +487,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.get_model(
t2i_adapter_model_info = context.services.model_loader.get_model(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
@ -640,7 +639,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
lora_info = context.services.model_loader.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
@ -648,7 +647,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
del lora_info
return
unet_info = context.services.model_manager.get_model(
unet_info = context.services.model_loader.get_model(
**self.unet.unet.dict(),
context=context,
)
@ -753,7 +752,7 @@ class LatentsToImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model(
vae_info = context.services.model_loader.get_model(
**self.vae.vae.dict(),
context=context,
)
@ -978,7 +977,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get_pil_image(self.image.image_name)
vae_info = context.services.model_manager.get_model(
vae_info = context.services.model_loader.get_model(
**self.vae.vae.dict(),
context=context,
)

View File

@ -3,7 +3,8 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -19,9 +20,7 @@ from .baseinvocation import (
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
key: str = Field(description="Unique ID for model")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
@ -61,16 +60,13 @@ class ModelLoaderOutput(BaseInvocationOutput):
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
key: str = Field(description="Unique ID of the model")
class LoRAModelField(BaseModel):
"""LoRA model field"""
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
key: str = Field(description="Unique ID for model")
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
@ -81,20 +77,15 @@ class MainModelLoaderInvocation(BaseInvocation):
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
"""Load a main model, outputting its submodels."""
key = self.model.key
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
if not context.services.model_record_store.model_exists(key):
raise Exception(f"Unknown model {key}")
"""
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
@ -103,7 +94,7 @@ class MainModelLoaderInvocation(BaseInvocation):
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
@ -112,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
@ -125,30 +116,22 @@ class MainModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.TextEncoder,
),
loras=[],
@ -156,9 +139,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
submodel=SubModelType.Vae,
),
),
@ -167,7 +148,7 @@ class MainModelLoaderInvocation(BaseInvocation):
@invocation_output("lora_loader_output")
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
"""Model loader output."""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@ -187,24 +168,20 @@ class LoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
"""Load a LoRA model."""
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
key = self.lora.key
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {lora_name}!")
if not context.services.model_record_store.model_exists(key):
raise Exception(f"Unknown lora: {key}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
raise Exception(f'Lora "{key}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
raise Exception(f'Lora "{key}" already applied to clip')
output = LoraLoaderOutput()
@ -212,9 +189,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -224,9 +199,7 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -237,7 +210,7 @@ class LoraLoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
"""SDXL LoRA Loader Output."""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
@ -261,27 +234,22 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
"""Load an SDXL LoRA."""
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
key = self.lora.key
if not context.services.model_record_store.model_exists(key):
raise Exception(f"Unknown lora name: {key}!")
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if self.unet is not None and any(lora.key == key for lora in self.unet.loras):
raise Exception(f'Lora "{key}" already applied to unet')
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.key == key for lora in self.clip.loras):
raise Exception(f'Lora "{key}" already applied to clip')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
if self.clip2 is not None and any(lora.key == key for lora in self.clip2.loras):
raise Exception(f'Lora "{key}" already applied to clip2')
output = SDXLLoraLoaderOutput()
@ -289,9 +257,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -301,9 +267,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -313,9 +277,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
key=key,
submodel=None,
weight=self.weight,
)
@ -325,10 +287,9 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
class VAEModelField(BaseModel):
"""Vae model field"""
"""Vae model field."""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
key: str = Field(description="Unique ID for VAE model")
@invocation_output("vae_loader_output")
@ -340,29 +301,22 @@ class VaeLoaderOutput(BaseInvocationOutput):
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
"""Loads a VAE model, outputting a VaeLoaderOutput."""
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, ui_type=UIType.VaeModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
"""Load a VAE model."""
key = self.vae_model.key
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
if not context.services.model_record_store.model_exists(key):
raise Exception(f"Unkown vae name: {key}!")
return VaeLoaderOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
key=key,
)
)
)
@ -370,7 +324,7 @@ class VaeLoaderInvocation(BaseInvocation):
@invocation_output("seamless_output")
class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
"""Modified Seamless Model output."""
unet: Optional[UNetField] = OutputField(description=FieldDescriptions.unet, title="UNet")
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
@ -390,6 +344,7 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
"""Apply seamless transformation."""
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)

View File

@ -17,7 +17,7 @@ from invokeai.app.invocations.primitives import ConditioningField, ConditioningO
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.model_manager.lora import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin
@ -62,15 +62,15 @@ class ONNXPromptInvocation(BaseInvocation):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
tokenizer_info = context.services.model_loader.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
text_encoder_info = context.services.model_loader.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras
]
@ -81,7 +81,7 @@ class ONNXPromptInvocation(BaseInvocation):
ti_list.append(
(
name,
context.services.model_manager.get_model(
context.services.model_loader.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
@ -254,12 +254,12 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
eta=0.0,
)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_loader.get_model(**self.unet.unet.dict())
with unet_info as unet: # , ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
# loras = [(stack.enter_context(context.services.model_loader.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
(context.services.model_loader.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.unet.loras
]
@ -345,7 +345,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
if self.vae.vae.submodel != SubModelType.VaeDecoder:
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
vae_info = context.services.model_manager.get_model(
vae_info = context.services.model_loader.get_model(
**self.vae.vae.dict(),
)
@ -418,7 +418,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
model_type = ModelType.ONNX
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -426,7 +426,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
@ -435,7 +435,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
@ -444,7 +444,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,

View File

@ -1,4 +1,4 @@
from ...backend.model_management import ModelType, SubModelType
from ...backend.model_manager import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -48,7 +48,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -137,7 +137,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
if not context.services.model_record_store.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,

View File

@ -16,7 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType
from invokeai.backend.model_manager import BaseModelType
class T2IAdapterModelField(BaseModel):

View File

@ -25,6 +25,7 @@ from pydantic import BaseSettings
class PagingArgumentParser(argparse.ArgumentParser):
"""
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
"""
@ -144,16 +145,6 @@ class InvokeAISettings(BaseSettings):
return [
"type",
"initconf",
"version",
"from_file",
"model",
"root",
"max_cache_size",
"max_vram_cache_size",
"always_use_cpu",
"free_gpu_mem",
"xformers_enabled",
"tiled_decode",
]
class Config:
@ -226,9 +217,7 @@ class InvokeAISettings(BaseSettings):
def int_or_float_or_str(value: str) -> Union[int, float, str]:
"""
Workaround for argparse type checking.
"""
"""Workaround for argparse type checking."""
try:
return int(value)
except Exception as e: # noqa F841

View File

@ -171,6 +171,7 @@ two configs are kept in separate sections of the config file:
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
@ -182,7 +183,9 @@ from .base import InvokeAISettings
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
DEFAULT_MAX_DISK_CACHE = 20 # gigs, enough for three sdxl models, or 6 sd-1 models
DEFAULT_RAM_CACHE = 7.5
DEFAULT_VRAM_CACHE = 0.25
class InvokeAIAppConfig(InvokeAISettings):
@ -217,11 +220,8 @@ class InvokeAIAppConfig(InvokeAISettings):
# PATHS
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
autoimport_dir : Optional[Path] = Field(default=None, description='Path to a directory of models files to be imported on startup.', category='Paths')
model_config_db : Union[Path, Literal['auto'], None] = Field(default=None, description='Path to a sqlite .db file or .yaml file for storing model config records; "auto" will reuse the main sqlite db', category='Paths')
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
@ -241,8 +241,9 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# CACHE
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching", category="Model Cache", )
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage", category="Model Cache", )
disk : float = Field(default=DEFAULT_MAX_DISK_CACHE, ge=0, description="Maximum size (in GB) for the disk-based diffusers model conversion cache", category="Model Cache", )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
# DEVICE
@ -254,7 +255,6 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
# QUEUE
@ -272,6 +272,10 @@ class InvokeAIAppConfig(InvokeAISettings):
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
controlnet_dir : Path = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
# fmt: on
@ -312,9 +316,7 @@ class InvokeAIAppConfig(InvokeAISettings):
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
"""
This returns a singleton InvokeAIAppConfig configuration object.
"""
"""This returns a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
or type(cls.singleton_config) is not cls
@ -324,6 +326,29 @@ class InvokeAIAppConfig(InvokeAISettings):
cls.singleton_init = kwargs
return cls.singleton_config
@classmethod
def _excluded_from_yaml(cls) -> List[str]:
el = super()._excluded_from_yaml()
el.extend(
[
"version",
"from_file",
"model",
"root",
"max_cache_size",
"max_vram_cache_size",
"always_use_cpu",
"free_gpu_mem",
"xformers_enabled",
"tiled_decode",
"conf_path",
"lora_dir",
"embedding_dir",
"controlnet_dir",
]
)
return el
@property
def root_path(self) -> Path:
"""
@ -414,7 +439,11 @@ class InvokeAIAppConfig(InvokeAISettings):
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
def conversion_cache_size(self) -> float:
return self.disk
@property
def vram_cache_size(self) -> float:
return self.max_vram_cache_size or self.vram
@property
@ -440,9 +469,7 @@ class InvokeAIAppConfig(InvokeAISettings):
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -0,0 +1,205 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Model download service.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
from invokeai.backend.model_manager.download import ( # noqa F401
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadJobStatus,
DownloadQueueBase,
ModelDownloadQueue,
ModelSourceMetadata,
UnknownJobIDException,
)
if TYPE_CHECKING:
from .events import EventServiceBase
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL or repo_id."""
@abstractmethod
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
filename: Optional[Path] = None,
start: Optional[bool] = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase:
"""
Create a download job.
:param source: Source of the download - URL, repo_id or local Path
:param destdir: Directory to download into.
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]
:param event_handler: Callable that receives a DownloadJobBase and acts on it.
:returns job id: The numeric ID of the DownloadJobBase object for this task.
"""
pass
@abstractmethod
def submit_download_job(
self,
job: DownloadJobBase,
start: Optional[bool] = True,
):
"""
Submit a download job.
:param job: A DownloadJobBase
:param start: Immediately start job [True]
After execution, `job.id` will be set to a non-negative value.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJobBase]:
"""
List active DownloadJobBases.
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJobBase:
"""
Return the DownloadJobBase corresponding to the string ID.
:param id: ID of the DownloadJobBase.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def start_all_jobs(self):
"""Enqueue all idle and paused jobs."""
pass
@abstractmethod
def pause_all_jobs(self):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def start_job(self, job: DownloadJobBase):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, job: DownloadJobBase):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJobBase):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass
class DownloadQueueService(DownloadQueueServiceBase):
"""Multithreaded queue for downloading models via URL or repo_id."""
_event_bus: Optional["EventServiceBase"] = None
_queue: DownloadQueueBase
def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs):
"""
Initialize new DownloadQueueService object.
:param event_bus: EventServiceBase object for reporting progress.
:param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue.
e.g. `max_parallel_dl`.
"""
self._event_bus = event_bus
self._queue = ModelDownloadQueue(**kwargs)
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
filename: Optional[Path] = None,
start: Optional[bool] = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase: # noqa D102
event_handlers = event_handlers or []
if self._event_bus:
event_handlers = [*event_handlers, self._event_bus.emit_model_event]
return self._queue.create_download_job(
source=source,
destdir=destdir,
filename=filename,
start=start,
access_token=access_token,
event_handlers=event_handlers,
)
def submit_download_job(
self,
job: DownloadJobBase,
start: bool = True,
):
return self._queue.submit_download_job(job, start)
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
return self._queue.list_jobs()
def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102
return self._queue.id_to_job(id)
def start_all_jobs(self): # noqa D102
return self._queue.start_all_jobs()
def pause_all_jobs(self): # noqa D102
return self._queue.pause_all_jobs()
def cancel_all_jobs(self): # noqa D102
return self._queue.cancel_all_jobs()
def prune_jobs(self): # noqa D102
return self._queue.prune_jobs()
def start_job(self, job: DownloadJobBase): # noqa D102
return self._queue.start_job(job)
def pause_job(self, job: DownloadJobBase): # noqa D102
return self._queue.pause_job(job)
def cancel_job(self, job: DownloadJobBase): # noqa D102
return self._queue.cancel_job(job)
def join(self): # noqa D102
return self._queue.join()

View File

@ -3,7 +3,7 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.model_record_service import BaseModelType, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
@ -11,14 +11,17 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.loader import ModelInfo
from invokeai.backend.util.logging import InvokeAILogger
class EventServiceBase:
queue_event: str = "queue_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event."""
pass
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
@ -153,9 +156,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_key: str,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
@ -166,9 +167,7 @@ class EventServiceBase:
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
),
)
@ -179,9 +178,7 @@ class EventServiceBase:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_key: str,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
@ -193,9 +190,7 @@ class EventServiceBase:
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_key=model_key,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
@ -312,3 +307,9 @@ class EventServiceBase:
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
)
def emit_model_event(self, job: DownloadJobBase) -> None:
"""Emit event when the status of a download/install job changes."""
self.dispatch( # use dispatch() directly here because we are not a session event.
event_name="model_event", payload=dict(job=job)
)

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download_manager import DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
@ -18,7 +19,9 @@ if TYPE_CHECKING:
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.model_install_service import ModelInstallServiceBase
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@ -35,8 +38,11 @@ class InvocationServices:
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
latents: "LatentsStorageBase"
download_queue: "DownloadQueueServiceBase"
model_record_store: "ModelRecordServiceBase"
model_loader: "ModelLoadServiceBase"
model_installer: "ModelInstallServiceBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -55,7 +61,10 @@ class InvocationServices:
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
model_record_store: "ModelRecordServiceBase",
model_loader: "ModelLoadServiceBase",
model_installer: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -72,7 +81,10 @@ class InvocationServices:
self.images = images
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.download_queue = download_queue
self.model_record_store = model_record_store
self.model_loader = model_loader
self.model_installer = model_installer
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -38,12 +38,12 @@ import psutil
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_manager.cache import CacheStats
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
from .model_manager_service import ModelManagerService
from .model_loader_service import ModelLoadServiceBase
# size of GIG in bytes
GIG = 1073741824
@ -174,13 +174,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerService
model_loader: ModelLoadServiceBase
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerService,
model_loader: ModelLoadServiceBase,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
@ -189,15 +189,15 @@ class InvocationStatsService(InvocationStatsServiceBase):
self.graph_id = graph_id
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
self.model_loader = model_loader
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
if self.model_loader:
self.model_loader.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
@ -208,7 +208,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
invocation_type=self.invocation.type,
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
@ -217,12 +217,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
model_manager: ModelManagerService,
model_loader: ModelLoadServiceBase,
) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
return self.StatsContext(invocation, graph_execution_state_id, model_loader, self)
def reset_all_stats(self):
"""Zero all statistics"""

View File

@ -0,0 +1,192 @@
# Copyright 2023 Lincoln Stein and the InvokeAI Team
"""
Convert and merge models.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import move, rmtree
from typing import List, Optional
from pydantic import Field
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
from .model_install_service import ModelInstallServiceBase
from .model_loader_service import ModelInfo, ModelLoadServiceBase
from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType
class ModelConvertBase(ABC):
"""Convert and merge models."""
@abstractmethod
def __init__(
cls,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
pass
@abstractmethod
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
pass
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
class ModelConvert(ModelConvertBase):
"""Implementation of ModelConvertBase."""
def __init__(
self,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
self.loader = loader
self.installer = installer
self.store = store
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version as well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
new_diffusers_path = None
config = InvokeAIAppConfig.get_config()
try:
info: ModelConfigBase = self.store.get_model(key)
if info.model_format != "checkpoint":
raise ValueError(f"not a checkpoint format model: {info.name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `path`. It doesn't matter
# what submodel type we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
converted_model: ModelInfo = self.loader.get_model(key, **submodel)
checkpoint_path = config.models_path / info.path
old_diffusers_path = config.models_path / converted_model.location
# new values to write in
update = info.dict()
update.pop("config")
update["model_format"] = "diffusers"
update["path"] = str(converted_model.location)
if dest_directory:
new_diffusers_path = Path(dest_directory) / info.name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
move(old_diffusers_path, new_diffusers_path)
update["path"] = new_diffusers_path.as_posix()
self.store.update_model(key, update)
result = self.installer.sync_model_path(key, ignore_hash_change=True)
except Exception as excp:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
if new_diffusers_path:
rmtree(new_diffusers_path)
raise excp
if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path):
checkpoint_path.unlink()
return result
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
merger = ModelMerger(self.store)
try:
if not merged_model_name:
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
raise Exception("not implemented")
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -0,0 +1,653 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
import re
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import move, rmtree
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend import get_precision
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.download.model_queue import (
HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobMetadataURL,
DownloadJobRepoID,
DownloadJobWithMetadata,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.models import InvalidModelException
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
if TYPE_CHECKING:
from .events import EventServiceBase
from .download_manager import (
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadQueueService,
DownloadQueueServiceBase,
ModelSourceMetadata,
)
class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
probe_override: Optional[Dict[str, Any]] = Field(
description="Keys in this dict will override like-named attributes in the automatic probe info",
default=None,
)
class ModelInstallURLJob(DownloadJobMetadataURL, ModelInstallJob):
"""Job for installing URLs."""
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
"""Job for installing repo ids."""
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
"""Job for installing local paths."""
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
event_handlers: List[DownloadEventHandler] = [],
):
"""
Create ModelInstallService object.
:param config: Optional InvokeAIAppConfig. If None passed,
uses the system-wide default app config.
:param download: Optional DownloadQueueServiceBase object. If None passed,
a default queue object will be created.
:param store: Optional ModelConfigStore. If None passed,
defaults to `configs/models.yaml`.
:param event_bus: InvokeAI event bus for reporting events to.
:param event_handlers: List of event handlers to pass to the queue object.
"""
pass
@property
@abstractmethod
def queue(self) -> DownloadQueueServiceBase:
"""Return the download queue used by the installer."""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the storage backend used by the installer."""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
pass
@abstractmethod
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
"""
Probe and register the model at model_path.
:param model_path: Filesystem Path to the model.
:param overrides: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
"""
Probe, register and install the model in the models directory.
This involves moving the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
:returns id: The string ID of the installed model.
"""
pass
@abstractmethod
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
start: Optional[bool] = True,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""
Download and install the indicated model.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread, and the returned object is a
invokeai.backend.model_manager.download.DownloadJobBase
object which can be interrogated to get the status of
the download and install process. Call our `wait_for_installs()`
method to wait for all downloads and installations to complete.
:param source: Either a URL or a HuggingFace repo_id.
:param inplace: If True, local paths will not be moved into
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param probe_override: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
:param metadata: Use this to override the fields 'description`,
`author`, `tags`, `source` and `license`.
:returns ModelInstallJob object.
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
pass
@abstractmethod
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
pass
@abstractmethod
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
:param model_path: Path to the model on disk.
:return str: FastHash of the model for use as an ID.
"""
pass
class ModelInstallService(ModelInstallServiceBase):
"""Model installer class handles installation from a local path."""
_app_config: InvokeAIAppConfig
_logger: Logger
_store: ModelConfigStore
_download_queue: DownloadQueueServiceBase
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
_installed: Set[str] = Field(default=set)
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
_event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None)
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
def __init__(
self,
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
event_handlers: List[DownloadEventHandler] = [],
): # noqa D107 - use base class docstrings
self._app_config = config or InvokeAIAppConfig.get_config()
self._store = store or ModelRecordServiceBase.open(self._app_config)
self._logger = InvokeAILogger.get_logger(config=self._app_config)
self._event_bus = event_bus
self._precision = get_precision()
self._handlers = event_handlers
if self._event_bus:
self._handlers.append(self._event_bus.emit_model_event)
self._download_queue = queue or DownloadQueueService(event_bus=event_bus)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
"""Call automatically at process start."""
self.sync_to_config()
@property
def queue(self) -> DownloadQueueServiceBase:
"""Return the queue."""
return self._download_queue
@property
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
return self._store
@property
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
return self._app_config
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
start: Optional[bool] = True,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob: # noqa D102
queue = self._download_queue
variant = variant or ("fp16" if self._precision == "float16" else None)
job = self._make_download_job(
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
)
handler = (
self._complete_registration_handler
if inplace and Path(source).exists()
else self._complete_installation_handler
)
if isinstance(job, ModelInstallJob):
job.probe_override = probe_override
if metadata and isinstance(job, DownloadJobWithMetadata):
job.metadata = metadata
job.add_event_handler(handler)
self._async_installs[source] = None
queue.submit_download_job(job, start=start)
return job
def register_path(
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
return self._register(model_path, info)
def install_path(
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
new_path = self._move_model(model_path, dest_path)
new_hash = self.hash(new_path)
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
info,
)
def unregister(self, key: str): # noqa D102
self._store.del_model(key)
def delete(self, key: str): # noqa D102
model = self._store.get_model(key)
path = self._app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def conditionally_delete(self, key: str): # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self._store.get_model(key)
models_dir = self._app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.delete(key)
else:
self.unregister(key)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
key: str = FastModelHash.hash(model_path)
model_path = model_path.absolute()
if model_path.is_relative_to(self._app_config.models_path):
model_path = model_path.relative_to(self._app_config.models_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.name if model_path.is_dir() else model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format,
hash=key,
)
# add 'main' specific fields
if info.model_type == ModelType.Main:
if info.variant_type:
registration_data.update(variant=info.variant_type)
if info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
)
except KeyError as exc:
raise InvalidModelException(
"Configuration file for this checkpoint could not be determined"
) from exc
self._store.add_model(key, registration_data)
return key
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
return move(old_path, new_path)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
if overrides: # used to override probe fields
for key, value in overrides.items():
try:
setattr(info, key, value) # skip validation errors
except Exception:
pass
return info
def _complete_installation_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
if isinstance(job, DownloadJobWithMetadata):
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.name = metadata.name or info.name
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
jobs = self._download_queue.list_jobs()
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir.cleanup()
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self._store.get_model(key)
old_path = Path(model.path)
models_dir = self._app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.hash = self.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.hash != key:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
self._store.update_model(key, model)
return model
def _make_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None,
subfolder: Optional[str] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
# In the event that we are being asked to install a path that is already on disk,
# we simply probe and register/install it. The job does not actually do anything, but we
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
if Path(source).exists(): # a path that is already on disk
destdir = source
return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers)
# choose a temporary directory inside the models directory
models_dir = self._app_config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
cls = ModelInstallJob
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = ModelInstallRepoIDJob
source = match.group(1)
subfolder = match.group(2) or subfolder
kwargs = dict(variant=variant, subfolder=subfolder)
elif re.match(HTTP_RE, str(source)):
cls = ModelInstallURLJob
kwargs = {}
else:
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(
source=str(source),
destination=Path(self._tmpdir.name),
access_token=access_token,
priority=priority,
event_handlers=self._handlers,
**kwargs,
)
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""Pause until all installation jobs have completed."""
self._download_queue.join()
id_map = self._async_installs
self._async_installs = dict()
return id_map
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._installed = set()
search.search(scan_dir)
return list(self._installed)
def scan_models_directory(self):
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self._store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self.scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
self.scan_directory(self._app_config.root_path / autoimport)
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.register_path(model)
self.sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True

View File

@ -0,0 +1,140 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic import Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad
from .config import InvokeAIAppConfig
from .model_record_service import ModelRecordServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
class ModelLoadServiceBase(ABC):
"""Load models into memory."""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
store: Union[ModelConfigStore, ModelRecordServiceBase],
):
"""
Initialize a ModelLoadService
:param config: InvokeAIAppConfig object
:param store: ModelConfigStore object for fetching configuration information
installation and download events will be sent to the event bus.
"""
pass
@abstractmethod
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model identified by key.
:param key: Unique key returned by the ModelConfigStore module.
:param submodel_type: Submodel to return (required for main models)
:param context" Optional InvocationContext, used in event reporting.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Reset model cache statistics for graph with graph_id."""
pass
# implementation
class ModelLoadService(ModelLoadServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoad
def __init__(
self,
config: InvokeAIAppConfig,
record_store: Union[ModelConfigStore, ModelRecordServiceBase],
):
"""
Initialize a ModelLoadService.
:param config: InvokeAIAppConfig object
:param store: ModelRecordServiceBase or ModelConfigStore object for fetching configuration information
installation and download events will be sent to the event bus.
"""
self._loader = ModelLoad(config, record_store)
def get_model(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model.
The submodel is required when fetching a main model.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_key=key,
submodel=submodel_type,
model_info=model_info,
)
return model_info
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics. Is this used?
"""
self._loader.collect_cache_stats(cache_stats)
def _emit_load_event(
self,
context: InvocationContext,
model_key: str,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_key=model_key,
submodel=submodel,
)

View File

@ -1,675 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
import torch
from pydantic import Field
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
MergeInterpolationMethod,
ModelInfo,
ModelManager,
ModelMerger,
ModelNotFoundException,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_management.model_cache import CacheStats
from invokeai.backend.model_management.model_search import FindModels
from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory"""
@abstractmethod
def __init__(
self,
config: InvokeAIAppConfig,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
pass
@abstractmethod
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
pass
@property
@abstractmethod
def logger(self):
pass
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Uses the exact format as the omegaconf stanza.
"""
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
pass
@abstractmethod
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
):
"""
Initialize with the path to the models.yaml config file.
Optional parameters are the torch device type, precision, max_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
logger.debug(f"Config file={config_file}")
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f"GPU device = {device} {device_name}")
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == "float32" else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `ram_cache_size` config setting
max_cache_size = config.ram_cache_size
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info("Model manager service initialized")
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
)
return model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
def list_models(
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
) -> list[dict]:
"""
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"add/update model {model_name}")
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
self.logger.debug(f"delete model {model_name}")
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
def _emit_load_event(
self,
context: InvocationContext,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
@property
def logger(self):
return self.mgr.logger
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names=model_names,
base_model=base_model,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels([directory], self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
)

View File

@ -0,0 +1,130 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from __future__ import annotations
import sqlite3
import threading
from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from invokeai.backend.model_manager import ( # noqa F401
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.storage import ( # noqa F401
ModelConfigStore,
ModelConfigStoreSQL,
ModelConfigStoreYAML,
UnknownModelException,
)
from invokeai.backend.util.logging import InvokeAILogger
from .config import InvokeAIAppConfig
class ModelRecordServiceBase(ModelConfigStore):
"""
Responsible for managing model configuration records.
This is an ABC that is simply a subclassing of the ModelConfigStore ABC
in the backend.
"""
@classmethod
@abstractmethod
def from_db_file(cls, db_file: Path) -> ModelRecordServiceBase:
"""
Initialize a new object from a database file.
If the path does not exist, a new sqlite3 db will be initialized.
:param db_file: Path to the database file
"""
pass
@classmethod
def open(
cls, config: InvokeAIAppConfig, conn: Optional[sqlite3.Connection] = None, lock: Optional[threading.Lock] = None
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
"""
Choose either a ModelConfigStoreSQL or a ModelConfigStoreFile backend.
Logic is as follows:
1. if config.model_config_db contains a Path, then
a. if the path looks like a .db file, open a new sqlite3 connection and return a ModelRecordServiceSQL
b. if the path looks like a .yaml file, return a new ModelRecordServiceFile
c. otherwise bail
2. if config.model_config_db is the literal 'auto', then use the passed sqlite3 connection and thread lock.
a. if either of these is missing, then we create our own connection to the invokeai.db file, which *should*
be a safe thing to do - sqlite3 will use file-level locking.
3. if config.model_config_db is None, then fall back to config.conf_path, using a yaml file
"""
logger = InvokeAILogger.get_logger()
db = config.model_config_db
if db is None:
return ModelRecordServiceFile.from_db_file(config.model_conf_path)
if str(db) == "auto":
logger.info("Model config storage = main InvokeAI database")
return (
ModelRecordServiceSQL.from_connection(conn, lock)
if (conn and lock)
else ModelRecordServiceSQL.from_db_file(config.db_path)
)
assert isinstance(db, Path)
suffix = db.suffix
if suffix == ".yaml":
logger.info(f"Model config storage = {str(db)}")
return ModelRecordServiceFile.from_db_file(config.root_path / db)
elif suffix == ".db":
logger.info(f"Model config storage = {str(db)}")
return ModelRecordServiceSQL.from_db_file(config.root_path / db)
else:
raise ValueError(
f'Unrecognized model config record db file type {db} in "model_config_db" configuration variable.'
)
class ModelRecordServiceSQL(ModelConfigStoreSQL):
"""
ModelRecordService that uses Sqlite for its backend.
Please see invokeai/backend/model_manager/storage/sql.py for
the implementation.
"""
@classmethod
def from_connection(cls, conn: sqlite3.Connection, lock: threading.Lock) -> ModelRecordServiceSQL:
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
This is the same as the default __init__() constructor.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
return cls(conn, lock)
@classmethod
def from_db_file(cls, db_file: Path) -> ModelRecordServiceSQL: # noqa D102 - docstring in ABC
Path(db_file).parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_file, check_same_thread=False)
lock = threading.Lock()
return cls(conn, lock)
class ModelRecordServiceFile(ModelConfigStoreYAML):
"""
ModelRecordService that uses a YAML file for its backend.
Please see invokeai/backend/model_manager/storage/yaml.py for
the implementation.
"""
@classmethod
def from_db_file(cls, db_file: Path) -> ModelRecordServiceFile: # noqa D102 - docstring in ABC
return cls(db_file)

View File

@ -97,8 +97,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
model_loader = self.__invoker.services.model_loader
with statistics.collect_stats(invocation, graph_id, model_loader):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection

View File

@ -4,7 +4,7 @@ from PIL import Image
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ...backend.model_management.models import BaseModelType
from ...backend.model_manager import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from ..invocations.baseinvocation import InvocationContext

View File

@ -1,5 +1,15 @@
"""
Initialization file for invokeai.backend
"""
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
from .model_management.models import SilenceWarnings # noqa: F401
from .model_manager import ( # noqa F401
BaseModelType,
DuplicateModelException,
InvalidModelException,
ModelConfigStore,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
)
from .util.devices import get_precision # noqa F401

View File

@ -8,7 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.model_conf_path.parent.exists(), f"{config.model_conf_path.parent} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:

View File

@ -0,0 +1,196 @@
"""
Utility (backend) functions used by model_install.py
"""
from pathlib import Path
from typing import Dict, List, Optional
import omegaconf
from huggingface_hub import HfFolder
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService, ModelSourceMetadata
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS.yaml"
class UnifiedModelInfo(BaseModel):
name: Optional[str] = None
base_model: Optional[BaseModelType] = None
model_type: Optional[ModelType] = None
source: Optional[str] = None
subfolder: Optional[str] = None
description: Optional[str] = None
recommended: bool = False
installed: bool = False
default: bool = False
requires: List[str] = Field(default_factory=list)
@dataclass
class InstallSelections:
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
remove_models: List[str] = Field(default_factory=list)
class TqdmProgress(object):
_bars: Dict[int, tqdm] # the tqdm object
_last: Dict[int, int] # last bytes downloaded
def __init__(self):
self._bars = dict()
self._last = dict()
def job_update(self, job: ModelInstallJob):
if not isinstance(job, DownloadJobRemoteSource):
return
job_id = job.id
if job.status == "running" and job.total_bytes > 0: # job starts running before total bytes known
if job_id not in self._bars:
dest = Path(job.destination).name
self._bars[job_id] = tqdm(
desc=dest,
initial=0,
total=job.total_bytes,
unit="iB",
unit_scale=True,
)
self._last[job_id] = 0
self._bars[job_id].update(job.bytes - self._last[job_id])
self._last[job_id] = job.bytes
class InstallHelper(object):
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
all_models: Dict[str, UnifiedModelInfo] = dict()
_installer: ModelInstallService
_config: InvokeAIAppConfig
_installed_models: List[str] = []
_starter_models: List[str] = []
_default_model: Optional[str] = None
_initial_models: omegaconf.DictConfig
def __init__(self, config: InvokeAIAppConfig):
self._config = config
self._installer = ModelInstallService(config=config, event_handlers=[TqdmProgress().job_update])
self._initial_models = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
self._initialize_model_lists()
@property
def installer(self) -> ModelInstallService:
return self._installer
def _initialize_model_lists(self):
"""
Initialize our model slots.
Set up the following:
installed_models -- list of installed model keys
starter_models -- list of starter model keys from INITIAL_MODELS
all_models -- dict of key => UnifiedModelInfo
default_model -- key to default model
"""
# previously-installed models
for model in self._installer.store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
self.all_models[key] = info
self._installed_models.append(key)
for key in self._initial_models.keys():
if key in self.all_models:
# we want to preserve the description
description = self.all_models[key].description or self._initial_models[key].get("description")
self.all_models[key].description = description
else:
base_model, model_type, model_name = key.split("/")
info = UnifiedModelInfo(
name=model_name,
model_type=model_type,
base_model=base_model,
source=self._initial_models[key].source,
description=self._initial_models[key].get("description"),
recommended=self._initial_models[key].get("recommended", False),
default=self._initial_models[key].get("default", False),
subfolder=self._initial_models[key].get("subfolder"),
requires=list(self._initial_models[key].get("requires", [])),
)
self.all_models[key] = info
if not self.default_model:
self._default_model = key
elif self._initial_models[key].get("default", False):
self._default_model = key
self._starter_models.append(key)
# previously-installed models
for model in self._installer.store.all_models():
info = UnifiedModelInfo.parse_obj(model.dict())
info.installed = True
key = f"{model.base_model.value}/{model.model_type.value}/{model.name}"
self.all_models[key] = info
self._installed_models.append(key)
def recommended_models(self) -> List[UnifiedModelInfo]:
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
def installed_models(self) -> List[UnifiedModelInfo]:
return [self._to_model(x) for x in self._installed_models]
def starter_models(self) -> List[UnifiedModelInfo]:
return [self._to_model(x) for x in self._starter_models]
def default_model(self) -> UnifiedModelInfo:
return self._to_model(self._default_model)
def _to_model(self, key: str) -> UnifiedModelInfo:
return self.all_models[key]
def _add_required_models(self, model_list: List[UnifiedModelInfo]):
installed = {x.source for x in self.installed_models()}
reverse_source = {x.source: x for x in self.all_models.values()}
additional_models = []
for model_info in model_list:
for requirement in model_info.requires:
if requirement not in installed:
additional_models.append(reverse_source.get(requirement))
model_list.extend(additional_models)
def add_or_delete(self, selections: InstallSelections):
installer = self._installer
self._add_required_models(selections.install_models)
for model in selections.install_models:
metadata = ModelSourceMetadata(description=model.description, name=model.name)
installer.install_model(
model.source,
subfolder=model.subfolder,
access_token=HfFolder.get_token(),
metadata=metadata,
)
for model in selections.remove_models:
parts = model.split("/")
if len(parts) == 1:
base_model, model_type, model_name = (None, None, model)
else:
base_model, model_type, model_name = parts
matches = installer.store.search_by_name(
base_model=base_model, model_type=model_type, model_name=model_name
)
if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
elif not matches:
print(f"{model}: unknown model")
else:
for m in matches:
print(f"Deleting {m.model_type}:{m.name}")
installer.conditionally_delete(m.key)
installer.wait_for_installs()

View File

@ -22,7 +22,6 @@ from typing import Any, get_args, get_type_hints
from urllib import request
import npyscreen
import omegaconf
import psutil
import torch
import transformers
@ -38,21 +37,25 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.storage import ConfigFileVersionMismatchException, migrate_models_store
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.model_install import addModelsForm
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
MIN_COLS,
MIN_LINES,
CenteredButtonPress,
CheckboxWithChanged,
CyclingForm,
FileBox,
MultiSelectColumns,
SingleSelectColumnsSimple,
SingleSelectWithChanged,
WindowTooSmallException,
set_min_terminal_size,
)
@ -82,7 +85,6 @@ GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
@ -96,6 +98,8 @@ logger = InvokeAILogger.get_logger()
class DummyWidgetValue(Enum):
"""Dummy widget values."""
zero = 0
true = True
false = False
@ -179,6 +183,22 @@ class ProgressBar:
self.pbar.update(block_size)
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
filter = lambda x: "fp16 is not a valid" not in x.getMessage()
logger.addFilter(filter)
try:
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
finally:
logger.removeFilter(filter)
return destination
# ---------------------------------------------
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
try:
@ -455,6 +475,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
max_width=110,
scroll_exit=True,
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.disk = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.disk, range=(0, 100), step=0.5),
out_of=100,
lowest=0.0,
step=0.5,
relx=8,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
@ -495,6 +534,45 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
else:
self.vram = DummyWidgetValue.zero
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Location of the database used to store model path and configuration information:",
editable=False,
color="CONTROL",
)
self.nextrely += 1
if first_time:
old_opts.model_config_db = "auto"
self.model_conf_auto = self.add_widget_intelligent(
CheckboxWithChanged,
value=str(old_opts.model_config_db) == "auto",
name="Main database",
relx=2,
max_width=25,
scroll_exit=True,
)
self.nextrely -= 2
config_db = str(old_opts.model_config_db or old_opts.conf_path)
self.model_conf_override = self.add_widget_intelligent(
FileBox,
value=str(old_opts.root_path / config_db)
if config_db != "auto"
else str(old_opts.root_path / old_opts.conf_path),
name="Specify models config database manually",
select_dir=False,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
# begin_entry_at=40,
relx=30,
max_height=3,
max_width=100,
scroll_exit=True,
hidden=str(old_opts.model_config_db) == "auto",
)
self.model_conf_auto.on_changed = self.show_hide_model_conf_override
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@ -506,19 +584,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
labelColor="GOOD",
begin_entry_at=40,
max_height=3,
max_width=127,
scroll_exit=True,
)
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
max_height=3,
max_width=127,
scroll_exit=True,
)
self.nextrely += 1
@ -555,6 +635,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
self.attention_slice_label.hidden = not show
self.attention_slice_size.hidden = not show
def show_hide_model_conf_override(self, value):
self.model_conf_override.hidden = value
self.model_conf_override.display()
def on_ok(self):
options = self.marshall_arguments()
if self.validate_field_values(options):
@ -590,17 +674,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
for attr in [
"ram",
"vram",
"disk",
"outdir",
]:
if hasattr(self, attr):
setattr(new_opts, attr, getattr(self, attr).value)
for attr in self.autoimport_dirs:
if not self.autoimport_dirs[attr].value:
continue
directory = Path(self.autoimport_dirs[attr].value)
if directory.is_relative_to(config.root_path):
directory = directory.relative_to(config.root_path)
setattr(new_opts, attr, directory)
new_opts.model_config_db = "auto" if self.model_conf_auto.value else self.model_conf_override.value
new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
@ -615,13 +703,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
class EditOptApplication(npyscreen.NPSAppManaged):
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace, install_helper: InstallHelper):
super().__init__()
self.program_opts = program_opts
self.invokeai_opts = invokeai_opts
self.user_cancelled = False
self.autoload_pending = True
self.install_selections = default_user_selections(program_opts)
self.install_helper = install_helper
self.install_selections = default_user_selections(program_opts, install_helper)
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -644,12 +733,6 @@ class EditOptApplication(npyscreen.NPSAppManaged):
return self.options.marshall_arguments()
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp.run()
return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
@ -666,21 +749,12 @@ def default_startup_options(init_file: Path) -> Namespace:
return opts
def default_user_selections(program_opts: Namespace) -> InstallSelections:
try:
installer = ModelInstall(config)
except omegaconf.errors.ConfigKeyError:
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
initialize_rootdir(config.root_path, True)
installer = ModelInstall(config)
models = installer.all_models()
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
default_models = (
[install_helper.default_model()] if program_opts.default_only else install_helper.recommended_models()
)
return InstallSelections(
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else list(),
install_models=default_models if program_opts.yes_to_all else list(),
)
@ -730,7 +804,7 @@ def maybe_create_models_yaml(root: Path):
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
def run_console_ui(program_opts: Namespace, initfile: Path, install_helper: InstallHelper) -> (Namespace, Namespace):
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
@ -739,13 +813,7 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
import torch
torch.multiprocessing.set_start_method("spawn")
editApp = EditOptApplication(program_opts, invokeai_opts)
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
editApp.run()
if editApp.user_cancelled:
return (None, None)
@ -904,6 +972,7 @@ def main():
if opt.full_precision:
invoke_args.extend(["--precision", "float32"])
config.parse_args(invoke_args)
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
logger = InvokeAILogger().get_logger(config=config)
errors = set()
@ -917,14 +986,22 @@ def main():
# run this unconditionally in case new directories need to be added
initialize_rootdir(config.root_path, opt.yes_to_all)
models_to_download = default_user_selections(opt)
# this will initialize the models.yaml file if not present
try:
install_helper = InstallHelper(config)
except ConfigFileVersionMismatchException:
config.model_config_db = migrate_models_store(config)
install_helper = InstallHelper(config)
models_to_download = default_user_selections(opt, install_helper)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
else:
init_options, models_to_download = run_console_ui(opt, new_init_file)
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
if init_options:
write_opts(init_options, new_init_file)
else:
@ -939,10 +1016,12 @@ def main():
if opt.skip_sd_weights:
logger.warning("Skipping diffusion weights download per user request")
elif models_to_download:
process_and_execute(opt, models_to_download)
install_helper.add_or_delete(models_to_download)
postscript(errors=errors)
if not opt.yes_to_all:
input("Press any key to continue...")
except WindowTooSmallException as e:

View File

@ -3,13 +3,15 @@ Migrate the models directory and models.yaml file from an existing
InvokeAI 2.3 installation to 3.0.0.
"""
#### NOTE: THIS SCRIPT NO LONGER WORKS WITH REFACTORED MODEL MANAGER, AND WILL NOT BE UPDATED.
import argparse
import os
import shutil
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Union
from typing import Optional, Union
import diffusers
import transformers
@ -21,8 +23,9 @@ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel,
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
from invokeai.app.services.model_install_service import ModelInstallService
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.model_manager import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -43,19 +46,14 @@ class MigrateTo3(object):
self,
from_root: Path,
to_models: Path,
model_manager: ModelManager,
installer: ModelInstallService,
src_paths: ModelPaths,
):
self.root_directory = from_root
self.dest_models = to_models
self.mgr = model_manager
self.installer = installer
self.src_paths = src_paths
@classmethod
def initialize_yaml(cls, yaml_file: Path):
with open(yaml_file, "w") as file:
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
def create_directory_structure(self):
"""
Create the basic directory structure for the models folder.
@ -107,44 +105,10 @@ class MigrateTo3(object):
Recursively walk through src directory, probe anything
that looks like a model, and copy the model into the
appropriate location within the destination models directory.
This is now trivially easy using the installer service.
"""
directories_scanned = set()
for root, dirs, files in os.walk(src_dir, followlinks=True):
for d in dirs:
try:
model = Path(root, d)
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest)
directories_scanned.add(model)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
try:
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
continue
model = Path(root, f)
if model.parent in directories_scanned:
continue
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest)
except Exception as e:
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
self.installer.scan_directory(src_dir)
def migrate_support_models(self):
"""
@ -260,23 +224,21 @@ class MigrateTo3(object):
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split("/")
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _download_vae(self, repo_id: str, subfolder: str = None) -> Optional[Path]:
self.installer.install(repo_id) # bug! We don't support subfolder yet.
ids = self.installer.wait_for_installs()
if key := ids.get(repo_id):
return self.installer.store.get_model(key).path
else:
return None
def _vae_path(self, vae: Union[str, dict]) -> Path:
"""
Convert 2.3 VAE stanza to a straight path.
"""
vae_path = None
def _vae_path(self, vae: Union[str, dict]) -> Optional[Path]:
"""Convert 2.3 VAE stanza to a straight path."""
vae_path: Optional[Path] = None
# First get a path
if isinstance(vae, str):
vae_path = vae
vae_path = Path(vae)
elif isinstance(vae, DictConfig):
if p := vae.get("path"):
@ -284,28 +246,21 @@ class MigrateTo3(object):
elif repo_id := vae.get("repo_id"):
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
vae_path = "models/core/convert/sd-vae-ft-mse"
return vae_path
return Path(vae_path)
else:
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
assert vae_path is not None, "Couldn't find VAE for this model"
if vae_path is None:
return None
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
if vae_path.is_dir():
self.copy_dir(vae_path, dest)
else:
self.copy_file(vae_path, dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path("models", rel_path)
key = self.installer.install_path(vae_path) # this will move the model
return self.installer.store.get_model(key).path
elif vae_path.is_relative_to(self.dest_models):
key = self.installer.register_path(vae_path) # this will keep the model in place
return self.installer.store.get_model(key).path
else:
return vae_path
@ -501,44 +456,27 @@ def get_legacy_embeddings(root: Path) -> ModelPaths:
return _parse_legacy_yamlfile(root, path)
def do_migrate(src_directory: Path, dest_directory: Path):
def do_migrate(config: InvokeAIAppConfig, src_directory: Path, dest_directory: Path):
"""
Migrate models from src to dest InvokeAI root directories
"""
config_file = dest_directory / "configs" / "models.yaml.3"
dest_models = dest_directory / "models.3"
mm_store = ModelRecordServiceBase.open(config)
mm_install = ModelInstallService(config=config, store=mm_store)
version_3 = (dest_directory / "models" / "core").exists()
# Here we create the destination models.yaml file.
# If we are writing into a version 3 directory and the
# file already exists, then we write into a copy of it to
# avoid deleting its previous customizations. Otherwise we
# create a new empty one.
if version_3: # write into the dest directory
try:
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
except Exception:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
(dest_directory / "models").replace(dest_models)
else:
MigrateTo3.initialize_yaml(config_file)
mgr = ModelManager(config_file)
if not version_3:
src_directory = (dest_directory / "models").replace(src_directory / "models.orig")
print(f"Original models directory moved to {dest_directory}/models.orig")
paths = get_legacy_embeddings(src_directory)
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, installer=mm_install, src_paths=paths)
migrator.migrate()
print("Migration successful.")
if not version_3:
(dest_directory / "models").replace(src_directory / "models.orig")
print(f"Original models directory moved to {dest_directory}/models.orig")
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
config_file.replace(config_file.with_suffix(""))
dest_models.replace(dest_models.with_suffix(""))
@ -588,7 +526,7 @@ script, which will perform a full upgrade in place.""",
initialize_rootdir(dest_root, True)
do_migrate(src_root, dest_root)
do_migrate(config, src_root, dest_root)
if __name__ == "__main__":

View File

@ -1,609 +0,0 @@
"""
Utility (backend) functions used by model_install.py
"""
import os
import re
import shutil
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Set, Union
import requests
import torch
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from huggingface_hub import HfApi, HfFolder, hf_hub_url
from omegaconf import OmegaConf
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(name="InvokeAI")
# the initial "configs" dir is now bundled in the `invokeai.configs` package
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
Config_preamble = """
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
LEGACY_CONFIGS = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
@dataclass
class InstallSelections:
install_models: List[str] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
@dataclass
class ModelLoadInfo:
name: str
model_type: ModelType
base_type: BaseModelType
path: Optional[Path] = None
repo_id: Optional[str] = None
subfolder: Optional[str] = None
description: str = ""
installed: bool = False
recommended: bool = False
default: bool = False
requires: Optional[List[str]] = field(default_factory=list)
class ModelInstall(object):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets)
def all_models(self) -> Dict[str, ModelLoadInfo]:
"""
Return dict of model_key=>ModelLoadInfo objects.
This method consolidates and simplifies the entries in both
models.yaml and INITIAL_MODELS.yaml so that they can
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
value["name"] = name
value["base_type"] = base
value["model_type"] = model_type
model_info = ModelLoadInfo(**value)
if model_info.subfolder and model_info.repo_id:
model_info.repo_id += f":{model_info.subfolder}"
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()]
for md in installed_models:
base = md["base_model"]
model_type = md["model_type"]
name = md["model_name"]
key = ModelManager.create_key(name, base, model_type)
if key in model_dict:
model_dict[key].installed = True
else:
model_dict[key] = ModelLoadInfo(
name=name,
base_type=base,
model_type=model_type,
path=value.get("path"),
installed=True,
)
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
def _is_autoloaded(self, model_info: dict) -> bool:
path = model_info.get("path")
if not path:
return False
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
if autodir_path := getattr(self.config, autodir):
autodir_path = self.config.root_path / autodir_path
if Path(path).is_relative_to(autodir_path):
return True
return False
def list_models(self, model_type):
installed = self.mgr.list_models(model_type=model_type)
print()
print(f"Installed models of type `{model_type}`:")
print(f"{'Model Key':50} Model Path")
for i in installed:
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
print()
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
return models
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)])
def default_model(self) -> str:
starters = self.starter_models()
defaults = [x for x in starters if self.datasets[x].get("default", False)]
return defaults[0]
def install(self, selections: InstallSelections):
verbosity = dlogging.get_verbosity() # quench NSFW nags
dlogging.set_verbosity_error()
job = 1
jobs = len(selections.remove_models) + len(selections.install_models)
# remove requested models
for key in selections.remove_models:
name, base, mtype = self.mgr.parse_key(key)
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
try:
self.mgr.del_model(name, base, mtype)
except FileNotFoundError as e:
logger.warning(e)
job += 1
# add requested models
self._remove_installed(selections.install_models)
self._add_required_models(selections.install_models)
for path in selections.install_models:
logger.info(f"Installing {path} [{job}/{jobs}]")
try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1
dlogging.set_verbosity(verbosity)
self.mgr.commit()
def heuristic_import(
self,
model_path_id_or_url: Union[str, Path],
models_installed: Set[Path] = None,
) -> Dict[str, AddModelResult]:
"""
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
"""
if not models_installed:
models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path): self._install_path(path)})
# folders style or similar
elif path.is_dir() and any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif len(str(model_path_id_or_url).split("/")) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
return models_installed
def _remove_installed(self, model_list: List[str]):
all_models = self.all_models()
for path in model_list:
key = self.reverse_paths.get(path)
if key and all_models[key].installed:
logger.warning(f"{path} already installed. Skipping.")
model_list.remove(path)
def _add_required_models(self, model_list: List[str]):
additional_models = []
all_models = self.all_models()
for path in model_list:
if not (key := self.reverse_paths.get(path)):
continue
for requirement in all_models[key].requires:
requirement_key = self.reverse_paths.get(requirement)
if not all_models[requirement_key].installed:
additional_models.append(requirement)
model_list.extend(additional_models)
# install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
if not info:
logger.warning(f"Unable to parse format of {path}")
return None
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
attributes = self._make_attributes(path, info)
return self.mgr.add_model(
model_name=model_name,
base_model=info.base_type,
model_type=info.model_type,
model_attributes=attributes,
)
def _install_url(self, url: str) -> AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url, Path(staging))
if not location:
logger.error(f"Unable to download {url}. Skipping.")
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
dest.parent.mkdir(parents=True, exist_ok=True)
models_path = shutil.move(location, dest)
# staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str) -> AddModelResult:
# hack to recover models stored in subfolders --
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
subfolder = None
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
repo_id = match.group(1)
subfolder = match.group(2)
hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith(f"{subfolder}/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if f"{prefix}model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
elif f"{prefix}unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
else:
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
) # LoRA
break
elif (
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
): # vae, controlnet or some other standalone
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}learned_embeds.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
)
break
elif (
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
): # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
if not info:
logger.warning(f"Could not probe {location}. Skipping install.")
return {}
dest = (
self.config.models_path
/ info.base_type.value
/ info.model_type.value
/ self._get_model_name(repo_id, location)
)
if dest.exists():
shutil.rmtree(dest)
shutil.copytree(location, dest)
return self._install_path(dest, info)
def _get_model_name(self, path_name: str, location: Path) -> str:
"""
Calculate a name for the model - primitive implementation.
"""
if key := self.reverse_paths.get(path_name):
(name, base, mtype) = ModelManager.parse_key(key)
return name
elif location.is_dir():
return location.name
else:
return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
model_name = path.name if path.is_dir() else path.stem
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
if key := self.reverse_paths.get(self.current_id):
if key in self.datasets:
description = self.datasets[key].get("description") or description
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict(
path=str(rel_path),
description=str(description),
model_format=info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
dict(
variant=info.variant_type,
)
)
if info.format == "checkpoint":
try:
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
legacy_conf = Path(
self.config.legacy_conf_dir,
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
)
else:
legacy_conf = Path(
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
)
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
root = root or self.config.root_path
if path.is_relative_to(root):
return path.relative_to(root)
else:
return path
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
"""
Retrieve a StableDiffusion model from cache or remote and then
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
model = None
for variant in variants:
try:
model = DiffusionPipeline.from_pretrained(
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
subfolder=subfolder,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None
model.save_pretrained(staging / name, safe_serialization=True)
return staging / name
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(
repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
)
if p:
paths.append(p)
else:
logger.warning(f"Could not download {filename} from {repo_id}.")
return location if len(paths) > 0 else None
@classmethod
def _reverse_paths(cls, datasets) -> dict:
"""
Reverse mapping from repo_id/path to destination name.
"""
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
# -------------------------------------
def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n"
response = input(f"{prompt} [{default}] ") or default
if default_yes:
return response[0] not in ("n", "N")
else:
return response[0] in ("y", "Y")
# ---------------------------------------------
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
logger = InvokeAILogger.get_logger("InvokeAI")
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
model = model_class.from_pretrained(
model_name,
resume_download=True,
**kwargs,
)
model.save_pretrained(destination, safe_serialization=True)
return destination
# ---------------------------------------------
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
subfolder: str = None,
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
elif exist_size > 0:
logger.info(f"{model_name}: partial file found. Resuming...")
else:
logger.info(f"{model_name}: Downloading...")
try:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
return None
return model_dest

View File

@ -8,7 +8,7 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from invokeai.backend.model_manager.models.base import calc_model_size_by_data
from .resampler import Resampler

View File

@ -0,0 +1 @@
The contents of this directory are deprecated. model_manager.py is here only for reference.

View File

@ -1,114 +0,0 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Set, types
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
]
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self, model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return list(self.models_found)

View File

@ -1,75 +0,0 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -0,0 +1,27 @@
"""Initialization file for invokeai.backend.model_manager.config."""
from .config import ( # noqa F401
BaseModelType,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
)
# from .install import ModelInstall, ModelInstallJob # noqa F401
# from .loader import ModelInfo, ModelLoad # noqa F401
# from .lora import ModelPatcher, ONNXModelPatcher # noqa F401
from .models import OPENAPI_MODEL_CONFIGS, InvalidModelException, read_checkpoint_meta # noqa F401
from .probe import ModelProbe, ModelProbeInfo # noqa F401
from .search import ModelSearch # noqa F401
from .storage import ( # noqa F401
DuplicateModelException,
ModelConfigStore,
ModelConfigStoreSQL,
ModelConfigStoreYAML,
UnknownModelException,
)

View File

@ -1,5 +1,6 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
@ -25,13 +26,14 @@ import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
from typing import Any, Dict, List, Optional, Type, Union
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util import InvokeAILogger, Logger
from ..util import GIG
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
@ -63,20 +65,10 @@ class CacheStats(object):
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
class _CacheRecord:
size: int
model: Any
cache: ModelCache
cache: "ModelCache"
_locks: int
def __init__(self, cache, model: Any, size: int):
@ -112,10 +104,9 @@ class ModelCache(object):
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
logger: Logger = InvokeAILogger.get_logger(),
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -123,7 +114,6 @@ class ModelCache(object):
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
"""
self.model_infos: Dict[str, ModelBase] = dict()
@ -138,40 +128,37 @@ class ModelCache(object):
self.logger = logger
# used for stats collection
self.stats = None
self.stats: Optional[CacheStats] = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models: Dict[str, _CacheRecord] = dict()
self._cache_stack: List[str] = list()
# Note that the combination of model_path and submodel_type
# are sufficient to generate a unique cache key. This key
# is not the same as the unique hash used to identify models
# in invokeai.backend.model_manager.storage
def get_key(
self,
model_path: str,
base_model: BaseModelType,
model_type: ModelType,
model_path: Path,
submodel_type: Optional[SubModelType] = None,
):
key = f"{model_path}:{base_model}:{model_type}"
key = model_path.as_posix()
if submodel_type:
key += f":{submodel_type}"
return key
def _get_model_info(
self,
model_path: str,
model_path: Path,
model_class: Type[ModelBase],
base_model: BaseModelType,
model_type: ModelType,
):
model_info_key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=None,
)
model_info_key = self.get_key(model_path=model_path)
if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class(
model_path,
model_path.as_posix(),
base_model,
model_type,
)
@ -200,12 +187,8 @@ class ModelCache(object):
base_model=base_model,
model_type=model_type,
)
key = self.get_key(
model_path=model_path,
base_model=base_model,
model_type=model_type,
submodel_type=submodel,
)
key = self.get_key(model_path, submodel)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
@ -253,7 +236,7 @@ class ModelCache(object):
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.cache_size = int(self.max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
@ -306,8 +289,12 @@ class ModelCache(object):
)
class ModelLocker(object):
"""Context manager that locks models into VRAM."""
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
Initialize a context manager object that locks models into VRAM.
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock
@ -366,18 +353,6 @@ class ModelCache(object):
self._cache_stack.remove(cache_id)
self._cached_models.pop(cache_id, None)
def model_hash(
self,
model_path: Union[str, Path],
) -> str:
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
@ -429,8 +404,8 @@ class ModelCache(object):
refs = sys.getrefcount(cache_entry.model)
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
# Manually clear local variable references of just finished function calls.
# For some reason python doesn't want to garbage collect it even when gc.collect() is called
if refs > 2:
while True:
cleared = False

View File

@ -0,0 +1,366 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base_model='sd-1',
model_type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
model_format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
import warnings
from enum import Enum
from typing import List, Literal, Optional, Type, Union
import pydantic
# import these so that we can silence them
from diffusers import logging as diffusers_logging
from omegaconf.listconfig import ListConfig # to support the yaml backend
from pydantic import BaseModel, Extra, Field
from pydantic.error_wrappers import ValidationError
from transformers import logging as transformers_logging
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
# TODO: use this
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base_model: BaseModelType
model_type: ModelType
model_format: ModelFormat
key: str = Field(
description="key for model derived from original hash", default="<NOKEY>"
) # assigned on the first install
hash: Optional[str] = Field(
description="current hash key for model", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
license: Optional[str] = Field(description="License string")
source: Optional[str] = Field(description="Model download source (URL or repo_id)")
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable
class Config:
"""Pydantic configuration hint."""
use_enum_values = False
extra = Extra.forbid
validate_assignment = True
@pydantic.validator("tags", pre=True)
@classmethod
def _fix_tags(cls, v):
if isinstance(v, ListConfig): # to support yaml backend
v = list(v)
return v
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
model_format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
model_format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
model_format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(None)
variant: ModelVariantType = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfig, MainConfig):
"""Model config for main checkpoint models."""
class MainDiffusersConfig(DiffusersConfig, MainConfig):
"""Model config for main diffusers models."""
class ONNXSD1Config(MainConfig):
"""Model config for ONNX format models based on sd-1."""
model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
class ONNXSD2Config(MainConfig):
"""Model config for ONNX format models based on sd-2."""
model_format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
prediction_type: SchedulerPredictionType
upcast_attention: bool
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
model_format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
model_format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
model_format: Literal[ModelFormat.Diffusers]
AnyModelConfig = Union[
ModelConfigBase,
MainCheckpointConfig,
MainDiffusersConfig,
LoRAConfig,
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
VaeCheckpointConfig,
VaeDiffusersConfig,
ControlNetDiffusersConfig,
ControlNetCheckpointConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
_class_map: dict = {
ModelFormat.Checkpoint: {
ModelType.Main: MainCheckpointConfig,
ModelType.Vae: VaeCheckpointConfig,
},
ModelFormat.Diffusers: {
ModelType.Main: MainDiffusersConfig,
ModelType.Lora: LoRAConfig,
ModelType.Vae: VaeDiffusersConfig,
ModelType.ControlNet: ControlNetDiffusersConfig,
ModelType.CLIPVision: CLIPVisionDiffusersConfig,
},
ModelFormat.Lycoris: {
ModelType.Lora: LoRAConfig,
},
ModelFormat.Onnx: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.Olive: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.EmbeddingFile: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.EmbeddingFolder: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.InvokeAI: {
ModelType.IPAdapter: IPAdapterConfig,
},
}
@classmethod
def make_config(
cls,
model_data: Union[dict, ModelConfigBase],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
if key:
model_data.key = key
return model_data
try:
model_format = model_data.get("model_format")
model_type = model_data.get("model_type")
model_base = model_data.get("base_model")
class_to_return = dest_class or cls._class_map[model_format][model_type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
model = class_to_return.parse_obj(model_data)
if key:
model.key = key # ensure consistency
return model
except KeyError as exc:
raise InvalidModelConfigException(
f"Unknown combination of model_format '{model_format}' and model_type '{model_type}'"
) from exc
except ValidationError as exc:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc
# TO DO: Move this somewhere else
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@ -19,9 +19,8 @@
import re
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union
import requests
import torch
@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt(
# only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = requests.get(config_url).text
original_config = OmegaConf.load(original_config_file)
if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config["model"]["params"]["use_ema"]
if original_config.model["params"].get("use_ema") is not None:
extract_ema = original_config.model["params"]["use_ema"]
if (
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
and original_config["model"]["params"].get("parameterization") == "v"
and original_config.model["params"].get("parameterization") == "v"
):
prediction_type = "v_prediction"
upcast_attention = True
@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt(
num_in_channels = 4
if "unet_config" in original_config.model.params:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
"parameterization" in original_config.model["params"]
and original_config.model["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"}
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt(
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config")
@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs)
pipe.save_pretrained(
dump_path,
@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs)
pipe.save_pretrained(dump_path, safe_serialization=True)

View File

@ -0,0 +1,11 @@
"""Initialization file for threaded download manager."""
from .base import ( # noqa F401
DownloadEventHandler,
DownloadJobBase,
DownloadJobStatus,
DownloadQueueBase,
UnknownJobIDException,
)
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
from .queue import DownloadJobPath, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # noqa F401

View File

@ -0,0 +1,260 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Abstract base class for a multithreaded model download queue."""
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional, Union
import requests
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
class DownloadJobStatus(str, Enum):
"""State of a download job."""
IDLE = "idle" # not enqueued, will not run
ENQUEUED = "enqueued" # enqueued but not yet active
RUNNING = "running" # actively downloading
PAUSED = "paused" # previously started, now paused
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated by caller
class UnknownJobIDException(Exception):
"""Raised when an invalid Job is referenced."""
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@total_ordering
class DownloadJobBase(BaseModel):
"""Class to monitor and control a model download request."""
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
source: Any = Field(description="Where to download from. Specific types specified in child classes.")
destination: Path = Field(description="Destination of downloaded model on local disk")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
event_handlers: Optional[List[DownloadEventHandler]] = Field(
description="Callables that will be called whenever job status changes",
default_factory=list,
)
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
job_sequence: Optional[int] = Field(
description="Counter that records order in which this job was dequeued (used in unit testing)"
)
preserve_partial_downloads: bool = Field(
description="if true, then preserve partial downloads when cancelled or errored", default=False
)
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list."""
if self.event_handlers is not None:
self.event_handlers.append(handler)
def clear_event_handlers(self):
"""Clear all event handlers."""
self.event_handlers = list()
def cleanup(self, preserve_partial_downloads: bool = False):
"""Possibly do some action when work is finished."""
pass
class Config:
"""Config object for this pydantic class."""
arbitrary_types_allowed = True
validate_assignment = True
def __lt__(self, other: "DownloadJobBase") -> bool:
"""
Return True if self.priority < other.priority.
:param other: The DownloadJobBase that this will be compared against.
"""
if not hasattr(other, "priority"):
return NotImplemented
return self.priority < other.priority
class DownloadQueueBase(ABC):
"""Abstract base class for managing model downloads."""
@abstractmethod
def __init__(
self,
max_parallel_dl: int = 5,
event_handlers: List[DownloadEventHandler] = [],
requests_session: Optional[requests.sessions.Session] = None,
quiet: bool = False,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param event_handler: Optional callable that will be called each time a job status changes.
:param requests_session: Optional requests.sessions.Session object, for unit tests.
:param quiet: If true, don't log the start of download jobs. Useful for subrequests.
"""
pass
@abstractmethod
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
priority: int = 10,
start: Optional[bool] = True,
filename: Optional[Path] = None,
variant: Optional[str] = None, # FIXME: variant is only used in one specific subclass
access_token: Optional[str] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""
Create and submit a download job.
:param source: Source of the download - URL, repo_id or Path
:param destdir: Directory to download into.
:param priority: Initial priority for this job [10]
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]
:param variant: Variant to download, such as "fp16" (repo_ids only).
:param event_handlers: Optional callables that will be called whenever job status changes.
:returns the job: job.id will be a non-negative value after execution
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
def submit_download_job(
self,
job: DownloadJobBase,
start: Optional[bool] = True,
):
"""
Submit a download job.
:param job: A DownloadJobBase
:param start: Immediately start job [True]
After execution, `job.id` will be set to a non-negative value.
"""
pass
@abstractmethod
def release(self):
"""
Release resources used by queue.
If threaded downloads are
used, then this will stop the threads.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJobBase]:
"""
List active DownloadJobBases.
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJobBase:
"""
Return the DownloadJobBase corresponding to the string ID.
:param id: ID of the DownloadJobBase.
Exceptions:
* UnknownJobException
Note that once a job is completed, id_to_job() may no longer
recognize the job. Call id_to_job() before the job completes
if you wish to keep the job object around after it has
completed work.
"""
pass
@abstractmethod
def start_all_jobs(self):
"""Enqueue all stopped jobs."""
pass
@abstractmethod
def pause_all_jobs(self):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_all_jobs(self, preserve_partial: bool = False):
"""
Cancel all jobs (those in enqueued, running and paused states).
:param preserve_partial: Keep partially downloaded files [False].
"""
pass
@abstractmethod
def start_job(self, job: DownloadJobBase):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, job: DownloadJobBase):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
Cancel the job, clearing partial downloads and putting it into CANCELLED state.
:param preserve_partial: Keep partial downloads [False]
"""
pass
@abstractmethod
def join(self):
"""
Wait until all jobs are off the queue.
Note that once a job is completed, id_to_job() will
no longer recognize the job.
"""
pass
@abstractmethod
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
pass
@abstractmethod
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Given a job, translate its source field into a downloadable URL.
Intended to be subclassed to cover various source types.
"""
pass

View File

@ -0,0 +1,370 @@
import re
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
from huggingface_hub import HfApi, hf_hub_url
from pydantic import BaseModel, Field, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase
from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue
# regular expressions used to dispatch appropriate downloaders and metadata retrievers
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
class DownloadJobWithMetadata(DownloadJobRemoteSource):
"""A remote download that has metadata associated with it."""
metadata: ModelSourceMetadata = Field(
description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata
)
class DownloadJobMetadataURL(DownloadJobWithMetadata, DownloadJobURL):
"""DownloadJobWithMetadata with validation of the source URL."""
class DownloadJobRepoID(DownloadJobWithMetadata):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional[DownloadQueueBase] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class ModelDownloadQueue(DownloadQueue):
"""Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai."""
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""Create a download job and return its ID."""
cls: Optional[Type[DownloadJobBase]] = None
kwargs: Dict[str, Optional[str]] = dict()
if re.match(HTTP_RE, str(source)):
cls = DownloadJobWithMetadata
kwargs.update(access_token=access_token)
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs.update(
variant=variant,
access_token=access_token,
)
if cls:
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
event_handlers=event_handlers,
priority=priority,
**kwargs,
)
return self.submit_download_job(job, start)
else:
return super().create_download_job(
source=source,
destdir=destdir,
start=start,
priority=priority,
filename=filename,
variant=variant,
access_token=access_token,
event_handlers=event_handlers,
)
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
if isinstance(job, DownloadJobRepoID):
return self._download_repoid
elif isinstance(job, DownloadJobWithMetadata):
return self._download_with_resume
else:
return super().select_downloader(job)
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
assert isinstance(job, DownloadJobWithMetadata)
metadata = job.metadata
url = job.source
metadata_url = url
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
else:
version_data = resp["modelVersions"][0] # first one
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
# return the download url
return download_url
def _download_repoid(self, job: DownloadJobBase) -> None:
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
assert isinstance(job, DownloadJobRemoteSource)
if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused
return
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
if subjob.status == DownloadJobStatus.COMPLETED:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
assert isinstance(job, DownloadJobRepoID)
self._lock.acquire() # prevent status from being updated while we are setting up subqueue
self._update_job_status(job, DownloadJobStatus.RUNNING)
try:
job.subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
repo_id = job.source
variant = job.variant
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
job.subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token,
)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
self._lock.release()
if job.subqueue is not None:
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = []
# unfortunately the HF repo contains both files needed for the model
# as well as anything else the owner thought to include in the directory,
# including checkpoint files, different EMA versions, etc.
# This filters out just the file types needed for the model
for x in sibs:
if x.rfilename.endswith((".json", ".txt")):
paths.append(x.rfilename)
elif x.rfilename.endswith(("learned_embeds.bin", "ip_adapter.bin")):
paths.append(x.rfilename)
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin)$", x.rfilename):
paths.append(x.rfilename)
sizes = {x.rfilename: x.size for x in sibs}
prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, f"{prefix}model_index.json")
urls = [
(
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
file_variant, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
file_variant = None
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
return result

View File

@ -0,0 +1,432 @@
# Copyright (c) 2023, Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
import re
import shutil
import threading
import time
import traceback
from pathlib import Path
from queue import PriorityQueue
from typing import Callable, Dict, List, Optional, Set, Union
import requests
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from invokeai.backend.util import InvokeAILogger, Logger
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
# marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
# regular expression for picking up a URL
HTTP_RE = r"^https?://"
class DownloadJobPath(DownloadJobBase):
"""Download from a local Path."""
source: Path = Field(description="Local filesystem Path where model can be found")
class DownloadJobRemoteSource(DownloadJobBase):
"""A DownloadJob from a remote source that provides progress info."""
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
access_token: Optional[str] = Field(description="access token needed to access this resource")
class DownloadJobURL(DownloadJobRemoteSource):
"""Job declaration for downloading individual URLs."""
source: AnyHttpUrl = Field(description="URL to download")
class DownloadQueue(DownloadQueueBase):
"""Class for queued download of models."""
_jobs: Dict[int, DownloadJobBase]
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue
_lock: threading.RLock # to allow for reentrant locking for method calls
_logger: Logger
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
_next_job_id: int = 0
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
_requests: requests.sessions.Session
_quiet: bool = False
def __init__(
self,
max_parallel_dl: int = 5,
event_handlers: List[DownloadEventHandler] = [],
requests_session: Optional[requests.sessions.Session] = None,
quiet: bool = False,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param event_handler: Optional callable that will be called each time a job status changes.
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = dict()
self._next_job_id = 0
self._queue = PriorityQueue()
self._worker_pool = set()
self._lock = threading.RLock()
self._logger = InvokeAILogger.get_logger()
self._event_handlers = event_handlers
self._requests = requests_session or requests.Session()
self._quiet = quiet
self._start_workers(max_parallel_dl)
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""Create a download job and return its ID."""
kwargs: Dict[str, Optional[str]] = dict()
cls = DownloadJobBase
if Path(source).exists():
cls = DownloadJobPath
elif re.match(HTTP_RE, str(source)):
cls = DownloadJobURL
kwargs.update(access_token=access_token)
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
event_handlers=event_handlers,
priority=priority,
**kwargs,
)
return self.submit_download_job(job, start)
def submit_download_job(
self,
job: DownloadJobBase,
start: Optional[bool] = True,
):
"""Submit a job."""
# add the queue's handlers
for handler in self._event_handlers:
job.add_event_handler(handler)
with self._lock:
job.id = self._next_job_id
self._jobs[job.id] = job
self._next_job_id += 1
if start:
self.start_job(job)
return job
def release(self):
"""Signal our threads to exit when queue done."""
for thread in self._worker_pool:
if thread.is_alive():
self._queue.put(STOP_JOB)
def join(self):
"""Wait for all jobs to complete."""
self._queue.join()
def list_jobs(self) -> List[DownloadJobBase]:
"""List all the jobs."""
return list(self._jobs.values())
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
with self._lock:
to_delete = set()
try:
for job_id, job in self._jobs.items():
if self._in_terminal_state(job):
to_delete.add(job_id)
for job_id in to_delete:
del self._jobs[job_id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def id_to_job(self, id: int) -> DownloadJobBase:
"""Translate a job ID into a DownloadJobBase object."""
try:
return self._jobs[id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def start_job(self, job: DownloadJobBase):
"""Enqueue (start) the indicated job."""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
self._queue.put(job)
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def pause_job(self, job: DownloadJobBase):
"""
Pause (dequeue) the indicated job.
The job can be restarted with start_job() and the download will pick up
from where it left off.
"""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.PAUSED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def start_all_jobs(self):
"""Start (enqueue) all jobs that are idle or paused."""
with self._lock:
for job in self._jobs.values():
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
self.start_job(job)
def pause_all_jobs(self):
"""Pause all running jobs."""
with self._lock:
for job in self._jobs.values():
if not self._in_terminal_state(job):
self.pause_job(job)
def cancel_all_jobs(self, preserve_partial: bool = False):
"""Cancel all jobs (those not in enqueued, running or paused state)."""
with self._lock:
for job in self._jobs.values():
if not self._in_terminal_state(job):
self.cancel_job(job, preserve_partial)
def _in_terminal_state(self, job: DownloadJobBase):
return job.status in [
DownloadJobStatus.COMPLETED,
DownloadJobStatus.ERROR,
DownloadJobStatus.CANCELLED,
]
def _start_workers(self, max_workers: int):
"""Start the requested number of worker threads."""
for i in range(0, max_workers):
worker = threading.Thread(target=self._download_next_item, daemon=True)
worker.start()
self._worker_pool.add(worker)
def _download_next_item(self):
"""Worker thread gets next job on priority queue."""
done = False
while not done:
job = self._queue.get()
with self._lock:
job.job_sequence = self._sequence
self._sequence += 1
try:
if job == STOP_JOB: # marker that queue is done
done = True
if job.status == DownloadJobStatus.ENQUEUED:
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
do_download = self.select_downloader(job)
do_download(job)
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
finally:
self._queue.task_done()
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
if isinstance(job, DownloadJobURL):
return self._download_with_resume
elif isinstance(job, DownloadJobPath):
return self._download_path
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
return job.source
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""
dest = None
try:
assert isinstance(job, DownloadJobRemoteSource)
url = self.get_url_for_job(job)
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
exist_size = 0
resp = self._requests.get(url, headers=header, stream=True)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.destination.is_dir():
try:
file_name = ""
if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]):
file_name = match.group(1)
assert file_name != ""
self._validate_filename(
job.destination.as_posix(), file_name
) # will raise a ValueError exception if file_name is suspicious
except ValueError:
self._logger.warning(
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
)
file_name = os.path.basename(url)
except (KeyError, AssertionError):
file_name = os.path.basename(url)
job.destination = job.destination / file_name
dest = job.destination
else:
dest = job.destination
dest.parent.mkdir(parents=True, exist_ok=True)
if dest.exists():
job.bytes = dest.stat().st_size
header["Range"] = f"bytes={job.bytes}-"
open_mode = "ab"
resp = self._requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
self._logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
self._logger.warning(f"{dest}: complete file found. Skipping.")
self._update_job_status(job, DownloadJobStatus.COMPLETED)
return
if resp.status_code == 206 or exist_size > 0:
self._logger.warning(f"{dest}: partial file found. Resuming")
elif resp.status_code != 200:
raise HTTPError(resp.reason)
else:
self._logger.debug(f"{job.source}: Downloading {job.destination}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
self._update_job_status(job, DownloadJobStatus.RUNNING)
with open(dest, open_mode) as file:
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
return
job.bytes += file.write(data)
if job.bytes - last_report_bytes >= report_delta:
last_report_bytes = job.bytes
self._update_job_status(job)
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
return
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except KeyboardInterrupt as excp:
raise excp
except (HTTPError, OSError) as excp:
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
print(traceback.format_exc())
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
def _validate_filename(self, directory: str, filename: str):
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260
if "/" in filename:
raise ValueError
if filename.startswith(".."):
raise ValueError
if len(filename) > pc_name_max:
raise ValueError
if len(os.path.join(directory, filename)) > os.pathconf(directory, "PC_PATH_MAX"):
raise ValueError
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
"""Optionally change the job status and send an event indicating a change of state."""
with self._lock:
if new_status:
job.status = new_status
if self._in_terminal_state(job) and not self._quiet:
self._logger.info(f"{job.source}: Download job completed with status {job.status.value}")
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
job.job_started = time.time()
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
job.job_ended = time.time()
if job.event_handlers:
for handler in job.event_handlers:
try:
handler(job)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
if job.status != DownloadJobStatus.ERROR: # let handlers know, but don't cause infinite recursion
self._update_job_status(job, DownloadJobStatus.ERROR)
def _download_path(self, job: DownloadJobBase):
"""Call when the source is a Path or pathlike object."""
source = Path(job.source).resolve()
destination = Path(job.destination).resolve()
try:
self._update_job_status(job, DownloadJobStatus.RUNNING)
if source != destination:
shutil.move(source, destination)
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except OSError as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
def _cleanup_cancelled_job(self, job: DownloadJobBase):
job.cleanup(job.preserve_partial_downloads)
if not job.preserve_partial_downloads:
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}")
dest = Path(job.destination)
try:
if dest.is_file():
dest.unlink()
elif dest.is_dir():
shutil.rmtree(dest.as_posix(), ignore_errors=True)
except OSError as excp:
self._logger(excp)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from imohash import hashfile
from .models import InvalidModelException
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise InvalidModelException(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@ -0,0 +1,250 @@
# Copyright (c) 2023, Lincoln D. Stein
"""Model loader for InvokeAI."""
import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from shutil import move, rmtree
from typing import Optional, Tuple, Union
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
from .cache import CacheStats, ModelCache
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
from .storage import ModelConfigStore
@dataclass
class ModelInfo:
"""This is a context manager object that is used to intermediate access to a model."""
context: ModelCache.ModelLocker
name: str
base_model: BaseModelType
type: Union[ModelType, SubModelType]
key: str
location: Union[Path, str]
precision: torch.dtype
_cache: Optional[ModelCache] = None
def __enter__(self):
"""Context entry."""
return self.context.__enter__()
def __exit__(self, *args, **kwargs):
"""Context exit."""
self.context.__exit__(*args, **kwargs)
class ModelLoadBase(ABC):
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
@abstractmethod
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Return a model given its key.
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
pass
@property
@abstractmethod
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore object that supports this loader."""
pass
@property
@abstractmethod
def logger(self) -> Logger:
"""Return the current logger."""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the config object used by the loader."""
pass
@abstractmethod
def collect_cache_stats(self, cache_stats: CacheStats):
"""Replace cache statistics."""
pass
@abstractmethod
def resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Turn a potentially relative path into an absolute one in the models_dir."""
pass
@property
@abstractmethod
def precision(self) -> torch.dtype:
"""Return torch.fp16 or torch.fp32."""
pass
class ModelLoad(ModelLoadBase):
"""Implementation of ModelLoadBase."""
_app_config: InvokeAIAppConfig
_store: ModelConfigStore
_cache: ModelCache
_logger: Logger
_cache_keys: dict
def __init__(
self,
config: InvokeAIAppConfig,
store: Optional[ModelConfigStore] = None,
):
"""
Initialize ModelLoad object.
:param config: The app's InvokeAIAppConfig object.
"""
self._app_config = config
self._store = store or ModelRecordServiceBase.open(config)
self._logger = InvokeAILogger.get_logger()
self._cache_keys = dict()
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
precision = choose_precision(device) if config.precision == "auto" else config.precision
dtype = torch.float32 if precision == "float32" else torch.float16
self._logger.info(f"Rendering device = {device} ({device_name})")
self._logger.info(f"Maximum RAM cache size: {config.ram}")
self._logger.info(f"Maximum VRAM cache size: {config.vram}")
self._logger.info(f"Precision: {precision}")
self._cache = ModelCache(
max_cache_size=config.ram,
max_vram_cache_size=config.vram,
lazy_offloading=config.lazy_offload,
execution_device=device,
precision=dtype,
logger=self._logger,
)
@property
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore instance used by this class."""
return self._store
@property
def precision(self) -> torch.dtype:
"""Return torch.fp16 or torch.fp32."""
return self._cache.precision
@property
def logger(self) -> Logger:
"""Return the current logger."""
return self._logger
@property
def config(self) -> InvokeAIAppConfig:
"""Return the config object."""
return self._app_config
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Get the ModelInfo corresponding to the model with key "key".
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_config = self.store.get_model(key) # May raise a UnknownModelException
if model_config.model_type == "main" and not submodel_type:
raise InvalidModelException("submodel_type is required when loading a main model")
submodel_type = SubModelType(submodel_type) if submodel_type else None
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
submodel_type = None
model_class = self._get_implementation(model_config.base_model, model_config.model_type)
if not model_path.exists():
raise InvalidModelException(f"Files for model '{key}' not found at {model_path}")
dst_convert_path = self._get_model_convert_cache_path(model_path)
model_path = self.resolve_model_path(
model_class.convert_if_required(
model_config=model_config,
output_path=dst_convert_path,
)
)
model_context = self._cache.get_model(
model_path=model_path,
model_class=model_class,
base_model=model_config.base_model,
model_type=model_config.model_type,
submodel=submodel_type,
)
if key not in self._cache_keys:
self._cache_keys[key] = set()
self._cache_keys[key].add(model_context.key)
return ModelInfo(
context=model_context,
name=model_config.name,
base_model=model_config.base_model,
type=submodel_type or model_config.model_type,
key=model_config.key,
location=model_path,
precision=self._cache.precision,
_cache=self._cache,
)
def collect_cache_stats(self, cache_stats: CacheStats):
"""Save CacheStats object for stats collecting."""
self._cache.stats = cache_stats
def resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Turn a potentially relative path into an absolute one in the models_dir."""
return self._app_config.models_path / path
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _get_model_convert_cache_path(self, model_path):
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, bool]:
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = Path(model_config.path)
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override

View File

@ -12,7 +12,7 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from .models.lora import LoRAModel
from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw
"""
loras = [
@ -87,7 +87,7 @@ class ModelPatcher:
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -97,7 +97,7 @@ class ModelPatcher:
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -107,7 +107,7 @@ class ModelPatcher:
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -117,7 +117,7 @@ class ModelPatcher:
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
):
original_weights = dict()
@ -337,7 +337,7 @@ class ONNXModelPatcher:
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, torch.Tensor]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
@ -348,7 +348,7 @@ class ONNXModelPatcher:
orig_weights = dict()
try:
blended_loras = dict()
blended_loras: Dict[str, torch.Tensor] = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():

View File

@ -4,7 +4,7 @@ from typing import Optional
import psutil
import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from .libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB

View File

@ -1,5 +1,5 @@
"""
invokeai.backend.model_management.model_merge exports:
invokeai.backend.model_manager.merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
@ -9,14 +9,17 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
import warnings
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Set
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install_service import ModelInstallService
from ...backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelType
from .config import MainConfig
class MergeInterpolationMethod(str, Enum):
@ -27,8 +30,18 @@ class MergeInterpolationMethod(str, Enum):
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
_store: ModelConfigStore
_config: InvokeAIAppConfig
def __init__(self, store: ModelConfigStore, config: Optional[InvokeAIAppConfig] = None):
"""
Initialize a ModelMerger object.
:param store: Underlying storage manager for the running process.
:param config: InvokeAIAppConfig object (if not provided, default will be selected).
"""
self._store = store
self._config = config or InvokeAIAppConfig.get_config()
def merge_diffusion_models(
self,
@ -70,15 +83,14 @@ class ModelMerger(object):
def merge_diffusion_models_and_save(
self,
model_names: List[str],
base_model: Union[BaseModelType, str],
model_keys: List[str],
merged_model_name: str,
alpha: float = 0.5,
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
) -> ModelConfigBase:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
@ -92,25 +104,38 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
model_paths: List[Path] = list()
model_names = list()
config = self._config
store = self._store
base_models: Set[BaseModelType] = set()
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert (
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
for key in model_keys:
info = store.get_model(key)
assert isinstance(info, MainConfig)
model_names.append(info.name)
assert (
info["model_format"] == "diffusers"
), f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
info.model_format == "diffusers"
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
assert (
len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
info.variant == "normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([(config.root_path / info["path"]).as_posix()])
if key == model_keys[0]:
vae = info.vae
# tally base models used
base_models.add(info.base_model)
model_paths.extend([(config.models_path / info.path).as_posix()])
assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}"
base_model = base_models.pop()
merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp)
logger.debug(f"interp = {interp}, merge_method={merge_method}")
@ -124,17 +149,19 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
return self.manager.add_model(
merged_model_name,
base_model=base_model,
model_type=ModelType.Main,
model_attributes=attributes,
clobber=True,
# register model and get its unique key
installer = ModelInstallService(store=self._store, config=self._config)
key = installer.register_path(dump_path)
# update model's config
model_config = self._store.get_model(key)
model_config.update(
dict(
name=merged_model_name,
description=f"Merge of models {', '.join(model_names)}",
vae=vae,
)
)
self._store.update_model(key, model_config)
return model_config

View File

@ -1,22 +1,20 @@
import inspect
from enum import Enum
from typing import Literal, get_origin
from typing import Any, Literal, get_origin
from pydantic import BaseModel
from .base import ( # noqa: F401
BaseModelType,
DuplicateModelException,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
SubModelType,
read_checkpoint_meta,
)
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO:
@ -97,14 +95,12 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS: Any = list()
OPENAPI_MODEL_CONFIGS: Any = list()
class OpenAPIModelInfoBase(BaseModel):
model_name: str
base_model: BaseModelType
model_type: ModelType
key: str
for base_model, models in MODEL_CLASSES.items():

View File

@ -1,13 +1,14 @@
import inspect
import json
import os
import shutil
import sys
import typing
import warnings
from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
import numpy as np
@ -15,90 +16,40 @@ import onnx
import safetensors.torch
import torch
from diffusers import ConfigMixin, DiffusionPipeline
from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, Field
from transformers import logging as transformers_logging
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from ..config import ( # noqa F401
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
class DuplicateModelException(Exception):
class ModelNotFoundException(Exception):
"""Exception for when a model is not found on the expected path."""
pass
class InvalidModelException(Exception):
"""Exception for when a model is corrupted in some way; for example missing files."""
pass
class ModelNotFoundException(Exception):
pass
class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
# MoVQ = "movq"
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"
class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
class EmptyConfigLoader(ConfigMixin):
@classmethod
def load_config(cls, *args, **kwargs):
"""Load empty configuration."""
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs)
@ -132,7 +83,7 @@ class ModelBase(metaclass=ABCMeta):
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
@ -231,6 +182,15 @@ class ModelBase(metaclass=ABCMeta):
) -> Any:
raise NotImplementedError()
@classmethod
@abstractmethod
def convert_if_required(
cls,
model_config: ModelConfigBase,
output_path: str,
) -> str:
raise NotImplementedError()
class DiffusersModel(ModelBase):
# child_types: Dict[str, Type]
@ -453,22 +413,6 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
return checkpoint
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")
ONNX_WEIGHTS_NAME = "model.onnx"
@ -672,3 +616,34 @@ class IAIOnnxRuntimeModel:
# TODO: session options
return cls(model_path, provider=provider)
def trim_model_convert_cache(cache_path: Path, max_cache_size: int):
current_size = directory_size(cache_path)
logger = InvokeAILogger.get_logger()
if current_size <= max_cache_size:
return
logger.debug(
"Convert cache has gotten too large {(current_size / GIG):4.2f} > {(max_cache_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# either a 'unet/config.json' file, or a 'config.json' file at top level
def by_atime(path: Path) -> float:
for config in ["unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while current_size > max_cache_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
current_size -= victim_size

View File

@ -5,7 +5,7 @@ from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
from invokeai.backend.model_manager.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,

View File

@ -8,7 +8,9 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..config import ControlNetCheckpointConfig, ControlNetDiffusersConfig
from .base import (
GIG,
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
@ -32,12 +34,11 @@ class ControlNetModel(ModelBase):
# model_class: Type
# model_size: int
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class DiffusersConfig(ControlNetDiffusersConfig):
model_format: Literal[ControlNetModelFormat.Diffusers] = ControlNetModelFormat.Diffusers
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
class CheckpointConfig(ControlNetCheckpointConfig):
model_format: Literal[ControlNetModelFormat.Checkpoint] = ControlNetModelFormat.Checkpoint
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -112,27 +113,22 @@ class ControlNetModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
model_config: ModelConfigBase,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
if isinstance(model_config, ControlNetCheckpointConfig):
return _convert_controlnet_ckpt_and_cache(
model_path=model_path,
model_config=config.config,
model_config=model_config,
output_path=output_path,
base_model=base_model,
)
else:
return model_path
return model_config.path
def _convert_controlnet_ckpt_and_cache(
model_path: str,
model_config: ControlNetCheckpointConfig,
output_path: str,
base_model: BaseModelType,
model_config: ControlNetModel.CheckpointConfig,
max_cache_size: int,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
@ -140,7 +136,7 @@ def _convert_controlnet_ckpt_and_cache(
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
weights = app_config.root_path / model_config.path
output_path = Path(output_path)
logger.info(f"Converting {weights} to diffusers format")
@ -148,6 +144,11 @@ def _convert_controlnet_ckpt_and_cache(
if output_path.exists():
return output_path
# make sufficient size in the cache folder
size_needed = weights.stat().st_size
max_cache_size = (app_config.conversion_cache_size * GIG,)
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers

View File

@ -1,12 +1,11 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
from invokeai.backend.model_management.models.base import (
from invokeai.backend.model_manager.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
@ -17,15 +16,12 @@ from invokeai.backend.model_management.models.base import (
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
from ..config import ModelFormat
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
model_format: Literal[ModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
@ -42,7 +38,7 @@ class IPAdapterModel(ModelBase):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
return ModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@ -80,7 +76,7 @@ class IPAdapterModel(ModelBase):
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
if format == ModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -2,11 +2,12 @@ import bisect
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Literal, Optional, Union
import torch
from safetensors.torch import load_file
from ..config import LoRAConfig
from .base import (
BaseModelType,
InvalidModelException,
@ -27,8 +28,8 @@ class LoRAModelFormat(str, Enum):
class LoRAModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: LoRAModelFormat # TODO:
class Config(LoRAConfig):
model_format: Literal[LoRAModelFormat.LyCORIS] # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
@ -80,16 +81,14 @@ class LoRAModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
model_config: ModelConfigBase,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
if cls.detect_format(model_config.path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path
return model_config.path
class LoRALayerBase:

View File

@ -1,14 +1,13 @@
import json
import os
from enum import Enum
from typing import Literal, Optional
from typing import Literal
from omegaconf import OmegaConf
from pydantic import Field
from ..config import MainDiffusersConfig
from .base import (
BaseModelType,
DiffusersModel,
InvalidModelException,
ModelConfigBase,
ModelType,
@ -16,6 +15,7 @@ from .base import (
classproperty,
read_checkpoint_meta,
)
from .stable_diffusion import StableDiffusionModelBase
class StableDiffusionXLModelFormat(str, Enum):
@ -23,18 +23,13 @@ class StableDiffusionXLModelFormat(str, Enum):
Diffusers = "diffusers"
class StableDiffusionXLModel(DiffusersModel):
class StableDiffusionXLModel(StableDiffusionModelBase):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
class DiffusersConfig(MainDiffusersConfig):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
@ -104,26 +99,3 @@ class StableDiffusionXLModel(DiffusersModel):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -2,7 +2,7 @@ import json
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal, Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from omegaconf import OmegaConf
@ -11,6 +11,8 @@ from pydantic import Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ..cache import GIG
from ..config import MainCheckpointConfig, MainDiffusersConfig, SilenceWarnings
from .base import (
BaseModelType,
DiffusersModel,
@ -19,11 +21,10 @@ from .base import (
ModelNotFoundException,
ModelType,
ModelVariantType,
SilenceWarnings,
classproperty,
read_checkpoint_meta,
trim_model_convert_cache,
)
from .sdxl import StableDiffusionXLModel
class StableDiffusion1ModelFormat(str, Enum):
@ -31,17 +32,31 @@ class StableDiffusion1ModelFormat(str, Enum):
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class StableDiffusionModelBase(DiffusersModel):
"""Base class that defines common class methodsd."""
class CheckpointConfig(ModelConfigBase):
@classmethod
def convert_if_required(
cls,
model_config: ModelConfigBase,
output_path: str,
) -> str:
if isinstance(model_config, MainCheckpointConfig):
return _convert_ckpt_and_cache(
model_config=model_config,
output_path=output_path,
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_config.path
class StableDiffusion1Model(StableDiffusionModelBase):
class DiffusersConfig(MainDiffusersConfig):
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
class CheckpointConfig(MainCheckpointConfig):
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
@ -115,31 +130,13 @@ class StableDiffusion1Model(DiffusersModel):
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
load_safety_checker=False,
output_path=output_path,
)
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
class StableDiffusion2Model(StableDiffusionModelBase):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
@ -226,33 +223,10 @@ class StableDiffusion2Model(DiffusersModel):
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
output_path=output_path,
)
else:
return model_path
# TODO: rework
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[
StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
model_config: ModelConfigBase,
output_path: str,
use_save_model: bool = False,
**kwargs,
@ -263,17 +237,22 @@ def _convert_ckpt_and_cache(
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
version = model_config.base_model.value
weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
max_cache_size = app_config.conversion_cache_size * GIG
# return cached version if it exists
if output_path.exists():
return output_path
# make sufficient size in the cache folder
size_needed = weights.stat().st_size
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
# to avoid circular import errors
from ...util.devices import choose_torch_device, torch_dtype
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers

View File

@ -3,6 +3,7 @@ from typing import Literal
from diffusers import OnnxRuntimeModel
from ..config import ONNXSD1Config, ONNXSD2Config
from .base import (
BaseModelType,
DiffusersModel,
@ -21,9 +22,8 @@ class StableDiffusionOnnxModelFormat(str, Enum):
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
class Config(ONNXSD1Config):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
@ -72,19 +72,16 @@ class ONNXStableDiffusion1Model(DiffusersModel):
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
# config: ModelConfigBase, # not used?
# base_model: BaseModelType, # not used?
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
class Config(ONNXSD2Config):
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2

View File

@ -5,7 +5,7 @@ from typing import Literal, Optional
import torch
from diffusers import T2IAdapter
from invokeai.backend.model_management.models.base import (
from .base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,

View File

@ -1,8 +1,10 @@
import os
from typing import Optional
from typing import Literal, Optional
import torch
from ..config import ModelFormat, TextualInversionConfig
# TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw
from .base import (
@ -20,8 +22,15 @@ from .base import (
class TextualInversionModel(ModelBase):
# model_size: int
class Config(ModelConfigBase):
model_format: None
class FolderConfig(TextualInversionConfig):
"""Config for embeddings that are represented as a folder containing learned_embeds.bin."""
model_format: Literal[ModelFormat.EmbeddingFolder]
class FileConfig(TextualInversionConfig):
"""Config for embeddings that are contained in safetensors/checkpoint files."""
model_format: Literal[ModelFormat.EmbeddingFile]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
@ -79,9 +88,7 @@ class TextualInversionModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
model_config: ModelConfigBase,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
return model_config.path

View File

@ -1,7 +1,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
import safetensors
import torch
@ -9,7 +9,9 @@ from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from ..config import VaeCheckpointConfig, VaeDiffusersConfig
from .base import (
GIG,
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
@ -22,6 +24,7 @@ from .base import (
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
trim_model_convert_cache,
)
@ -34,8 +37,11 @@ class VaeModel(ModelBase):
# vae_class: Type
# model_size: int
class Config(ModelConfigBase):
model_format: VaeModelFormat
class DiffusersConfig(VaeDiffusersConfig):
model_format: Literal[VaeModelFormat.Diffusers] = VaeModelFormat.Diffusers
class CheckpointConfig(VaeCheckpointConfig):
model_format: Literal[VaeModelFormat.Checkpoint] = VaeModelFormat.Checkpoint
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
@ -97,28 +103,22 @@ class VaeModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
model_config: ModelConfigBase,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
if isinstance(model_config, VaeCheckpointConfig):
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
model_config=model_config,
output_path=output_path,
base_model=base_model,
model_config=config,
)
else:
return model_path
return model_config.path
# TODO: rework
def _convert_vae_ckpt_and_cache(
weights_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ModelConfigBase,
output_path: str,
max_cache_size: int,
) -> str:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
@ -126,7 +126,7 @@ def _convert_vae_ckpt_and_cache(
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights_path = app_config.root_dir / weights_path
weights_path = app_config.root_dir / model_config.path
output_path = Path(output_path)
"""
@ -148,6 +148,12 @@ def _convert_vae_ckpt_and_cache(
if output_path.exists():
return output_path
# make sufficient size in the cache folder
size_needed = weights_path.stat().st_size
max_cache_size = (app_config.conversion_cache_size * GIG,)
trim_model_convert_cache(output_path.parent, max_cache_size - size_needed)
base_model = model_config.base_model
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config

View File

@ -1,47 +1,89 @@
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
"""
Return descriptive information on Stable Diffusion models.
Module for probing a Stable Diffusion model and returning
its base type, model type, format and variant.
"""
import json
import re
from dataclasses import dataclass
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
from typing import Callable, Dict, Optional, Type
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
from .config import BaseModelType, ModelFormat, ModelType, ModelVariantType, SchedulerPredictionType
from .hash import FastModelHash
from .util import lora_token_vector_length, read_checkpoint_meta
@dataclass
class ModelProbeInfo(object):
class InvalidModelException(Exception):
"""Raised when an invalid model is encountered."""
class ModelProbeInfo(BaseModel):
"""Fields describing a probed model."""
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
format: ModelFormat
hash: str
variant_type: ModelVariantType = ModelVariantType("normal")
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
upcast_attention: Optional[bool] = False
image_size: Optional[int] = None
class ProbeBase(object):
"""forward declaration"""
class ModelProbeBase(ABC):
"""Class to probe a checkpoint, safetensors or diffusers folder."""
pass
@classmethod
@abstractmethod
def probe(
cls,
model: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Optional[ModelProbeInfo]:
"""
Probe model located at path and return ModelProbeInfo object.
:param model: Path to a model checkpoint or folder.
:param prediction_type_helper: An optional Callable that takes the model path
and returns the SchedulerPredictionType.
"""
pass
class ModelProbe(object):
PROBES = {
class ProbeBase(ABC):
"""Base model for probing checkpoint and diffusers-style models."""
@abstractmethod
def get_base_type(self) -> Optional[BaseModelType]:
"""Return the BaseModelType for the model."""
pass
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType for the model."""
pass
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Return the SchedulerPredictionType for the model."""
pass
def get_format(self) -> str:
"""Return the format for the model."""
pass
class ModelProbe(ModelProbeBase):
"""Class to probe a checkpoint, safetensors or diffusers folder."""
PROBES: Dict[str, dict] = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
@ -52,7 +94,6 @@ class ModelProbe(object):
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
@ -61,58 +102,46 @@ class ModelProbe(object):
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]):
"""
Register a probe subclass to use when interrogating a model.
@classmethod
def heuristic_probe(
cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> ModelProbeInfo:
if isinstance(model, Path):
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
:param format: The ModelFormat of the model to be probed.
:param model_type: The ModelType of the model to be probed.
:param probe_class: The class of the prober (inherits from ProbeBase).
"""
cls.PROBES[format][model_type] = probe_class
@classmethod
def probe(
cls,
model_path: Path,
model: Optional[Union[Dict, ModelMixin]] = None,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> ModelProbeInfo:
"""
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the SchedulerPredictionType.
"""
if model_path:
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
else:
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
model_info = None
"""Probe model."""
try:
model_type = (
cls.get_model_type_from_folder(model_path, model)
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
cls.get_model_type_from_folder(model_path)
if model_path.is_dir()
else cls.get_model_type_from_checkpoint(model_path)
)
format_type = "onnx" if model_type == ModelType.ONNX else format_type
format_type = (
"onnx" if model_type == ModelType.ONNX else "diffusers" if model_path.is_dir() else "checkpoint"
)
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, prediction_type_helper)
raise InvalidModelException(f"Unable to determine model type for {model_path}")
probe = probe_class(model_path, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
hash = FastModelHash.hash(model_path)
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
@ -123,33 +152,35 @@ class ModelProbe(object):
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
hash=hash,
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512,
)
except Exception:
raise
raise InvalidModelException(f"Unable to determine model type for {model_path}")
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
def get_model_type_from_checkpoint(cls, model_path: Path) -> Optional[ModelType]:
"""
Scan a checkpoint model and return its ModelType.
:param model_path: path to the model checkpoint/safetensors file
"""
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
return None
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
@ -174,39 +205,37 @@ class ModelProbe(object):
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
def get_model_type_from_folder(cls, folder_path: Path) -> Optional[ModelType]:
"""
Get the model type of a hugging-face style folder.
:param folder_path: Path to model folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
@ -219,59 +248,52 @@ class ModelProbe(object):
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
def _scan_and_load_checkpoint(cls, model: Path) -> dict:
if model.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model)
return torch.load(model)
else:
return safetensors.torch.load_file(model)
@classmethod
def _scan_model(cls, model_name, checkpoint):
def _scan_model(cls, model: Path):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
Scan a model for malicious code.
:param model: Path to the model to be scanned
Raises an Exception if unsafe code is found.
"""
# scan model
scan_result = scan_file_path(checkpoint)
scan_result = scan_file_path(model)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
def get_base_type(self) -> BaseModelType:
pass
def get_variant_type(self) -> ModelVariantType:
pass
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
pass
def get_format(self) -> str:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
) -> BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
"""Base class for probing checkpoint-style models."""
def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None):
"""Initialize the CheckpointProbeBase object."""
self.checkpoint_path = checkpoint_path
self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.helper = helper
def get_base_type(self) -> BaseModelType:
def get_base_type(self) -> Optional[BaseModelType]:
"""Return the BaseModelType of a checkpoint-style model."""
pass
def get_format(self) -> str:
"""Return the format of a checkpoint-style model."""
return "checkpoint"
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
"""Return the ModelVariantType of a checkpoint-style model."""
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
@ -289,7 +311,10 @@ class CheckpointProbeBase(ProbeBase):
class PipelineCheckpointProbe(CheckpointProbeBase):
"""Probe a checkpoint-style main model."""
def get_base_type(self) -> BaseModelType:
"""Return the ModelBaseType for the checkpoint-style main model."""
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
@ -338,16 +363,23 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase):
"""Probe a Checkpoint-style VAE model."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the VAE model."""
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
"""Probe for LoRA Checkpoint Files."""
def get_format(self) -> str:
"""Return the format of the LoRA."""
return "lycoris"
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the LoRA."""
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
@ -358,14 +390,18 @@ class LoRACheckpointProbe(CheckpointProbeBase):
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""TextualInversion checkpoint prober."""
def get_format(self) -> str:
return None
"""Return the format of a TextualInversion emedding."""
return ModelFormat.EmbeddingFile
def get_base_type(self) -> BaseModelType:
"""Return BaseModelType of the checkpoint model."""
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
@ -377,12 +413,14 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
else:
return None
raise InvalidModelException("Unknown base model for {self.checkpoint_path}")
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Probe checkpoint-based ControlNet models."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the model."""
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
@ -394,18 +432,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Probe IP adapter models."""
def get_base_type(self) -> BaseModelType:
"""Probe base type."""
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
"""Probe ClipVision adapter models."""
def get_base_type(self) -> BaseModelType:
"""Probe base type."""
raise NotImplementedError()
@ -418,24 +460,33 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
self.model = model
"""Class for probing folder-based models."""
def __init__(self, folder_path: Path, helper: Optional[Callable] = None): # not used
"""
Initialize the folder prober.
:param model: Path to the model to be probed.
:param helper: Callable for returning the SchedulerPredictionType (unused).
"""
self.folder_path = folder_path
def get_variant_type(self) -> ModelVariantType:
"""Return the model's variant type."""
return ModelVariantType.Normal
def get_format(self) -> str:
"""Return the model's format."""
return "diffusers"
class PipelineFolderProbe(FolderProbeBase):
"""Probe a pipeline (main) folder."""
def get_base_type(self) -> BaseModelType:
if self.model:
unet_conf = self.model.unet.config
else:
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
"""Return the BaseModelType of a pipeline folder."""
with open(self.folder_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
@ -448,29 +499,21 @@ class PipelineFolderProbe(FolderProbeBase):
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
return SchedulerPredictionType(prediction_type)
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType for diffusers-style main models."""
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
config_file = self.folder_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
@ -485,7 +528,10 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase):
"""Class for probing folder-style models."""
def get_base_type(self) -> BaseModelType:
"""Get base type of model."""
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
@ -515,30 +561,41 @@ class VaeFolderProbe(FolderProbeBase):
class TextualInversionFolderProbe(FolderProbeBase):
"""Probe a HuggingFace-style TextualInversion folder."""
def get_format(self) -> str:
return None
"""Return the format of the TextualInversion."""
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
"""Return the ModelBaseType of the HuggingFace-style Textual Inversion Folder."""
path = self.folder_path / "learned_embeds.bin"
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
raise InvalidModelException("This textual inversion folder does not contain a learned_embeds.bin file.")
return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
"""Probe an ONNX-format folder."""
def get_format(self) -> str:
"""Return the format of the folder (always "onnx")."""
return "onnx"
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of the ONNX folder."""
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType of the ONNX folder."""
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
"""Probe a ControlNet model folder."""
def get_base_type(self) -> BaseModelType:
"""Return the BaseModelType of a ControlNet model folder."""
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
@ -549,13 +606,11 @@ class ControlNetFolderProbe(FolderProbeBase):
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
@ -563,7 +618,10 @@ class ControlNetFolderProbe(FolderProbeBase):
class LoRAFolderProbe(FolderProbeBase):
"""Probe a LoRA model folder."""
def get_base_type(self) -> BaseModelType:
"""Get the ModelBaseType of a LoRA model folder."""
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
@ -572,14 +630,18 @@ class LoRAFolderProbe(FolderProbeBase):
break
if not model_file:
raise InvalidModelException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file, None).get_base_type()
return LoRACheckpointProbe(model_file).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
"""Class for probing IP-Adapter models."""
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
"""Get format of ip adapter."""
return ModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
"""Get base type of ip adapter."""
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
@ -597,7 +659,10 @@ class IPAdapterFolderProbe(FolderProbeBase):
class CLIPVisionFolderProbe(FolderProbeBase):
"""Probe for folder-based CLIPVision models."""
def get_base_type(self) -> BaseModelType:
"""Get base type."""
return BaseModelType.Any
@ -622,22 +687,25 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
diffusers = ModelFormat("diffusers")
checkpoint = ModelFormat("checkpoint")
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe(diffusers, ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe(diffusers, ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
ModelProbe.register_probe(checkpoint, ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe(checkpoint, ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe)

View File

@ -0,0 +1,198 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class and implementation for recursive directory search for models.
Example usage:
```
from invokeai.backend.model_manager import ModelSearch, ModelProbe
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
```
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.backend.util import InvokeAILogger, Logger
default_logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
class ModelSearchBase(ABC, BaseModel):
"""
Abstract directory traversal model search class
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
"""
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
underscore_attrs_are_private = True
arbitrary_types_allowed = True
@abstractmethod
def search_started(self):
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path):
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self):
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
_directory: Path = Field(default=None)
_models_found: Set[Path] = Field(default=None)
_scanned_dirs: Set[Path] = Field(default=None)
_pruned_paths: Set[Path] = Field(default=None)
def search_started(self):
self._models_found = set()
self._scanned_dirs = set()
self._pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
def model_found(self, model: Path):
self.stats.models_found += 1
if not self.on_model_found:
self.stats.models_filtered += 1
self._models_found.add(model)
return
if self.on_model_found(model):
self.stats.models_filtered += 1
self._models_found.add(model)
def search_completed(self):
if self.on_search_completed:
self.on_search_completed(self._models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self.search_completed()
return self._models_found
def _walk_directory(self, path: Union[Path, str]):
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self._pruned_paths.add(Path(root))
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
]
):
self._scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))

View File

@ -0,0 +1,13 @@
"""Initialization file for invokeai.backend.model_manager.storage."""
import pathlib
from ..config import AnyModelConfig # noqa F401
from .base import ( # noqa F401
ConfigFileVersionMismatchException,
DuplicateModelException,
ModelConfigStore,
UnknownModelException,
)
from .migrate import migrate_models_store # noqa F401
from .sql import ModelConfigStoreSQL # noqa F401
from .yaml import ModelConfigStoreYAML # noqa F401

View File

@ -0,0 +1,166 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Set, Union
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelConfigStore(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> ModelConfigBase:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Return models containing all of the listed tags.
:param tags: Set of tags to search on.
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> Optional[AnyModelConfig]:
"""Return the model having the indicated path."""
pass
@abstractmethod
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_name()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_name(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
"More than one model share the same name and type: {base_model}/{model_type}/{model_name}"
)
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> ModelConfigBase:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
return self.update_model(key, {"name": new_name})

View File

@ -0,0 +1,67 @@
# Copyright (c) 2023 The InvokeAI Development Team
import shutil
from pathlib import Path
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType
from .base import CONFIG_FILE_VERSION
def migrate_models_store(config: InvokeAIAppConfig) -> Path:
"""Migrate models from v1 models.yaml to v3.2 models.yaml."""
# avoid circular import
from invokeai.backend.model_manager.install import DuplicateModelException, ModelInstall
from invokeai.backend.model_manager.storage import get_config_store
app_config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger()
old_file: Path = app_config.model_conf_path
new_file: Path = old_file.with_name("models3_2.yaml")
old_conf = OmegaConf.load(old_file)
store = get_config_store(new_file)
installer = ModelInstall(store=store)
logger.info(f"Migrating old models file at {old_file} to new {CONFIG_FILE_VERSION} format")
for model_key, stanza in old_conf.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
new_key = "<NOKEY>"
try:
path = app_config.models_path / stanza["path"]
new_key = installer.register_path(path)
except DuplicateModelException:
# if model already installed, then we just update its info
models = store.search_by_name(
model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type)
)
if len(models) != 1:
continue
new_key = models[0].key
except Exception as excp:
print(str(excp))
if new_key != "<NOKEY>":
model_info = store.get_model(new_key)
if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig):
model_info.vae = (app_config.models_path / vae).as_posix()
if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig):
model_info.config = (app_config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
shutil.move(old_file, str(old_file) + ".orig")
shutil.move(new_file, old_file)
return old_file

View File

@ -0,0 +1,468 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Implementation of ModelConfigStore using a SQLite3 database
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreYAML("./configs/models.yaml")
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
tags=['sfw','cartoon']
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
import threading
from pathlib import Path
from typing import List, Optional, Set, Union
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
class ModelConfigStoreSQL(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = lock
with self._lock:
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- These 4 fields are enums in python, unrestricted string here
base_model TEXT NOT NULL,
model_type TEXT NOT NULL,
model_name TEXT NOT NULL,
model_path TEXT NOT NULL,
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
# model_tag table 1:M relation between model key and tag(s)
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_tag (
id TEXT NOT NULL,
tag_id INTEGER NOT NULL,
FOREIGN KEY(id) REFERENCES model_config(id),
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
UNIQUE(id,tag_id)
);
"""
)
# tags table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
tag_id INTEGER NOT NULL PRIMARY KEY,
tag_text TEXT NOT NULL UNIQUE
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add trigger to remove tags when model is deleted
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_deleted
AFTER DELETE
ON model_config
BEGIN
DELETE from model_tag WHERE id=old.id;
END;
"""
)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
with self._lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base_model,
model_type,
model_name,
model_path,
config
)
VALUES (?,?,?,?,?,?);
""",
(
key,
record.base_model,
record.model_type,
record.name,
record.path,
json_serialized,
),
)
if record.tags:
self._update_tags(key, record.tags)
self._conn.commit()
except sqlite3.IntegrityError as e:
self._conn.rollback()
if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
else:
raise e
except sqlite3.Error as e:
self._conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def _update_tags(self, key: str, tags: List[str]) -> None:
"""Update tags for model with key."""
# remove previous tags from this model
self._cursor.execute(
"""--sql
DELETE FROM model_tag
WHERE id=?;
""",
(key,),
)
# NOTE: isn't there a more elegant way of doing this than one tag
# at a time, with a select to get the tag ID?
for tag in tags:
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO tags (
tag_text
)
VALUES (?);
""",
(tag,),
)
self._cursor.execute(
"""--sql
SELECT tag_id
FROM tags
WHERE tag_text = ?
LIMIT 1;
""",
(tag,),
)
tag_id = self._cursor.fetchone()[0]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO model_tag (
id,
tag_id
)
VALUES (?,?);
""",
(key, tag_id),
)
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.dict()) # and turn it into a json string.
with self._lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base_model=?,
model_type=?,
model_name=?,
model_path=?,
config=?
WHERE id=?;
""",
(record.base_model, record.model_type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException
if record.tags:
self._update_tags(key, record.tags)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._lock:
try:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
raise e
return count > 0
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""Return models containing all of the listed tags."""
# rather than create a hairy SQL cross-product, we intersect
# tag results in a stepwise fashion at the python level.
results = []
with self._lock:
try:
matches: Set[str] = set()
for tag in tags:
self._cursor.execute(
"""--sql
SELECT a.id FROM model_tag AS a,
tags AS b
WHERE a.tag_id=b.tag_id
AND b.tag_text=?;
""",
(tag,),
)
model_keys = {x[0] for x in self._cursor.fetchall()}
matches = matches.intersection(model_keys) if len(matches) > 0 else model_keys
if matches:
self._cursor.execute(
f"""--sql
SELECT config FROM model_config
WHERE id IN ({','.join('?' * len(matches))});
""",
tuple(matches),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("model_name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base_model=?")
bindings.append(base_model)
if model_type:
where_clause.append("model_type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock:
try:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""Return the model with the indicated path, or None."""
raise NotImplementedError("search_by_path not implemented in storage.sql")

View File

@ -0,0 +1,239 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Implementation of ModelConfigStore using a YAML file.
Typical usage:
from invokeai.backend.model_manager.storage.yaml import ModelConfigStoreYAML
store = ModelConfigStoreYAML("./configs/models.yaml")
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
model_type='embedding',
model_format='embedding_file',
author='Anonymous',
tags=['sfw','cartoon']
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base_model)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_tag({'sfw','oss license'})
configs = store.search_by_name(base_model='sd-2', model_type='main')
"""
import threading
from enum import Enum
from pathlib import Path
from typing import List, Optional, Set, Union
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import (
CONFIG_FILE_VERSION,
ConfigFileVersionMismatchException,
DuplicateModelException,
ModelConfigStore,
UnknownModelException,
)
class ModelConfigStoreYAML(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_filename: Path
_config: DictConfig
_lock: threading.RLock
def __init__(self, config_file: Path):
"""Initialize ModelConfigStore object with a .yaml file."""
super().__init__()
self._filename = Path(config_file).absolute() # don't let chdir mess us up!
self._lock = threading.RLock()
if not self._filename.exists():
self._initialize_yaml()
config = OmegaConf.load(self._filename)
assert isinstance(config, DictConfig)
self._config = config
if str(self.version) != CONFIG_FILE_VERSION:
raise ConfigFileVersionMismatchException
def _initialize_yaml(self):
with self._lock:
self._filename.parent.mkdir(parents=True, exist_ok=True)
with open(self._filename, "w") as yaml_file:
yaml_file.write(yaml.dump({"__metadata__": {"version": CONFIG_FILE_VERSION}}))
def _commit(self):
with self._lock:
newfile = Path(str(self._filename) + ".new")
yaml_str = OmegaConf.to_yaml(self._config)
with open(newfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml_str)
newfile.replace(self._filename)
@property
def version(self) -> str:
"""Return version of this config file/database."""
return self._config.__metadata__.get("version")
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
with self._lock:
if key in self._config:
existing_model = self.get_model(key)
raise DuplicateModelException(
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
)
self._config[key] = self._fix_enums(dict_fields)
self._commit()
return self.get_model(key)
def _fix_enums(self, original: dict) -> dict:
"""In python 3.9, omegaconf stores incorrectly stringified enums."""
fixed_dict = {}
for key, value in original.items():
fixed_dict[key] = value.value if isinstance(value, Enum) else value
return fixed_dict
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._lock:
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config.pop(key)
self._commit()
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key) # ensure it is a valid config obect
dict_fields = record.dict() # and back to a dict with valid fields
with self._lock:
if key not in self._config:
raise UnknownModelException(f"Unknown key '{key}' for model config")
self._config[key] = self._fix_enums(dict_fields)
self._commit()
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
try:
record = self._config[key]
return ModelConfigFactory.make_config(record, key)
except KeyError as e:
raise UnknownModelException(f"Unknown key '{key}' for model config") from e
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
return key in self._config
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
"""
Return models containing all of the listed tags.
:param tags: Set of tags to search on.
"""
results = []
tags = set(tags)
with self._lock:
for config in self.all_models():
config_tags = set(config.tags or [])
if tags.difference(config_tags): # not all tags in the model
continue
results.append(config)
return results
def search_by_name(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results: List[ModelConfigBase] = list()
with self._lock:
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, str(key))
if model_name and model.name != model_name:
continue
if base_model and model.base_model != base_model:
continue
if model_type and model.model_type != model_type:
continue
results.append(model)
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""Return the model with the indicated path, or None."""
with self._lock:
for key, record in self._config.items():
if key == "__metadata__":
continue
model = ModelConfigFactory.make_config(record, str(key))
if model.path == path:
return model
return None

View File

@ -0,0 +1,162 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""
Various utilities used by the model manager.
"""
import json
import warnings
from pathlib import Path
from typing import Optional, Union
import safetensors
import torch
from diffusers import logging as diffusers_logging
from picklescan.scanner import scan_file_path
from transformers import logging as transformers_logging
class SilenceWarnings(object):
"""
Context manager that silences warnings from transformers and diffusers.
Usage:
with SilenceWarnings():
do_something_that_generates_warnings()
"""
def __init__(self):
"""Initialize SilenceWarnings context."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
"""Entry into the context."""
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Exit from the context."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")
def lora_token_vector_length(checkpoint: dict) -> Optional[int]:
"""
Given a checkpoint in memory, return the lora token vector length.
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {
"pt",
"torch",
"pytorch",
}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(str(path))
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@ -11,6 +11,7 @@ import logging
import math
import os
import random
import re
from pathlib import Path
from typing import Optional
@ -41,8 +42,8 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
from invokeai.app.services.model_manager_service import BaseModelType, ModelManagerService, ModelType
from invokeai.backend.model_manager import SubModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
@ -66,7 +67,6 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__)
@ -114,7 +114,6 @@ def parse_args():
general_group.add_argument(
"--output_dir",
type=Path,
default=f"{config.root}/text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
model_group.add_argument(
@ -550,8 +549,11 @@ def do_textual_inversion_training(
local_rank = env_local_rank
# setting up things the way invokeai expects them
output_dir = output_dir or config.root_path / "text-inversion-output"
print(f"output_dir={output_dir}")
if not os.path.isabs(output_dir):
output_dir = os.path.join(config.root, output_dir)
output_dir = Path(config.root, output_dir)
logging_dir = output_dir / logging_dir
@ -564,14 +566,15 @@ def do_textual_inversion_training(
project_config=accelerator_config,
)
model_manager = ModelManagerService(config, logger)
model_manager = ModelManagerService(config)
# The InvokeAI logger already does this...
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# logging.basicConfig(
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
# datefmt="%m/%d/%Y %H:%M:%S",
# level=logging.INFO,
# )
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
@ -603,17 +606,30 @@ def do_textual_inversion_training(
elif output_dir is not None:
os.makedirs(output_dir, exist_ok=True)
known_models = model_manager.model_names()
model_name = model.split("/")[-1]
model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None)
assert model_meta is not None, f"Unknown model: {model}"
model_info = model_manager.model_info(*model_meta)
assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'"
tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder)
vae_info = model_manager.get_model(*model_meta, submodel=SubModelType.Vae)
unet_info = model_manager.get_model(*model_meta, submodel=SubModelType.UNet)
if len(model) == 32 and re.match(r"^[0-9a-f]+$", model): # looks like a key, not a model name
model_key = model
else:
parts = model.split("/")
if len(parts) == 3:
base_model, model_type, model_name = parts
else:
model_name = parts[-1]
base_model = BaseModelType("sd-1")
model_type = ModelType.Main
models = model_manager.list_models(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
assert len(models) > 0, f"Unknown model: {model}"
assert len(models) < 2, "More than one model named {model_name}. Please pass key instead."
model_key = models[0].key
tokenizer_info = model_manager.get_model(model_key, submodel_type=SubModelType.Tokenizer)
noise_scheduler_info = model_manager.get_model(model_key, submodel_type=SubModelType.Scheduler)
text_encoder_info = model_manager.get_model(model_key, submodel_type=SubModelType.TextEncoder)
vae_info = model_manager.get_model(model_key, submodel_type=SubModelType.Vae)
unet_info = model_manager.get_model(model_key, submodel_type=SubModelType.UNet)
pipeline_args = dict(local_files_only=True)
if tokenizer_name:

View File

@ -1,6 +1,8 @@
"""
Initialization file for invokeai.backend.util
"""
from logging import Logger # noqa: F401
from .attention import auto_detect_slice_size # noqa: F401
from .devices import ( # noqa: F401
CPU_DEVICE,
@ -11,4 +13,13 @@ from .devices import ( # noqa: F401
normalize_device,
torch_dtype,
)
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
from .logging import InvokeAILogger # noqa: F401
from .util import ( # noqa: F401
GIG,
Chdir,
ask_user,
directory_size,
download_with_resume,
instantiate_from_config,
url_attachment_name,
)

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import platform
from contextlib import nullcontext
from typing import Union
from typing import Literal, Union
import torch
from packaging import version
@ -42,6 +42,13 @@ def choose_precision(device: torch.device) -> str:
return "float32"
def get_precision() -> Literal["float16", "float32"]:
device = torch.device(choose_torch_device())
precision = choose_precision(device) if config.precision == "auto" else config.precision
assert precision in ["float16", "float32"]
return precision
def torch_dtype(device: torch.device) -> torch.dtype:
if config.full_precision:
return torch.float32

View File

@ -180,6 +180,7 @@ import socket
import urllib.parse
from abc import abstractmethod
from pathlib import Path
from typing import Dict
from invokeai.app.services.config import InvokeAIAppConfig
@ -293,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
}
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
return self.FORMATS[levelno]
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
@ -332,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
}
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
return self.FORMATS[levelno]
LOG_FORMATTERS = {
@ -344,17 +345,19 @@ LOG_FORMATTERS = {
class InvokeAILogger(object):
loggers = dict()
loggers: Dict[str, logging.Logger] = dict()
@classmethod
def get_logger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger:
"""Return a logger appropriately configured for the current InvokeAI configuration."""
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name)
config = config or InvokeAIAppConfig.get_config() # in case None is passed
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.get_loggers(config):
logger.addHandler(ch)

View File

@ -6,9 +6,10 @@ import pytest
import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
from invokeai.app.services.model_install_service import ModelInstallService
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
@pytest.fixture(scope="session")
@ -24,11 +25,16 @@ def model_installer():
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
# a singleton.
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
#
# REPLY(lstein): Don't use get_config() here. Just use the regular pydantic constructor.
#
config = InvokeAIAppConfig(log_level="info")
model_store = ModelRecordServiceBase.open(config)
return ModelInstallService(store=model_store, config=config)
def install_and_load_model(
model_installer: ModelInstall,
model_installer: ModelInstallService,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,
@ -52,15 +58,19 @@ def install_and_load_model(
ModelInfo
"""
# If the requested model is already installed, return its ModelInfo.
with contextlib.suppress(ModelNotFoundException):
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
loader = ModelLoad(config=model_installer.config, store=model_installer.store)
with contextlib.suppress(UnknownModelException):
model = model_installer.store.model_info_by_name(model_name, base_model, model_type)
return loader.get_model(model.key, submodel_type)
# Install the requested model.
model_installer.heuristic_import(model_path_id_or_url)
model_installer.install(model_path_id_or_url)
model_installer.wait_for_installs()
try:
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
except ModelNotFoundException as e:
model = model_installer.store.model_info_by_name(model_name, base_model, model_type)
return loader.get_model(model.key, submodel_type)
except UnknownModelException as e:
raise Exception(
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
f" the installation id ('{model_path_id_or_url}'). Error: {e}"

View File

@ -2,14 +2,11 @@ import base64
import importlib
import io
import math
import multiprocessing as mp
import os
import re
from collections import abc
from inspect import isfunction
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import Optional
import numpy as np
import requests
@ -21,6 +18,9 @@ import invokeai.backend.util.logging as logger
from .devices import torch_dtype
# actual size of a gig
GIG = 1073741824
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
@ -101,112 +101,6 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
logger.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
else:
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
logger.info("Start prefetching...")
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
logger.error("Exception: ", e)
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
@ -269,7 +163,7 @@ def ask_user(question: str, answers: list):
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Optional[Path]:
"""
Download a model file.
:param url: https, http or ftp URL
@ -286,10 +180,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
content_length = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except AttributeError:
file_name = os.path.basename(url)
file_name = response_attachment(resp) or os.path.basename(url)
dest = dest / file_name
else:
dest.parent.mkdir(parents=True, exist_ok=True)
@ -338,15 +229,24 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
return dest
def url_attachment_name(url: str) -> dict:
def response_attachment(response: requests.Response) -> Optional[str]:
try:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
if disposition := response.headers.get("Content-Disposition"):
if match := re.search('filename="(.+)"', disposition):
return match.group(1)
return None
except Exception:
return None
def url_attachment_name(url: str) -> Optional[str]:
resp = requests.get(url)
if resp.ok:
return response_attachment(resp)
else:
return None
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None
@ -363,6 +263,19 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
return image_base64
def directory_size(directory: Path) -> int:
"""
Returns the aggregate size of all files in a directory (bytes).
"""
sum = 0
for root, dirs, files in os.walk(directory):
for f in files:
sum += Path(root, f).stat().st_size
for d in dirs:
sum += Path(root, d).stat().st_size
return sum
class Chdir(object):
"""Context manager to chdir to desired directory and change back after context exits:
Args:

View File

@ -1,156 +1,157 @@
# This file predefines a few models that the user may want to install.
sd-1/main/stable-diffusion-v1-5:
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
repo_id: runwayml/stable-diffusion-v1-5
source: runwayml/stable-diffusion-v1-5
recommended: True
default: True
sd-1/main/stable-diffusion-v1-5-inpainting:
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
repo_id: runwayml/stable-diffusion-inpainting
source: runwayml/stable-diffusion-inpainting
recommended: True
sd-2/main/stable-diffusion-2-1:
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-1
source: stabilityai/stable-diffusion-2-1
recommended: False
sd-2/main/stable-diffusion-2-inpainting:
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
repo_id: stabilityai/stable-diffusion-2-inpainting
source: stabilityai/stable-diffusion-2-inpainting
recommended: False
sdxl/main/stable-diffusion-xl-base-1-0:
description: Stable Diffusion XL base model (12 GB)
repo_id: stabilityai/stable-diffusion-xl-base-1.0
source: stabilityai/stable-diffusion-xl-base-1.0
recommended: True
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
description: Stable Diffusion XL refiner model (12 GB)
repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
source: stabilityai/stable-diffusion-xl-refiner-1.0
recommended: False
sdxl/vae/sdxl-1-0-vae-fix:
description: Fine tuned version of the SDXL-1.0 VAE
repo_id: madebyollin/sdxl-vae-fp16-fix
sdxl/vae/sdxl-vae-fp16-fix:
description: Version of the SDXL-1.0 VAE that works in half precision mode
source: madebyollin/sdxl-vae-fp16-fix
recommended: True
sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion
source: wavymulder/Analog-Diffusion
recommended: False
sd-1/main/Deliberate:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
repo_id: XpucT/Deliberate
source: XpucT/Deliberate
recommended: False
sd-1/main/Dungeons-and-Diffusion:
description: Dungeons & Dragons characters (2.13 GB)
repo_id: 0xJustin/Dungeons-and-Diffusion
source: 0xJustin/Dungeons-and-Diffusion
recommended: False
sd-1/main/dreamlike-photoreal-2:
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
repo_id: dreamlike-art/dreamlike-photoreal-2.0
source: dreamlike-art/dreamlike-photoreal-2.0
recommended: False
sd-1/main/Inkpunk-Diffusion:
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
repo_id: Envvi/Inkpunk-Diffusion
source: Envvi/Inkpunk-Diffusion
recommended: False
sd-1/main/openjourney:
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
repo_id: prompthero/openjourney
source: prompthero/openjourney
recommended: False
sd-1/main/seek.art_MEGA:
repo_id: coreco/seek.art_MEGA
source: coreco/seek.art_MEGA
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
recommended: False
sd-1/main/trinart_stable_diffusion_v2:
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
repo_id: naclbit/trinart_stable_diffusion_v2
source: naclbit/trinart_stable_diffusion_v2
recommended: False
sd-1/controlnet/qrcode_monster:
repo_id: monster-labs/control_v1p_sd15_qrcode_monster
source: monster-labs/control_v1p_sd15_qrcode_monster
subfolder: v2
sd-1/controlnet/canny:
repo_id: lllyasviel/control_v11p_sd15_canny
source: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
repo_id: lllyasviel/control_v11p_sd15_inpaint
source: lllyasviel/control_v11p_sd15_inpaint
sd-1/controlnet/mlsd:
repo_id: lllyasviel/control_v11p_sd15_mlsd
source: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
repo_id: lllyasviel/control_v11f1p_sd15_depth
source: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
repo_id: lllyasviel/control_v11p_sd15_normalbae
source: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
repo_id: lllyasviel/control_v11p_sd15_seg
source: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
repo_id: lllyasviel/control_v11p_sd15_lineart
source: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
source: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
repo_id: lllyasviel/control_v11p_sd15_openpose
source: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
repo_id: lllyasviel/control_v11p_sd15_scribble
source: lllyasviel/control_v11p_sd15_scribble
recommended: False
sd-1/controlnet/softedge:
repo_id: lllyasviel/control_v11p_sd15_softedge
source: lllyasviel/control_v11p_sd15_softedge
sd-1/controlnet/shuffle:
repo_id: lllyasviel/control_v11e_sd15_shuffle
source: lllyasviel/control_v11e_sd15_shuffle
sd-1/controlnet/tile:
repo_id: lllyasviel/control_v11f1e_sd15_tile
source: lllyasviel/control_v11f1e_sd15_tile
sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
source: lllyasviel/control_v11e_sd15_ip2p
sd-1/t2i_adapter/canny-sd15:
repo_id: TencentARC/t2iadapter_canny_sd15v2
source: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:
repo_id: TencentARC/t2iadapter_sketch_sd15v2
source: TencentARC/t2iadapter_sketch_sd15v2
sd-1/t2i_adapter/depth-sd15:
repo_id: TencentARC/t2iadapter_depth_sd15v2
source: TencentARC/t2iadapter_depth_sd15v2
sd-1/t2i_adapter/zoedepth-sd15:
repo_id: TencentARC/t2iadapter_zoedepth_sd15v1
source: TencentARC/t2iadapter_zoedepth_sd15v1
sdxl/t2i_adapter/canny-sdxl:
repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0
source: TencentARC/t2i-adapter-canny-sdxl-1.0
sdxl/t2i_adapter/zoedepth-sdxl:
repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
sdxl/t2i_adapter/lineart-sdxl:
repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0
source: TencentARC/t2i-adapter-lineart-sdxl-1.0
sdxl/t2i_adapter/sketch-sdxl:
repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0
source: TencentARC/t2i-adapter-sketch-sdxl-1.0
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True
sd-1/embedding/ahx-beta-453407d:
repo_id: sd-concepts-library/ahx-beta-453407d
description: A textual inversion to use in the negative prompt to reduce bad anatomy
sd-1/lora/LowRA:
path: https://civitai.com/api/download/models/63006
source: https://civitai.com/api/download/models/63006
recommended: True
description: An embedding that helps generate low-light images
sd-1/lora/Ink scenery:
path: https://civitai.com/api/download/models/83390
source: https://civitai.com/api/download/models/83390
description: Generate india ink-like landscapes
sd-1/ip_adapter/ip_adapter_sd15:
repo_id: InvokeAI/ip_adapter_sd15
source: InvokeAI/ip_adapter_sd15
recommended: True
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_sd15:
repo_id: InvokeAI/ip_adapter_plus_sd15
source: InvokeAI/ip_adapter_plus_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
repo_id: InvokeAI/ip_adapter_plus_face_sd15
source: InvokeAI/ip_adapter_plus_face_sd15
recommended: False
requires:
- InvokeAI/ip_adapter_sd_image_encoder
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
sdxl/ip_adapter/ip_adapter_sdxl:
repo_id: InvokeAI/ip_adapter_sdxl
source: InvokeAI/ip_adapter_sdxl
recommended: False
requires:
- InvokeAI/ip_adapter_sdxl_image_encoder
description: IP-Adapter for SDXL models
any/clip_vision/ip_adapter_sd_image_encoder:
repo_id: InvokeAI/ip_adapter_sd_image_encoder
source: InvokeAI/ip_adapter_sd_image_encoder
recommended: False
description: Required model for using IP-Adapters with SD-1/2 models
any/clip_vision/ip_adapter_sdxl_image_encoder:
repo_id: InvokeAI/ip_adapter_sdxl_image_encoder
source: InvokeAI/ip_adapter_sdxl_image_encoder
recommended: False
description: Required model for using IP-Adapters with SDXL models

View File

@ -0,0 +1,80 @@
model:
base_learning_rate: 7.5e-05
target: invokeai.backend.models.diffusion.ddpm.LatentInpaintDiffusion
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
personalization_config:
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['sculpture']
per_image_tokens: false
num_vectors_per_token: 8
progressive_words: False
unet_config:
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -6,28 +6,29 @@
"""
This is the npyscreen frontend to the model installation application.
The work is actually done in backend code in model_install_backend.py.
"""
import argparse
import curses
import logging
import sys
import textwrap
import traceback
from argparse import Namespace
from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe
from dataclasses import dataclass, field
from pathlib import Path
from shutil import get_terminal_size
from typing import Dict, List, Optional, Tuple
import npyscreen
import omegaconf
import torch
from npyscreen import widget
from pydantic import BaseModel
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, SchedulerPredictionType
from invokeai.backend.model_management import ModelManager, ModelType
from invokeai.app.services.model_install_service import ModelInstallJob, ModelInstallService
from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import (
@ -40,7 +41,6 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
TextBox,
WindowTooSmallException,
select_stable_diffusion_config_file,
set_min_terminal_size,
)
@ -56,12 +56,20 @@ NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(
MAX_OTHER_MODELS = 72
@dataclass
class InstallSelections:
install_models: List[UnifiedModelInfo] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
def make_printable(s: str) -> str:
"""Replace non-printable characters in a string"""
"""Replace non-printable characters in a string."""
return s.translate(NOPRINT_TRANS_TABLE)
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
"""Main form for interactive TUI."""
# for responsive resizing set to False, but this seems to cause a crash!
FIX_MINIMUM_SIZE_WHEN_CREATED = True
@ -74,17 +82,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
def create(self):
self.installer = self.parentApp.install_helper.installer
self.model_labels = self._get_model_labels()
self.keypress_timeout = 10
self.counter = 0
self.subprocess_connection = None
if not config.model_conf_path.exists():
with open(config.model_conf_path, "w") as file:
print("# InvokeAI model configuration file", file=file)
self.installer = ModelInstall(config)
self.all_models = self.installer.all_models()
self.starter_models = self.installer.starter_models()
self.model_labels = self._get_model_labels()
window_width, window_height = get_terminal_size()
self.nextrely -= 1
@ -161,15 +164,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = bottom_of_table + 1
self.monitor = self.add_widget_intelligent(
BufferBox,
name="Log Messages",
editable=False,
max_height=6,
)
self.nextrely += 1
done_label = "APPLY CHANGES"
back_label = "BACK"
cancel_label = "CANCEL"
current_position = self.nextrely
@ -185,14 +180,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
)
self.nextrely = current_position
self.ok_button = self.add_widget_intelligent(
npyscreen.ButtonPress,
name=done_label,
relx=(window_width - len(done_label)) // 2,
when_pressed_function=self.on_execute,
)
label = "APPLY CHANGES & EXIT"
label = "APPLY CHANGES"
self.nextrely = current_position
self.done = self.add_widget_intelligent(
npyscreen.ButtonPress,
@ -210,16 +199,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
"""Add widgets responsible for selecting diffusers models"""
widgets = dict()
models = self.all_models
starters = self.starter_models
starter_model_labels = self.model_labels
self.installed_models = sorted([x for x in starters if models[x].installed])
all_models = self.all_models # master dict of all models, indexed by key
model_list = [x for x in self.starter_models if all_models[x].model_type in ["main", "vae"]]
model_labels = [self.model_labels[x] for x in model_list]
widgets.update(
label1=self.add_widget_intelligent(
CenteredTitleText,
name="Select from a starter set of Stable Diffusion models from HuggingFace.",
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
editable=False,
labelColor="CAUTION",
)
@ -229,23 +217,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# if user has already installed some initial models, then don't patronize them
# by showing more recommendations
show_recommended = len(self.installed_models) == 0
keys = [x for x in models.keys() if x in starters]
checked = [
model_list.index(x)
for x in model_list
if (show_recommended and all_models[x].recommended) or all_models[x].installed
]
widgets.update(
models_selected=self.add_widget_intelligent(
MultiSelectColumns,
columns=1,
name="Install Starter Models",
values=[starter_model_labels[x] for x in keys],
value=[
keys.index(x)
for x in keys
if (show_recommended and models[x].recommended) or (x in self.installed_models)
],
max_height=len(starters) + 1,
values=model_labels,
value=checked,
max_height=len(model_list) + 1,
relx=4,
scroll_exit=True,
),
models=keys,
models=model_list,
)
self.nextrely += 1
@ -261,7 +250,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
) -> dict[str, npyscreen.widget]:
"""Generic code to create model selection widgets"""
widgets = dict()
model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and x not in exclude]
all_models = self.all_models
model_list = [x for x in all_models if all_models[x].model_type == model_type and x not in exclude]
model_labels = [self.model_labels[x] for x in model_list]
show_recommended = len(self.installed_models) == 0
@ -297,7 +287,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
value=[
model_list.index(x)
for x in model_list
if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed
if (show_recommended and all_models[x].recommended) or all_models[x].installed
],
max_height=len(model_list) // columns + 1,
relx=4,
@ -321,7 +311,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
download_ids=self.add_widget_intelligent(
TextBox,
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
max_height=4,
max_height=6,
scroll_exit=True,
editable=True,
)
@ -349,8 +339,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def resize(self):
super().resize()
if s := self.starter_pipelines.get("models_selected"):
keys = [x for x in self.all_models.keys() if x in self.starter_models]
s.values = [self.model_labels[x] for x in keys]
s.values = [self.model_labels[x] for x in self.starter_pipelines.get("models")]
def _toggle_tables(self, value=None):
selected_tab = value[0]
@ -382,17 +371,18 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.display()
def _get_model_labels(self) -> dict[str, str]:
"""Return a list of trimmed labels for all models."""
window_width, window_height = get_terminal_size()
checkbox_width = 4
spacing_width = 2
result = dict()
models = self.all_models
label_width = max([len(models[x].name) for x in models])
label_width = max([len(models[x].name) for x in self.starter_models])
description_width = window_width - label_width - checkbox_width - spacing_width
result = dict()
for x in models.keys():
description = models[x].description
for key in self.all_models:
description = models[key].description
description = (
description[0 : description_width - 3] + "..."
if description and len(description) > description_width
@ -400,7 +390,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
if description
else ""
)
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
return result
def _get_columns(self) -> int:
@ -411,38 +402,24 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
def confirm_deletions(self, selections: InstallSelections) -> bool:
remove_models = selections.remove_models
if len(remove_models) > 0:
mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models])
mods = "\n".join([self.all_models[x].name for x in remove_models])
return npyscreen.notify_ok_cancel(
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
)
else:
return True
def on_execute(self):
self.marshall_arguments()
app = self.parentApp
if not self.confirm_deletions(app.install_selections):
return
@property
def all_models(self) -> Dict[str, UnifiedModelInfo]:
return self.parentApp.install_helper.all_models
self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True)
self.ok_button.hidden = True
self.display()
@property
def starter_models(self) -> List[str]:
return self.parentApp.install_helper._starter_models
# TO DO: Spawn a worker thread, not a subprocess
parent_conn, child_conn = Pipe()
p = Process(
target=process_and_execute,
kwargs=dict(
opt=app.program_opts,
selections=app.install_selections,
conn_out=child_conn,
),
)
p.start()
child_conn.close()
self.subprocess_connection = parent_conn
self.subprocess = p
app.install_selections = InstallSelections()
@property
def installed_models(self) -> List[str]:
return self.parentApp.install_helper._installed_models
def on_back(self):
self.parentApp.switchFormPrevious()
@ -461,76 +438,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.parentApp.user_cancelled = False
self.editing = False
########## This routine monitors the child process that is performing model installation and removal #####
def while_waiting(self):
"""Called during idle periods. Main task is to update the Log Messages box with messages
from the child process that does the actual installation/removal"""
c = self.subprocess_connection
if not c:
return
monitor_widget = self.monitor.entry_widget
while c.poll():
try:
data = c.recv_bytes().decode("utf-8")
data.strip("\n")
# processing child is requesting user input to select the
# right configuration file
if data.startswith("*need v2 config"):
_, model_path, *_ = data.split(":", 2)
self._return_v2_config(model_path)
# processing child is done
elif data == "*done*":
self._close_subprocess_and_regenerate_form()
break
# update the log message box
else:
data = make_printable(data)
data = data.replace("[A", "")
monitor_widget.buffer(
textwrap.wrap(
data,
width=monitor_widget.width,
subsequent_indent=" ",
),
scroll_end=True,
)
self.display()
except (EOFError, OSError):
self.subprocess_connection = None
def _return_v2_config(self, model_path: str):
c = self.subprocess_connection
model_name = Path(model_path).name
message = select_stable_diffusion_config_file(model_name=model_name)
c.send_bytes(message.encode("utf-8"))
def _close_subprocess_and_regenerate_form(self):
app = self.parentApp
self.subprocess_connection.close()
self.subprocess_connection = None
self.monitor.entry_widget.buffer(["** Action Complete **"])
self.display()
# rebuild the form, saving and restoring some of the fields that need to be preserved.
saved_messages = self.monitor.entry_widget.values
app.main_form = app.addForm(
"MAIN",
addModelsForm,
name="Install Stable Diffusion Models",
multipage=self.multipage,
)
app.switchForm("MAIN")
app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([""], scroll_end=True)
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
def marshall_arguments(self):
"""
Assemble arguments and store as attributes of the application:
@ -561,16 +468,13 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
models_to_install = [x for x in selected if not self.all_models[x].installed]
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
selections.remove_models.extend(models_to_remove)
selections.install_models.extend(
all_models[x].path or all_models[x].repo_id
for x in models_to_install
if all_models[x].path or all_models[x].repo_id
)
selections.install_models.extend([all_models[x] for x in models_to_install])
# models located in the 'download_ids" section
for section in ui_sections:
if downloads := section.get("download_ids"):
selections.install_models.extend(downloads.value.split())
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
selections.install_models.extend(models)
# NOT NEEDED - DONE IN BACKEND NOW
# # special case for the ipadapter_models. If any of the adapters are
@ -593,12 +497,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self, opt):
def __init__(self, opt: Namespace, install_helper: InstallHelper):
super().__init__()
self.program_opts = opt
self.user_cancelled = False
# self.autoload_pending = True
self.install_selections = InstallSelections()
self.install_helper = install_helper
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
@ -610,136 +514,55 @@ class AddModelApplication(npyscreen.NPSAppManaged):
)
class StderrToMessage:
def __init__(self, connection: Connection):
self.connection = connection
def write(self, data: str):
self.connection.send_bytes(data.encode("utf-8"))
def flush(self):
pass
# --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType:
if tui_conn:
logger.debug("Waiting for user response...")
return _ask_user_for_pt_tui(model_path, tui_conn)
else:
return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print(
f"""
Please select the type of the V2 checkpoint named {model_path.name}:
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
[3] Skip this model and come back later.
"""
)
choice = None
ok = False
while not ok:
try:
choice = input("select> ").strip()
choice = choices[int(choice) - 1]
ok = True
except (ValueError, IndexError):
print(f"{choice} is not a valid choice")
except EOFError:
return
return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType:
try:
tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8"))
# note that we don't do any status checking here
response = tui_conn.recv_bytes().decode("utf-8")
if response is None:
return None
elif response == "epsilon":
return SchedulerPredictionType.epsilon
elif response == "v":
return SchedulerPredictionType.VPrediction
elif response == "abort":
logger.info("Conversion aborted")
return None
else:
return response
except Exception:
return None
# --------------------------------------------------------
def process_and_execute(
opt: Namespace,
selections: InstallSelections,
conn_out: Connection = None,
):
# need to reinitialize config in subprocess
config = InvokeAIAppConfig.get_config()
args = ["--root", opt.root] if opt.root else []
config.parse_args(args)
# set up so that stderr is sent to conn_out
if conn_out:
translator = StderrToMessage(conn_out)
sys.stderr = translator
sys.stdout = translator
logger = InvokeAILogger.get_logger()
logger.handlers.clear()
logger.addHandler(logging.StreamHandler(translator))
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out))
installer.install(selections)
if conn_out:
conn_out.send_bytes("*done*".encode("utf-8"))
conn_out.close()
def list_models(installer: ModelInstallService, model_type: ModelType):
"""Print out all models of type model_type."""
models = installer.store.search_by_name(model_type=model_type)
print(f"Installed models of type `{model_type}`:")
for model in models:
path = (config.models_path / model.path).resolve()
print(f"{model.name:40}{model.base_model.value:14}{path}")
# --------------------------------------------------------
def select_and_download_models(opt: Namespace):
"""Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision
installer = ModelInstall(config, prediction_type_helper=ask_user_for_prediction_type)
install_helper = InstallHelper(config)
installer = install_helper.installer
if opt.list_models:
installer.list_models(opt.list_models)
list_models(installer, opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or [])
installer.install(selections)
selections = InstallSelections(
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
)
install_helper.add_or_delete(selections)
elif opt.default_only:
selections = InstallSelections(install_models=installer.default_model())
installer.install(selections)
selections = InstallSelections(install_models=[initial_models.default_model()])
install_helper.add_or_delete(selections)
elif opt.yes_to_all:
selections = InstallSelections(install_models=installer.recommended_models())
installer.install(selections)
selections = InstallSelections(install_models=initial_models.recommended_models())
install_helper.add_or_delete(selections)
# this is where the TUI is called
else:
# needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn")
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt)
installApp = AddModelApplication(opt, install_helper)
try:
installApp.run()
except KeyboardInterrupt as e:
if hasattr(installApp, "main_form"):
if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive():
logger.info("Terminating subprocesses")
installApp.main_form.subprocess.terminate()
installApp.main_form.subprocess = None
raise e
process_and_execute(opt, installApp.install_selections)
print("Aborted...")
sys.exit(-1)
install_helper.add_or_delete(installApp.install_selections)
# -------------------------------------
@ -753,7 +576,7 @@ def main():
parser.add_argument(
"--delete",
nargs="*",
help="List of names of models to idelete",
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
)
parser.add_argument(
"--full-precision",
@ -780,14 +603,6 @@ def main():
choices=[x.value for x in ModelType],
help="list installed models",
)
parser.add_argument(
"--config_file",
"-c",
dest="config_file",
type=str,
default=None,
help="path to configuration file to create",
)
parser.add_argument(
"--root_dir",
dest="root",

View File

@ -19,7 +19,7 @@ from npyscreen import fmPopup
# minimum size for UIs
MIN_COLS = 150
MIN_LINES = 40
MIN_LINES = 45
class WindowTooSmallException(Exception):
@ -264,6 +264,17 @@ class SingleSelectWithChanged(npyscreen.SelectOne):
self.on_changed(self.value)
class CheckboxWithChanged(npyscreen.Checkbox):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.on_changed = None
def whenToggled(self):
super().whenToggled
if self.on_changed:
self.on_changed(self.value)
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
"""Row of radio buttons. Spacebar to select."""

View File

@ -6,21 +6,36 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import re
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import npyscreen
from npyscreen import widget
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import BaseModelType, ModelManager, ModelMerger, ModelType
from invokeai.backend.model_manager import (
BaseModelType,
ModelConfigStore,
ModelFormat,
ModelType,
ModelVariantType,
get_config_store,
)
from invokeai.backend.model_manager.merge import ModelMerger
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
config = InvokeAIAppConfig.get_config()
BASE_TYPES = [
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
]
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
@ -48,7 +63,7 @@ def _parse_args() -> Namespace:
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
choices=[x[0].value for x in BASE_TYPES],
help="The base model shared by the models to be merged",
)
parser.add_argument(
@ -106,9 +121,9 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
def create(self):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
self.models = self.get_models(BASE_TYPES[self.current_base][0])
self.model_names = [x[1] for x in self.models]
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -128,10 +143,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"Models Built on SD-1.x",
"Models Built on SD-2.x",
],
values=[x[1] for x in BASE_TYPES],
value=[self.current_base],
columns=4,
max_height=2,
@ -262,19 +274,19 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
sys.exit(0)
def marshall_arguments(self) -> dict:
model_names = self.model_names
model_keys = [x[0] for x in self.models]
models = [
model_names[self.model1.value[0]],
model_names[self.model2.value[0]],
model_keys[self.model1.value[0]],
model_keys[self.model2.value[0]],
]
if self.model3.value[0] > 0:
models.append(model_names[self.model3.value[0] - 1])
models.append(model_keys[self.model3.value[0] - 1])
interp = "add_difference"
else:
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
model_names=models,
model_keys=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
@ -309,17 +321,18 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self, base_model: Optional[BaseModelType] = None) -> List[str]:
model_names = [
info["model_name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
models = [
(x.key, x.name)
for x in self.model_manager.search_by_name(model_type=ModelType.Main, base_model=base_model)
if x.model_format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
]
return sorted(model_names)
return sorted(models, key=lambda x: x[1])
def _populate_models(self, value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
def _populate_models(self, value: List[int]):
base_model = BASE_TYPES[value[0]][0]
self.models = self.get_models(base_model)
self.model_names = [x[1] for x in self.models]
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -331,7 +344,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self, model_manager: ModelManager):
def __init__(self, model_manager: ModelConfigStore):
super().__init__()
self.model_manager = model_manager
@ -341,14 +354,13 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
model_manager = ModelManager(config.model_conf_path)
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
merger = ModelMerger(model_manager, config)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def run_cli(args: Namespace):
@ -361,13 +373,31 @@ def run_cli(args: Namespace):
args.merged_model_name = "+".join(args.model_names)
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
model_manager = ModelManager(config.model_conf_path)
model_manager: ModelConfigStore = get_config_store(config.model_conf_path)
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
len(model_manager.search_by_name(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
model_keys = []
for name in args.model_names:
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
model_keys.append(name)
else:
models = model_manager.search_by_name(
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
)
assert len(models) > 0, f"{name}: Unknown model"
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
model_keys.append(models[0].key)
merger.merge_diffusion_models_and_save(
alpha=args.alpha,
model_keys=model_keys,
merged_model_name=args.merged_model_name,
interp=args.interp,
force=args.force,
)
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
@ -375,6 +405,8 @@ def main():
args = _parse_args()
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
else:
config.parse_args([])
try:
if args.front_end:

View File

@ -22,6 +22,7 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ModelConfigStore, ModelType, get_config_store
from ...backend.training import do_textual_inversion_training, parse_args
@ -275,10 +276,13 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
return True
def get_model_names(self) -> Tuple[List[str], int]:
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"]
defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]]
default = defaults[0] if len(defaults) > 0 else 0
global config
store: ModelConfigStore = get_config_store(config.model_conf_path)
main_models = store.search_by_name(model_type=ModelType.Main)
model_names = [
f"{x.base_model.value}/{x.model_type.value}/{x.name}" for x in main_models if x.model_format == "diffusers"
]
default = 0
return (model_names, default)
def marshall_arguments(self) -> dict:
@ -384,6 +388,7 @@ def previous_args() -> dict:
def do_front_end(args: Namespace):
global config
saved_args = previous_args()
myapplication = MyApplication(saved_args=saved_args)
myapplication.run()
@ -399,7 +404,7 @@ def do_front_end(args: Namespace):
save_args(args)
try:
do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args)
do_textual_inversion_training(config, **args)
copy_to_embeddings_folder(args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
@ -413,6 +418,7 @@ def main():
args = parse_args()
config = InvokeAIAppConfig.get_config()
config.parse_args([])
# change root if needed
if args.root_dir:

View File

@ -35,6 +35,7 @@ stats.html
!.yarn/releases
!.yarn/sdks
!.yarn/versions
.vite
# Yalc
.yalc

View File

@ -238,7 +238,7 @@ const modelsFilter = <
T extends
| MainModelConfigEntity
| LoRAModelConfigEntity
| OnnxModelConfigEntity,
| OnnxModelConfigEntity
>(
data: EntityState<T> | undefined,
model_type: ModelType,

View File

@ -243,7 +243,6 @@ export const modelsApi = api.injectEndpoints({
{ type: 'MainModel', id: LIST_TAG },
'Model',
];
if (result) {
tags.push(
...result.ids.map((id) => ({

View File

@ -49,6 +49,7 @@ dependencies = [
"fastapi==0.88.0",
"fastapi-events==0.8.0",
"huggingface-hub~=0.16.4",
"imohash~=1.0.0",
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model
@ -106,6 +107,7 @@ dependencies = [
"pytest>6.0.0",
"pytest-cov",
"pytest-datadir",
"requests-testadapter",
]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
@ -140,7 +142,6 @@ dependencies = [
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"

View File

@ -0,0 +1,39 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
convert_models_config_to_3.2.py.
This script converts a pre-3.2 models.yaml file into the 3.2 format.
The main difference is that each model is identified by a unique hash,
rather than the concatenation of base, type and name used previously.
In addition, there are more metadata fields attached to each model.
These will mostly be empty after conversion, but will be populated
when new models are downloaded from HuggingFace or Civitae.
"""
import argparse
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.storage import migrate_models_store
def main():
parser = argparse.ArgumentParser(description="Convert a pre-3.2 models.yaml into the 3.2 version.")
parser.add_argument("--root", type=Path, help="Alternate root directory containing the models.yaml to convert")
parser.add_argument(
"--outfile",
type=Path,
default=Path("./models-3.2.yaml"),
help="File to write to. A file with suffix '.yaml' will use the YAML format. A file with an extension of '.db' will be treated as a SQLite3 database.",
)
args = parser.parse_args()
config_args = ["--root", args.root.as_posix()] if args.root else []
config = InvokeAIAppConfig.get_config()
config.parse_args(config_args)
migrate_models_store(config)
if __name__ == "__main__":
main()

View File

@ -1,9 +1,19 @@
#!/bin/env python
"""Little command-line utility for probing a model on disk."""
import argparse
import json
import sys
from pathlib import Path
from invokeai.backend.model_management.model_probe import ModelProbe
from invokeai.backend.model_manager import InvalidModelException, ModelProbe, SchedulerPredictionType
def helper(model_path: Path):
print('Warning: guessing "v_prediction" SchedulerPredictionType', file=sys.stderr)
return SchedulerPredictionType.VPrediction
parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument(
@ -14,5 +24,8 @@ parser.add_argument(
args = parser.parse_args()
for path in args.model_path:
info = ModelProbe().probe(path)
print(f"{path}: {info}")
try:
info = ModelProbe.probe(path, helper)
print(f"{path}:{json.dumps(info.dict(), sort_keys=True, indent=4)}")
except InvalidModelException as exc:
print(exc)

View File

@ -49,7 +49,10 @@ def mock_services() -> InvocationServices:
conn=db_conn, table_name="graph_executions", lock=lock
)
return InvocationServices(
model_manager=None, # type: ignore
download_queue=None, # type: ignore
model_loader=None, # type: ignore
model_installer=None, # type: ignore
model_record_store=None, # type: ignore
events=TestEventService(),
logger=logging, # type: ignore
images=None, # type: ignore

View File

@ -59,7 +59,10 @@ def mock_services() -> InvocationServices:
conn=db_conn, table_name="graph_executions", lock=lock
)
return InvocationServices(
model_manager=None, # type: ignore
download_queue=None, # type: ignore
model_loader=None, # type: ignore
model_installer=None, # type: ignore
model_record_store=None, # type: ignore
events=TestEventService(),
logger=logging, # type: ignore
images=None, # type: ignore

View File

@ -12,7 +12,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
populate_graph,
prepare_values_to_insert,
)
from tests.nodes.test_nodes import PromptTestInvocation
from .test_nodes import PromptTestInvocation
@pytest.fixture

Some files were not shown because too many files have changed in this diff Show More