mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add directory scanning for loras, controlnets and textual_inversions
This commit is contained in:
parent
6652f3405b
commit
887576d217
@ -15,10 +15,7 @@ InvokeAI:
|
|||||||
conf_path: configs/models.yaml
|
conf_path: configs/models.yaml
|
||||||
legacy_conf_dir: configs/stable-diffusion
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
outdir: outputs
|
outdir: outputs
|
||||||
embedding_dir: embeddings
|
|
||||||
lora_dir: loras
|
|
||||||
autoconvert_dir: null
|
autoconvert_dir: null
|
||||||
gfpgan_model_dir: models/gfpgan/GFPGANv1.4.pth
|
|
||||||
Models:
|
Models:
|
||||||
model: stable-diffusion-1.5
|
model: stable-diffusion-1.5
|
||||||
embeddings: true
|
embeddings: true
|
||||||
@ -171,7 +168,7 @@ from argparse import ArgumentParser
|
|||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path('invokeai.yaml')
|
INIT_FILE = Path('invokeai.yaml')
|
||||||
DB_FILE = Path('invokeai.db')
|
DB_FILE = Path('invokeai.db')
|
||||||
@ -379,18 +376,14 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI textual inversion aembeddings directory', category='Paths')
|
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
||||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.', category='Paths')
|
|
||||||
controlnet_dir : Path = Field(default="controlnets", description='Path to directory of ControlNet models.', category='Paths')
|
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
|
||||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
|
||||||
|
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
@ -492,46 +485,11 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
return self._resolve(self.legacy_conf_dir)
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_dir(self)->Path:
|
def models_path(self)->Path:
|
||||||
'''
|
|
||||||
Path to the global cache directory for HuggingFace hub-managed models
|
|
||||||
'''
|
|
||||||
return self.models_dir / "hub"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def models_dir(self)->Path:
|
|
||||||
'''
|
'''
|
||||||
Path to the models directory
|
Path to the models directory
|
||||||
'''
|
'''
|
||||||
return self._resolve("models")
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
|
||||||
def converted_ckpts_dir(self)->Path:
|
|
||||||
'''
|
|
||||||
Path to the converted models
|
|
||||||
'''
|
|
||||||
return self._resolve("models/converted_ckpts")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_path(self)->Path:
|
|
||||||
'''
|
|
||||||
Path to the textual inversion embeddings directory.
|
|
||||||
'''
|
|
||||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lora_path(self)->Path:
|
|
||||||
'''
|
|
||||||
Path to the LoRA models directory.
|
|
||||||
'''
|
|
||||||
return self._resolve(self.lora_dir) if self.lora_dir else None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def controlnet_path(self)->Path:
|
|
||||||
'''
|
|
||||||
Path to the controlnet models directory.
|
|
||||||
'''
|
|
||||||
return self._resolve(self.controlnet_dir) if self.controlnet_dir else None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def autoconvert_path(self)->Path:
|
def autoconvert_path(self)->Path:
|
||||||
@ -540,13 +498,6 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
'''
|
'''
|
||||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||||
|
|
||||||
@property
|
|
||||||
def gfpgan_model_path(self)->Path:
|
|
||||||
'''
|
|
||||||
Path to the GFPGAN model.
|
|
||||||
'''
|
|
||||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self)->bool:
|
def full_precision(self)->bool:
|
||||||
|
@ -162,7 +162,12 @@ class SDModelType(str, Enum):
|
|||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
TextualInversion = "textual_inversion"
|
TextualInversion = "textual_inversion"
|
||||||
|
ControlNet = "control_net"
|
||||||
|
|
||||||
|
class BaseModel(str, Enum):
|
||||||
|
StableDiffusion1_5 = "SD-1"
|
||||||
|
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
|
||||||
|
StableDiffusion2 = "SD-2" # 768 pixels; this will have v-prediction parameterization
|
||||||
|
|
||||||
class ModelInfoBase:
|
class ModelInfoBase:
|
||||||
#model_path: str
|
#model_path: str
|
||||||
|
@ -54,15 +54,17 @@ MODELS.YAML
|
|||||||
The general format of a models.yaml section is:
|
The general format of a models.yaml section is:
|
||||||
|
|
||||||
type-of-model/name-of-model:
|
type-of-model/name-of-model:
|
||||||
format: folder|ckpt|safetensors
|
|
||||||
repo_id: owner/repo
|
|
||||||
path: /path/to/local/file/or/directory
|
path: /path/to/local/file/or/directory
|
||||||
|
description: a description
|
||||||
|
format: folder|ckpt|safetensors|pt
|
||||||
|
base: SD-1|SD-2
|
||||||
subfolder: subfolder-name
|
subfolder: subfolder-name
|
||||||
|
|
||||||
The type of model is given in the stanza key, and is one of
|
The type of model is given in the stanza key, and is one of
|
||||||
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
||||||
safety_checker, feature_extractor, lora, textual_inversion}, and
|
safety_checker, feature_extractor, lora, textual_inversion,
|
||||||
correspond to items in the SDModelType enum defined in model_cache.py
|
controlnet}, and correspond to items in the SDModelType enum defined
|
||||||
|
in model_cache.py
|
||||||
|
|
||||||
The format indicates whether the model is organized as a folder with
|
The format indicates whether the model is organized as a folder with
|
||||||
model subdirectories, or is contained in a single checkpoint or
|
model subdirectories, or is contained in a single checkpoint or
|
||||||
@ -80,12 +82,12 @@ This example summarizes the two ways of getting a non-diffuser model:
|
|||||||
|
|
||||||
text_encoder/clip-test-1:
|
text_encoder/clip-test-1:
|
||||||
format: folder
|
format: folder
|
||||||
repo_id: openai/clip-vit-large-patch14
|
path: /path/to/folder
|
||||||
description: Returns standalone CLIPTextModel
|
description: Returns standalone CLIPTextModel
|
||||||
|
|
||||||
text_encoder/clip-test-2:
|
text_encoder/clip-test-2:
|
||||||
format: folder
|
format: folder
|
||||||
repo_id: stabilityai/stable-diffusion-2
|
repo_id: /path/to/folder
|
||||||
subfolder: text_encoder
|
subfolder: text_encoder
|
||||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
||||||
|
|
||||||
@ -99,6 +101,14 @@ model. Use the `submodel` parameter to select which part:
|
|||||||
print(type(my_vae))
|
print(type(my_vae))
|
||||||
# "AutoencoderKL"
|
# "AutoencoderKL"
|
||||||
|
|
||||||
|
DIRECTORY_SCANNING:
|
||||||
|
|
||||||
|
Loras, textual_inversion and controlnet models are usually not listed
|
||||||
|
explicitly in models.yaml, but are added to the in-memory data
|
||||||
|
structure at initialization time by scanning the models directory. The
|
||||||
|
in-memory data structure can be resynchronized by calling
|
||||||
|
`manager.scan_models_directory`.
|
||||||
|
|
||||||
DISAMBIGUATION:
|
DISAMBIGUATION:
|
||||||
|
|
||||||
You may wish to use the same name for a related family of models. To
|
You may wish to use the same name for a related family of models. To
|
||||||
@ -107,12 +117,12 @@ separated by "/". Example:
|
|||||||
|
|
||||||
tokenizer/clip-large:
|
tokenizer/clip-large:
|
||||||
format: tokenizer
|
format: tokenizer
|
||||||
repo_id: openai/clip-vit-large-patch14
|
path: /path/to/folder
|
||||||
description: Returns standalone tokenizer
|
description: Returns standalone tokenizer
|
||||||
|
|
||||||
text_encoder/clip-large:
|
text_encoder/clip-large:
|
||||||
format: text_encoder
|
format: text_encoder
|
||||||
repo_id: openai/clip-vit-large-patch14
|
path: /path/to/folder
|
||||||
description: Returns standalone text encoder
|
description: Returns standalone text encoder
|
||||||
|
|
||||||
You can now use the `model_type` argument to indicate which model you
|
You can now use the `model_type` argument to indicate which model you
|
||||||
@ -126,6 +136,14 @@ OTHER FUNCTIONS:
|
|||||||
Other methods provided by ModelManager support importing, editing,
|
Other methods provided by ModelManager support importing, editing,
|
||||||
converting and deleting models.
|
converting and deleting models.
|
||||||
|
|
||||||
|
IMPORTANT CHANGES AND LIMITATIONS SINCE 2.3:
|
||||||
|
|
||||||
|
1. Only local paths are supported. Repo_ids are no longer accepted. This
|
||||||
|
simplifies the logic.
|
||||||
|
|
||||||
|
2. VAEs can't be swapped in and out at load time. They must be baked
|
||||||
|
into the model when downloaded or converted.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -133,17 +151,13 @@ import os
|
|||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, List, Tuple, Union, types
|
from typing import Callable, Dict, Optional, List, Tuple, Union, types
|
||||||
from shutil import move, rmtree
|
from shutil import rmtree
|
||||||
from typing import Any, Optional, Union, Callable, Dict, List, types
|
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -156,7 +170,7 @@ from omegaconf.dictconfig import DictConfig
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, ask_user, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
from ..install.model_install_backend import Dataset_path, hf_download_with_resume
|
from ..install.model_install_backend import Dataset_path, hf_download_with_resume
|
||||||
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
||||||
SilenceWarnings)
|
SilenceWarnings)
|
||||||
@ -197,6 +211,25 @@ class SDLegacyType(Enum):
|
|||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
|
# layout of the models directory:
|
||||||
|
# models
|
||||||
|
# ├── SD-1
|
||||||
|
# │ ├── controlnet
|
||||||
|
# │ ├── lora
|
||||||
|
# │ ├── diffusers
|
||||||
|
# │ └── textual_inversion
|
||||||
|
# ├── SD-2
|
||||||
|
# │ ├── controlnet
|
||||||
|
# │ ├── lora
|
||||||
|
# │ ├── diffusers
|
||||||
|
# │ └── textual_inversion
|
||||||
|
# └── support
|
||||||
|
# ├── codeformer
|
||||||
|
# ├── gfpgan
|
||||||
|
# └── realesrgan
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
High-level interface to model management.
|
High-level interface to model management.
|
||||||
@ -241,6 +274,9 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
|
|
||||||
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
|
self.scan_models_directory(include_diffusers=False)
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -1258,6 +1294,67 @@ class ModelManager(object):
|
|||||||
if self.config_path:
|
if self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
def _delete_defunct_models(self):
|
||||||
|
'''
|
||||||
|
Remove models no longer on disk.
|
||||||
|
'''
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
to_delete = set()
|
||||||
|
for key in config:
|
||||||
|
if 'path' not in config[key]:
|
||||||
|
continue
|
||||||
|
path = self.globals.root_dir / config[key].path
|
||||||
|
if path.exists():
|
||||||
|
continue
|
||||||
|
to_delete.add(key)
|
||||||
|
|
||||||
|
for key in to_delete:
|
||||||
|
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
|
||||||
|
config.pop(key)
|
||||||
|
|
||||||
|
def scan_models_directory(self, include_diffusers:bool=False):
|
||||||
|
'''
|
||||||
|
Scan the models directory for loras, textual_inversions and controlnets
|
||||||
|
and create appropriate entries in the in-memory omegaconf. Diffusers
|
||||||
|
will not be added unless include_diffusers is true.
|
||||||
|
'''
|
||||||
|
self._delete_defunct_models()
|
||||||
|
|
||||||
|
model_directory = self.globals.models_path
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(model_directory):
|
||||||
|
parents = root.split('/')
|
||||||
|
subpaths = parents[parents.index('models')+1:]
|
||||||
|
if len(subpaths) < 2:
|
||||||
|
continue
|
||||||
|
base, model_type, *_ = subpaths
|
||||||
|
|
||||||
|
if model_type == "diffusers" and not include_diffusers:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for d in dirs:
|
||||||
|
config[f'{model_type}/{d}'] = dict(
|
||||||
|
path = os.path.join(root,d),
|
||||||
|
description = f'{model_type} model {d}',
|
||||||
|
format = 'folder',
|
||||||
|
base = base,
|
||||||
|
)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
basename = Path(f).stem
|
||||||
|
format = Path(f).suffix[1:]
|
||||||
|
config[f'{model_type}/{basename}'] = dict(
|
||||||
|
path = os.path.join(root,f),
|
||||||
|
description = f'{model_type} model {basename}',
|
||||||
|
format = format,
|
||||||
|
base = base,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION
|
||||||
|
##### AND NEED TO BE REWRITTEN
|
||||||
def list_lora_models(self)->Dict[str,bool]:
|
def list_lora_models(self)->Dict[str,bool]:
|
||||||
'''Return a dict of installed lora models; key is either the shortname
|
'''Return a dict of installed lora models; key is either the shortname
|
||||||
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
defined in INITIAL_MODELS, or the basename of the file in the LoRA
|
||||||
|
50
scripts/scan_models_directory.py
Normal file
50
scripts/scan_models_directory.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
'''
|
||||||
|
Scan the models directory and print out a new models.yaml
|
||||||
|
'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Model directory scanner")
|
||||||
|
parser.add_argument('models_directory')
|
||||||
|
args = parser.parse_args()
|
||||||
|
directory = args.models_directory
|
||||||
|
|
||||||
|
conf = OmegaConf.create()
|
||||||
|
conf['_version'] = '3.0.0'
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for d in dirs:
|
||||||
|
parents = root.split('/')
|
||||||
|
subpaths = parents[parents.index('models')+1:]
|
||||||
|
if len(subpaths) < 2:
|
||||||
|
continue
|
||||||
|
base, model_type, *_ = subpaths
|
||||||
|
|
||||||
|
conf[f'{model_type}/{d}'] = dict(
|
||||||
|
path = os.path.join(root,d),
|
||||||
|
description = f'{model_type} model {d}',
|
||||||
|
format = 'folder',
|
||||||
|
base = base,
|
||||||
|
)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
basename = Path(f).stem
|
||||||
|
format = Path(f).suffix[1:]
|
||||||
|
conf[f'{model_type}/{basename}'] = dict(
|
||||||
|
path = os.path.join(root,f),
|
||||||
|
description = f'{model_type} model {basename}',
|
||||||
|
format = format,
|
||||||
|
base = base,
|
||||||
|
)
|
||||||
|
|
||||||
|
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user