mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Configuration and model installer for new model layout (#3547)
# Restore invokeai-configure and invokeai-model-install This PR updates invokeai-configure and invokeai-model-install to work with the new model manager file layout. It addresses a naming issue for `ModelType.Main` (was `ModelType.Pipeline`) requested by @blessedcoolant, and adds back the feature that allows users to dump models into an `autoimport` directory for discovery at startup time.
This commit is contained in:
commit
2d85f9a123
@ -1,13 +1,13 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||||
|
|
||||||
from typing import Annotated, Literal, Optional, Union, Dict
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -51,11 +51,14 @@ class CreateModelResponse(BaseModel):
|
|||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
|
class ImportModelRequest(BaseModel):
|
||||||
|
name: str = Field(description="A model path, repo_id or URL to import")
|
||||||
|
prediction_type: Optional[Literal['epsilon','v_prediction','sample']] = Field(description='Prediction type for SDv2 checkpoint files')
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
class ConversionRequest(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
info: CkptModelInfo = Field(description="The converted model info")
|
info: CkptModelInfo = Field(description="The converted model info")
|
||||||
save_location: str = Field(description="The path to save the converted model weights")
|
save_location: str = Field(description="The path to save the converted model weights")
|
||||||
|
|
||||||
|
|
||||||
class ConvertedModelResponse(BaseModel):
|
class ConvertedModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
@ -105,6 +108,28 @@ async def update_model(
|
|||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="import_model",
|
||||||
|
responses={200: {"status": "success"}},
|
||||||
|
)
|
||||||
|
async def import_model(
|
||||||
|
model_request: ImportModelRequest
|
||||||
|
) -> None:
|
||||||
|
""" Add Model """
|
||||||
|
items_to_import = set([model_request.name])
|
||||||
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
|
items_to_import = items_to_import,
|
||||||
|
prediction_type_helper = lambda x: prediction_types.get(model_request.prediction_type)
|
||||||
|
)
|
||||||
|
if len(installed_models) > 0:
|
||||||
|
logger.info(f'Successfully imported {model_request.name}')
|
||||||
|
else:
|
||||||
|
logger.error(f'Model {model_request.name} not imported')
|
||||||
|
raise HTTPException(status_code=500, detail=f'Model {model_request.name} not imported')
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{model_name}",
|
"/{model_name}",
|
||||||
|
@ -73,7 +73,7 @@ class PipelineModelLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.Pipeline
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
|
@ -15,7 +15,7 @@ InvokeAI:
|
|||||||
conf_path: configs/models.yaml
|
conf_path: configs/models.yaml
|
||||||
legacy_conf_dir: configs/stable-diffusion
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
outdir: outputs
|
outdir: outputs
|
||||||
autoconvert_dir: null
|
autoimport_dir: null
|
||||||
Models:
|
Models:
|
||||||
model: stable-diffusion-1.5
|
model: stable-diffusion-1.5
|
||||||
embeddings: true
|
embeddings: true
|
||||||
@ -367,16 +367,19 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
|
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
|
||||||
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
|
||||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance')
|
||||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance')
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
autoimport_dir : Path = Field(default='autoimport/main', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||||
|
lora_dir : Path = Field(default='autoimport/lora', description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||||
|
embedding_dir : Path = Field(default='autoimport/embedding', description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||||
|
controlnet_dir : Path = Field(default='autoimport/controlnet', description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
db_dir : Path = Field(default='databases', description='Path to InvokeAI databases directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
|
@ -7,8 +7,6 @@
|
|||||||
# Coauthor: Kevin Turner http://github.com/keturn
|
# Coauthor: Kevin Turner http://github.com/keturn
|
||||||
#
|
#
|
||||||
import sys
|
import sys
|
||||||
print("Loading Python libraries...\n",file=sys.stderr)
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
@ -16,6 +14,7 @@ import shutil
|
|||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
import yaml
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
@ -25,6 +24,7 @@ from urllib import request
|
|||||||
import npyscreen
|
import npyscreen
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from huggingface_hub import login as hf_hub_login
|
from huggingface_hub import login as hf_hub_login
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -34,6 +34,8 @@ from transformers import (
|
|||||||
CLIPSegForImageSegmentation,
|
CLIPSegForImageSegmentation,
|
||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
BertTokenizerFast,
|
||||||
)
|
)
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
@ -52,12 +54,13 @@ from invokeai.frontend.install.widgets import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
default_dataset,
|
hf_download_from_pretrained,
|
||||||
download_from_hf,
|
InstallSelections,
|
||||||
hf_download_with_resume,
|
ModelInstall,
|
||||||
recommended_datasets,
|
|
||||||
UserSelections,
|
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_management.model_probe import (
|
||||||
|
ModelType, BaseModelType
|
||||||
|
)
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -81,7 +84,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
|||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger=None
|
logger=InvokeAILogger.getLogger()
|
||||||
|
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
def postscript(errors: None):
|
def postscript(errors: None):
|
||||||
@ -162,75 +165,91 @@ class ProgressBar:
|
|||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||||
try:
|
try:
|
||||||
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
|
logger.info(f"Installing {label} model file {model_url}...")
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
request.urlretrieve(
|
request.urlretrieve(
|
||||||
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||||
)
|
)
|
||||||
print("...downloaded successfully", file=sys.stderr)
|
logger.info("...downloaded successfully")
|
||||||
else:
|
else:
|
||||||
print("...exists", file=sys.stderr)
|
logger.info("...exists")
|
||||||
except Exception:
|
except Exception:
|
||||||
print("...download failed", file=sys.stderr)
|
logger.info("...download failed")
|
||||||
print(f"Error downloading {label} model", file=sys.stderr)
|
logger.info(f"Error downloading {label} model")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
def download_conversion_models():
|
||||||
# this will preload the Bert tokenizer fles
|
target_dir = config.root_path / 'models/core/convert'
|
||||||
def download_bert():
|
kwargs = dict() # for future use
|
||||||
print("Installing bert tokenizer...", file=sys.stderr)
|
try:
|
||||||
with warnings.catch_warnings():
|
logger.info('Downloading core tokenizers and text encoders')
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
from transformers import BertTokenizerFast
|
|
||||||
|
|
||||||
download_from_hf(BertTokenizerFast, "bert-base-uncased")
|
# bert
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
|
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||||
|
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||||
|
|
||||||
|
# sd-1
|
||||||
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
|
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
|
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
|
|
||||||
|
# sd-2
|
||||||
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||||
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||||
|
|
||||||
# ---------------------------------------------
|
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||||
def download_sd1_clip():
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||||
print("Installing SD1 clip model...", file=sys.stderr)
|
|
||||||
version = "openai/clip-vit-large-patch14"
|
|
||||||
download_from_hf(CLIPTokenizer, version)
|
|
||||||
download_from_hf(CLIPTextModel, version)
|
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
logger.info('Downloading stable diffusion VAE')
|
||||||
|
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||||
|
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||||
|
|
||||||
# ---------------------------------------------
|
# safety checking
|
||||||
def download_sd2_clip():
|
logger.info('Downloading safety checker')
|
||||||
version = "stabilityai/stable-diffusion-2"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
print("Installing SD2 clip model...", file=sys.stderr)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||||
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
|
|
||||||
|
|
||||||
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||||
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
print("Installing models from RealESRGAN...", file=sys.stderr)
|
logger.info("Installing models from RealESRGAN...")
|
||||||
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
|
||||||
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
|
||||||
|
|
||||||
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
|
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
|
||||||
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||||
|
|
||||||
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
|
||||||
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
|
||||||
|
|
||||||
|
|
||||||
def download_gfpgan():
|
def download_gfpgan():
|
||||||
print("Installing GFPGAN models...", file=sys.stderr)
|
logger.info("Installing GFPGAN models...")
|
||||||
for model in (
|
for model in (
|
||||||
[
|
[
|
||||||
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
||||||
"./models/gfpgan/GFPGANv1.4.pth",
|
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
||||||
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
|
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
|
||||||
"./models/gfpgan/weights/parsing_parsenet.pth",
|
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
model_url, model_dest = model[0], config.root_path / model[1]
|
model_url, model_dest = model[0], config.root_path / model[1]
|
||||||
@ -239,70 +258,32 @@ def download_gfpgan():
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_codeformer():
|
def download_codeformer():
|
||||||
print("Installing CodeFormer model file...", file=sys.stderr)
|
logger.info("Installing CodeFormer model file...")
|
||||||
model_url = (
|
model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
)
|
)
|
||||||
model_dest = config.root_path / "models/codeformer/codeformer.pth"
|
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
|
||||||
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_clipseg():
|
def download_clipseg():
|
||||||
print("Installing clipseg model for text-based masking...", file=sys.stderr)
|
logger.info("Installing clipseg model for text-based masking...")
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
try:
|
try:
|
||||||
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
|
hf_download_from_pretrained(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
||||||
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
|
hf_download_from_pretrained(CLIPSegForImageSegmentation, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error installing clipseg model:")
|
logger.info("Error installing clipseg model:")
|
||||||
print(traceback.format_exc())
|
logger.info(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
def download_support_models():
|
||||||
def download_safety_checker():
|
download_realesrgan()
|
||||||
print("Installing model for NSFW content detection...", file=sys.stderr)
|
download_gfpgan()
|
||||||
try:
|
download_codeformer()
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
download_clipseg()
|
||||||
StableDiffusionSafetyChecker,
|
download_conversion_models()
|
||||||
)
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
print("Error installing NSFW checker model:")
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
print("AutoFeatureExtractor...", file=sys.stderr)
|
|
||||||
download_from_hf(AutoFeatureExtractor, safety_model_id)
|
|
||||||
print("StableDiffusionSafetyChecker...", file=sys.stderr)
|
|
||||||
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def download_vaes():
|
|
||||||
print("Installing stabilityai VAE...", file=sys.stderr)
|
|
||||||
try:
|
|
||||||
# first the diffusers version
|
|
||||||
repo_id = "stabilityai/sd-vae-ft-mse"
|
|
||||||
args = dict(
|
|
||||||
cache_dir=config.cache_dir,
|
|
||||||
)
|
|
||||||
if not AutoencoderKL.from_pretrained(repo_id, **args):
|
|
||||||
raise Exception(f"download of {repo_id} failed")
|
|
||||||
|
|
||||||
repo_id = "stabilityai/sd-vae-ft-mse-original"
|
|
||||||
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
|
|
||||||
# next the legacy checkpoint version
|
|
||||||
if not hf_download_with_resume(
|
|
||||||
repo_id=repo_id,
|
|
||||||
model_name=model_name,
|
|
||||||
model_dir=str(config.root_path / Model_dir / Weights_dir),
|
|
||||||
):
|
|
||||||
raise Exception(f"download of {model_name} failed")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def get_root(root: str = None) -> str:
|
def get_root(root: str = None) -> str:
|
||||||
@ -465,39 +446,19 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
|||||||
editable=False,
|
editable=False,
|
||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
)
|
)
|
||||||
self.embedding_dir = self.add_widget_intelligent(
|
self.autoimport_dirs = {}
|
||||||
npyscreen.TitleFilename,
|
for description, config_name, path in autoimport_paths(old_opts):
|
||||||
name=" Textual Inversion Embeddings:",
|
self.autoimport_dirs[config_name] = self.add_widget_intelligent(
|
||||||
value=str(default_embedding_dir()),
|
npyscreen.TitleFilename,
|
||||||
select_dir=True,
|
name=description+':',
|
||||||
must_exist=False,
|
value=str(path),
|
||||||
use_two_lines=False,
|
select_dir=True,
|
||||||
labelColor="GOOD",
|
must_exist=False,
|
||||||
begin_entry_at=32,
|
use_two_lines=False,
|
||||||
scroll_exit=True,
|
labelColor="GOOD",
|
||||||
)
|
begin_entry_at=32,
|
||||||
self.lora_dir = self.add_widget_intelligent(
|
scroll_exit=True
|
||||||
npyscreen.TitleFilename,
|
)
|
||||||
name=" LoRA and LyCORIS:",
|
|
||||||
value=str(default_lora_dir()),
|
|
||||||
select_dir=True,
|
|
||||||
must_exist=False,
|
|
||||||
use_two_lines=False,
|
|
||||||
labelColor="GOOD",
|
|
||||||
begin_entry_at=32,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.controlnet_dir = self.add_widget_intelligent(
|
|
||||||
npyscreen.TitleFilename,
|
|
||||||
name=" ControlNets:",
|
|
||||||
value=str(default_controlnet_dir()),
|
|
||||||
select_dir=True,
|
|
||||||
must_exist=False,
|
|
||||||
use_two_lines=False,
|
|
||||||
labelColor="GOOD",
|
|
||||||
begin_entry_at=32,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
@ -562,10 +523,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
|||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||||
)
|
)
|
||||||
if not Path(opt.embedding_dir).parent.exists():
|
|
||||||
bad_fields.append(
|
|
||||||
f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
|
||||||
)
|
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = "The following problems were detected and must be corrected:\n"
|
message = "The following problems were detected and must be corrected:\n"
|
||||||
for problem in bad_fields:
|
for problem in bad_fields:
|
||||||
@ -585,12 +542,15 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
|||||||
"max_loaded_models",
|
"max_loaded_models",
|
||||||
"xformers_enabled",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
"embedding_dir",
|
|
||||||
"lora_dir",
|
|
||||||
"controlnet_dir",
|
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
|
for attr in self.autoimport_dirs:
|
||||||
|
directory = Path(self.autoimport_dirs[attr].value)
|
||||||
|
if directory.is_relative_to(config.root_path):
|
||||||
|
directory = directory.relative_to(config.root_path)
|
||||||
|
setattr(new_opts, attr, directory)
|
||||||
|
|
||||||
new_opts.hf_token = self.hf_token.value
|
new_opts.hf_token = self.hf_token.value
|
||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
@ -607,7 +567,8 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
|||||||
self.program_opts = program_opts
|
self.program_opts = program_opts
|
||||||
self.invokeai_opts = invokeai_opts
|
self.invokeai_opts = invokeai_opts
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.user_selections = default_user_selections(program_opts)
|
self.autoload_pending = True
|
||||||
|
self.install_selections = default_user_selections(program_opts)
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -642,41 +603,62 @@ def default_startup_options(init_file: Path) -> Namespace:
|
|||||||
opts.nsfw_checker = True
|
opts.nsfw_checker = True
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
def default_user_selections(program_opts: Namespace) -> UserSelections:
|
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||||
return UserSelections(
|
installer = ModelInstall(config)
|
||||||
install_models=default_dataset()
|
models = installer.all_models()
|
||||||
|
return InstallSelections(
|
||||||
|
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||||
if program_opts.default_only
|
if program_opts.default_only
|
||||||
else recommended_datasets()
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else dict(),
|
else list(),
|
||||||
purge_deleted_models=False,
|
# scan_directory=None,
|
||||||
scan_directory=None,
|
# autoscan_on_startup=None,
|
||||||
autoscan_on_startup=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
def autoimport_paths(config: InvokeAIAppConfig):
|
||||||
|
return [
|
||||||
|
('Checkpoints & diffusers models', 'autoimport_dir', config.root_path / config.autoimport_dir),
|
||||||
|
('LoRA/LyCORIS models', 'lora_dir', config.root_path / config.lora_dir),
|
||||||
|
('Controlnet models', 'controlnet_dir', config.root_path / config.controlnet_dir),
|
||||||
|
('Textual Inversion Embeddings', 'embedding_dir', config.root_path / config.embedding_dir),
|
||||||
|
]
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||||
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||||
|
|
||||||
for name in (
|
for name in (
|
||||||
"models",
|
"models",
|
||||||
"configs",
|
|
||||||
"embeddings",
|
|
||||||
"databases",
|
"databases",
|
||||||
"loras",
|
|
||||||
"controlnets",
|
|
||||||
"text-inversion-output",
|
"text-inversion-output",
|
||||||
"text-inversion-training-data",
|
"text-inversion-training-data",
|
||||||
|
"configs"
|
||||||
):
|
):
|
||||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||||
|
for model_type in ModelType:
|
||||||
|
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
configs_src = Path(configs.__path__[0])
|
configs_src = Path(configs.__path__[0])
|
||||||
configs_dest = root / "configs"
|
configs_dest = root / "configs"
|
||||||
if not os.path.samefile(configs_src, configs_dest):
|
if not os.path.samefile(configs_src, configs_dest):
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
|
dest = root / 'models'
|
||||||
|
for model_base in BaseModelType:
|
||||||
|
for model_type in ModelType:
|
||||||
|
path = dest / model_base.value / model_type.value
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = dest / 'core'
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
|
||||||
|
yaml_file.write(yaml.dump({'__metadata__':
|
||||||
|
{'version':'3.0.0'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def run_console_ui(
|
def run_console_ui(
|
||||||
program_opts: Namespace, initfile: Path = None
|
program_opts: Namespace, initfile: Path = None
|
||||||
@ -699,7 +681,7 @@ def run_console_ui(
|
|||||||
if editApp.user_cancelled:
|
if editApp.user_cancelled:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
else:
|
else:
|
||||||
return (editApp.new_opts, editApp.user_selections)
|
return (editApp.new_opts, editApp.install_selections)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -722,18 +704,6 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return config.root_path / "outputs"
|
return config.root_path / "outputs"
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def default_embedding_dir() -> Path:
|
|
||||||
return config.root_path / "embeddings"
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def default_lora_dir() -> Path:
|
|
||||||
return config.root_path / "loras"
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def default_controlnet_dir() -> Path:
|
|
||||||
return config.root_path / "controlnets"
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||||
opt = default_startup_options(initfile)
|
opt = default_startup_options(initfile)
|
||||||
@ -758,14 +728,42 @@ def migrate_init_file(legacy_format:Path):
|
|||||||
new.nsfw_checker = old.safety_checker
|
new.nsfw_checker = old.safety_checker
|
||||||
new.xformers_enabled = old.xformers
|
new.xformers_enabled = old.xformers
|
||||||
new.conf_path = old.conf
|
new.conf_path = old.conf
|
||||||
new.embedding_dir = old.embedding_path
|
new.root = legacy_format.parent.resolve()
|
||||||
|
|
||||||
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||||
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||||
outfile.write(new.to_yaml())
|
outfile.write(new.to_yaml())
|
||||||
|
|
||||||
legacy_format.replace(legacy_format.parent / 'invokeai.init.old')
|
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
def migrate_models(root: Path):
|
||||||
|
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||||
|
do_migrate(root, root)
|
||||||
|
|
||||||
|
def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
||||||
|
# We check for to see if the runtime directory is correctly initialized.
|
||||||
|
old_init_file = root / 'invokeai.init'
|
||||||
|
new_init_file = root / 'invokeai.yaml'
|
||||||
|
old_hub = root / 'models/hub'
|
||||||
|
migration_needed = old_init_file.exists() and not new_init_file.exists() or old_hub.exists()
|
||||||
|
|
||||||
|
if migration_needed:
|
||||||
|
if opt.yes_to_all or \
|
||||||
|
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
|
||||||
|
|
||||||
|
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
||||||
|
migrate_init_file(old_init_file)
|
||||||
|
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||||
|
|
||||||
|
if old_hub.exists():
|
||||||
|
migrate_models(config.root_path)
|
||||||
|
else:
|
||||||
|
print('Cannot continue without conversion. Aborting.')
|
||||||
|
|
||||||
|
return migration_needed
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
@ -831,20 +829,16 @@ def main():
|
|||||||
errors = set()
|
errors = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
models_to_download = default_user_selections(opt)
|
# if we do a root migration/upgrade, then we are keeping previous
|
||||||
|
# configuration and we are done.
|
||||||
# We check for to see if the runtime directory is correctly initialized.
|
if migrate_if_needed(opt, config.root_path):
|
||||||
old_init_file = config.root_path / 'invokeai.init'
|
sys.exit(0)
|
||||||
new_init_file = config.root_path / 'invokeai.yaml'
|
|
||||||
if old_init_file.exists() and not new_init_file.exists():
|
|
||||||
print('** Migrating invokeai.init to invokeai.yaml')
|
|
||||||
migrate_init_file(old_init_file)
|
|
||||||
# Load new init file into config
|
|
||||||
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
if not config.model_conf_path.exists():
|
||||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||||
|
|
||||||
|
models_to_download = default_user_selections(opt)
|
||||||
|
new_init_file = config.root_path / 'invokeai.yaml'
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, new_init_file)
|
write_default_options(opt, new_init_file)
|
||||||
init_options = Namespace(
|
init_options = Namespace(
|
||||||
@ -855,29 +849,21 @@ def main():
|
|||||||
if init_options:
|
if init_options:
|
||||||
write_opts(init_options, new_init_file)
|
write_opts(init_options, new_init_file)
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if opt.skip_support_models:
|
if opt.skip_support_models:
|
||||||
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
|
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
|
||||||
else:
|
else:
|
||||||
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
|
logger.info("CHECKING/UPDATING SUPPORT MODELS")
|
||||||
download_bert()
|
download_support_models()
|
||||||
download_sd1_clip()
|
|
||||||
download_sd2_clip()
|
|
||||||
download_realesrgan()
|
|
||||||
download_gfpgan()
|
|
||||||
download_codeformer()
|
|
||||||
download_clipseg()
|
|
||||||
download_safety_checker()
|
|
||||||
download_vaes()
|
|
||||||
|
|
||||||
if opt.skip_sd_weights:
|
if opt.skip_sd_weights:
|
||||||
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||||
elif models_to_download:
|
elif models_to_download:
|
||||||
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||||
process_and_execute(opt, models_to_download)
|
process_and_execute(opt, models_to_download)
|
||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
|
581
invokeai/backend/install/migrate_to_3.py
Normal file
581
invokeai/backend/install/migrate_to_3.py
Normal file
@ -0,0 +1,581 @@
|
|||||||
|
'''
|
||||||
|
Migrate the models directory and models.yaml file from an existing
|
||||||
|
InvokeAI 2.3 installation to 3.0.0.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import diffusers
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import (
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
BertTokenizerFast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.backend.model_management import ModelManager
|
||||||
|
from invokeai.backend.model_management.model_probe import (
|
||||||
|
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelProbeInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# holder for paths that we will migrate
|
||||||
|
@dataclass
|
||||||
|
class ModelPaths:
|
||||||
|
models: Path
|
||||||
|
embeddings: Path
|
||||||
|
loras: Path
|
||||||
|
controlnets: Path
|
||||||
|
|
||||||
|
class MigrateTo3(object):
|
||||||
|
def __init__(self,
|
||||||
|
root_directory: Path,
|
||||||
|
dest_models: Path,
|
||||||
|
yaml_file: io.TextIOBase,
|
||||||
|
src_paths: ModelPaths,
|
||||||
|
):
|
||||||
|
self.root_directory = root_directory
|
||||||
|
self.dest_models = dest_models
|
||||||
|
self.dest_yaml = yaml_file
|
||||||
|
self.model_names = set()
|
||||||
|
self.src_paths = src_paths
|
||||||
|
|
||||||
|
self._initialize_yaml()
|
||||||
|
|
||||||
|
def _initialize_yaml(self):
|
||||||
|
self.dest_yaml.write(
|
||||||
|
yaml.dump(
|
||||||
|
{
|
||||||
|
'__metadata__':
|
||||||
|
{
|
||||||
|
'version':'3.0.0'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def unique_name(self,name,info)->str:
|
||||||
|
'''
|
||||||
|
Create a unique name for a model for use within models.yaml.
|
||||||
|
'''
|
||||||
|
done = False
|
||||||
|
key = ModelManager.create_key(name,info.base_type,info.model_type)
|
||||||
|
unique_name = key
|
||||||
|
counter = 1
|
||||||
|
while not done:
|
||||||
|
if unique_name in self.model_names:
|
||||||
|
unique_name = f'{key}-{counter:0>2d}'
|
||||||
|
counter += 1
|
||||||
|
else:
|
||||||
|
done = True
|
||||||
|
self.model_names.add(unique_name)
|
||||||
|
name,_,_ = ModelManager.parse_key(unique_name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def create_directory_structure(self):
|
||||||
|
'''
|
||||||
|
Create the basic directory structure for the models folder.
|
||||||
|
'''
|
||||||
|
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
||||||
|
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
|
||||||
|
ModelType.ControlNet,ModelType.TextualInversion]:
|
||||||
|
path = self.dest_models / model_base.value / model_type.value
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = self.dest_models / 'core'
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_file(src:Path,dest:Path):
|
||||||
|
'''
|
||||||
|
copy a single file with logging
|
||||||
|
'''
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
|
return
|
||||||
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
|
try:
|
||||||
|
shutil.copy(src, dest)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'COPY FAILED: {str(e)}')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_dir(src:Path,dest:Path):
|
||||||
|
'''
|
||||||
|
Recursively copy a directory with logging
|
||||||
|
'''
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
|
try:
|
||||||
|
shutil.copytree(src, dest)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'COPY FAILED: {str(e)}')
|
||||||
|
|
||||||
|
def migrate_models(self, src_dir: Path):
|
||||||
|
'''
|
||||||
|
Recursively walk through src directory, probe anything
|
||||||
|
that looks like a model, and copy the model into the
|
||||||
|
appropriate location within the destination models directory.
|
||||||
|
'''
|
||||||
|
for root, dirs, files in os.walk(src_dir):
|
||||||
|
for f in files:
|
||||||
|
# hack - don't copy raw learned_embeds.bin, let them
|
||||||
|
# be copied as part of a tree copy operation
|
||||||
|
if f == 'learned_embeds.bin':
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
model = Path(root,f)
|
||||||
|
info = ModelProbe().heuristic_probe(model)
|
||||||
|
if not info:
|
||||||
|
continue
|
||||||
|
dest = self._model_probe_to_path(info) / f
|
||||||
|
self.copy_file(model, dest)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
for d in dirs:
|
||||||
|
try:
|
||||||
|
model = Path(root,d)
|
||||||
|
info = ModelProbe().heuristic_probe(model)
|
||||||
|
if not info:
|
||||||
|
continue
|
||||||
|
dest = self._model_probe_to_path(info) / model.name
|
||||||
|
self.copy_dir(model, dest)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
|
def migrate_support_models(self):
|
||||||
|
'''
|
||||||
|
Copy the clipseg, upscaler, and restoration models to their new
|
||||||
|
locations.
|
||||||
|
'''
|
||||||
|
dest_directory = self.dest_models
|
||||||
|
if (self.root_directory / 'models/clipseg').exists():
|
||||||
|
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
|
||||||
|
if (self.root_directory / 'models/realesrgan').exists():
|
||||||
|
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
|
||||||
|
for d in ['codeformer','gfpgan']:
|
||||||
|
path = self.root_directory / 'models' / d
|
||||||
|
if path.exists():
|
||||||
|
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
|
||||||
|
|
||||||
|
def migrate_tuning_models(self):
|
||||||
|
'''
|
||||||
|
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||||
|
'''
|
||||||
|
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||||
|
if not src:
|
||||||
|
continue
|
||||||
|
if src.is_dir():
|
||||||
|
logger.info(f'Scanning {src}')
|
||||||
|
self.migrate_models(src)
|
||||||
|
else:
|
||||||
|
logger.info(f'{src} directory not found; skipping')
|
||||||
|
continue
|
||||||
|
|
||||||
|
def migrate_conversion_models(self):
|
||||||
|
'''
|
||||||
|
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||||
|
script.
|
||||||
|
'''
|
||||||
|
|
||||||
|
dest_directory = self.dest_models
|
||||||
|
kwargs = dict(
|
||||||
|
cache_dir = self.root_directory / 'models/hub',
|
||||||
|
#local_files_only = True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
logger.info('Migrating core tokenizers and text encoders')
|
||||||
|
target_dir = dest_directory / 'core' / 'convert'
|
||||||
|
|
||||||
|
self._migrate_pretrained(BertTokenizerFast,
|
||||||
|
repo_id='bert-base-uncased',
|
||||||
|
dest = target_dir / 'bert-base-uncased',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# sd-1
|
||||||
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
|
repo_id= repo_id,
|
||||||
|
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
||||||
|
**kwargs)
|
||||||
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# sd-2
|
||||||
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
||||||
|
**{'subfolder':'tokenizer',**kwargs}
|
||||||
|
)
|
||||||
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
||||||
|
**{'subfolder':'text_encoder',**kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
logger.info('Migrating stable diffusion VAE')
|
||||||
|
self._migrate_pretrained(AutoencoderKL,
|
||||||
|
repo_id = 'stabilityai/sd-vae-ft-mse',
|
||||||
|
dest = target_dir / 'sd-vae-ft-mse',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# safety checking
|
||||||
|
logger.info('Migrating safety checker')
|
||||||
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
self._migrate_pretrained(AutoFeatureExtractor,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
|
**kwargs)
|
||||||
|
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
|
**kwargs)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
|
def write_yaml(self, model_name: str, path:Path, info:ModelProbeInfo, **kwargs):
|
||||||
|
'''
|
||||||
|
Write a stanza for a moved model into the new models.yaml file.
|
||||||
|
'''
|
||||||
|
name = self.unique_name(model_name, info)
|
||||||
|
stanza = {
|
||||||
|
f'{info.base_type.value}/{info.model_type.value}/{name}': {
|
||||||
|
'name': model_name,
|
||||||
|
'path': str(path),
|
||||||
|
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||||
|
'format': info.format,
|
||||||
|
'image_size': info.image_size,
|
||||||
|
'base': info.base_type.value,
|
||||||
|
'variant': info.variant_type.value,
|
||||||
|
'prediction_type': info.prediction_type.value,
|
||||||
|
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.dest_yaml.write(yaml.dump(stanza))
|
||||||
|
self.dest_yaml.flush()
|
||||||
|
|
||||||
|
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||||
|
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||||
|
|
||||||
|
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {dest}')
|
||||||
|
return
|
||||||
|
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||||
|
self._save_pretrained(model, dest)
|
||||||
|
|
||||||
|
def _save_pretrained(self, model, dest: Path):
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {dest}')
|
||||||
|
return
|
||||||
|
model_name = dest.name
|
||||||
|
download_path = dest.with_name(f'{model_name}.downloading')
|
||||||
|
model.save_pretrained(download_path, safe_serialization=True)
|
||||||
|
download_path.replace(dest)
|
||||||
|
|
||||||
|
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||||
|
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||||
|
info = ModelProbe().heuristic_probe(vae)
|
||||||
|
_, model_name = repo_id.split('/')
|
||||||
|
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||||
|
vae.save_pretrained(dest, safe_serialization=True)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
def _vae_path(self, vae: Union[str,dict])->Path:
|
||||||
|
'''
|
||||||
|
Convert 2.3 VAE stanza to a straight path.
|
||||||
|
'''
|
||||||
|
vae_path = None
|
||||||
|
|
||||||
|
# First get a path
|
||||||
|
if isinstance(vae,str):
|
||||||
|
vae_path = vae
|
||||||
|
|
||||||
|
elif isinstance(vae,DictConfig):
|
||||||
|
if p := vae.get('path'):
|
||||||
|
vae_path = p
|
||||||
|
elif repo_id := vae.get('repo_id'):
|
||||||
|
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||||
|
vae_path = 'models/core/convert/se-vae-ft-mse'
|
||||||
|
else:
|
||||||
|
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||||
|
|
||||||
|
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||||
|
|
||||||
|
# if the VAE is in the old models directory, then we must move it into the new
|
||||||
|
# one. VAEs outside of this directory can stay where they are.
|
||||||
|
vae_path = Path(vae_path)
|
||||||
|
if vae_path.is_relative_to(self.src_paths.models):
|
||||||
|
info = ModelProbe().heuristic_probe(vae_path)
|
||||||
|
dest = self._model_probe_to_path(info) / vae_path.name
|
||||||
|
if not dest.exists():
|
||||||
|
self.copy_dir(vae_path,dest)
|
||||||
|
vae_path = dest
|
||||||
|
|
||||||
|
if vae_path.is_relative_to(self.dest_models):
|
||||||
|
rel_path = vae_path.relative_to(self.dest_models)
|
||||||
|
return Path('models',rel_path)
|
||||||
|
else:
|
||||||
|
return vae_path
|
||||||
|
|
||||||
|
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
|
||||||
|
'''
|
||||||
|
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||||
|
'''
|
||||||
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
|
cache = self.root_directory / 'models/hub'
|
||||||
|
kwargs = dict(
|
||||||
|
cache_dir = cache,
|
||||||
|
safety_checker = None,
|
||||||
|
# local_files_only = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
owner,repo_name = repo_id.split('/')
|
||||||
|
model_name = model_name or repo_name
|
||||||
|
model = cache / '--'.join(['models',owner,repo_name])
|
||||||
|
|
||||||
|
if len(list(model.glob('snapshots/**/model_index.json')))==0:
|
||||||
|
return
|
||||||
|
revisions = [x.name for x in model.glob('refs/*')]
|
||||||
|
|
||||||
|
# if an fp16 is available we use that
|
||||||
|
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
|
||||||
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
repo_id,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
info = ModelProbe().heuristic_probe(pipeline)
|
||||||
|
if not info:
|
||||||
|
return
|
||||||
|
|
||||||
|
dest = self._model_probe_to_path(info) / repo_name
|
||||||
|
self._save_pretrained(pipeline, dest)
|
||||||
|
|
||||||
|
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||||
|
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
|
||||||
|
|
||||||
|
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||||
|
'''
|
||||||
|
Migrate a model referred to using 'weights' or 'path'
|
||||||
|
'''
|
||||||
|
|
||||||
|
# handle relative paths
|
||||||
|
dest_dir = self.dest_models
|
||||||
|
location = self.root_directory / location
|
||||||
|
|
||||||
|
info = ModelProbe().heuristic_probe(location)
|
||||||
|
if not info:
|
||||||
|
return
|
||||||
|
|
||||||
|
# uh oh, weights is in the old models directory - move it into the new one
|
||||||
|
if Path(location).is_relative_to(self.src_paths.models):
|
||||||
|
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||||
|
self.copy_dir(location,dest)
|
||||||
|
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
||||||
|
model_name = model_name or location.stem
|
||||||
|
model_name = self.unique_name(model_name, info)
|
||||||
|
self.write_yaml(model_name, path=location, info=info, **extra_config)
|
||||||
|
|
||||||
|
def migrate_defined_models(self):
|
||||||
|
'''
|
||||||
|
Migrate models defined in models.yaml
|
||||||
|
'''
|
||||||
|
# find any models referred to in old models.yaml
|
||||||
|
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
|
||||||
|
|
||||||
|
for model_name, stanza in conf.items():
|
||||||
|
|
||||||
|
try:
|
||||||
|
passthru_args = {}
|
||||||
|
|
||||||
|
if vae := stanza.get('vae'):
|
||||||
|
try:
|
||||||
|
passthru_args['vae'] = str(self._vae_path(vae))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||||
|
logger.warning(str(e))
|
||||||
|
|
||||||
|
if config := stanza.get('config'):
|
||||||
|
passthru_args['config'] = config
|
||||||
|
|
||||||
|
if repo_id := stanza.get('repo_id'):
|
||||||
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
|
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||||
|
|
||||||
|
elif location := stanza.get('weights'):
|
||||||
|
logger.info(f'Migrating checkpoint model {model_name}')
|
||||||
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
|
elif location := stanza.get('path'):
|
||||||
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
|
def migrate(self):
|
||||||
|
self.create_directory_structure()
|
||||||
|
# the configure script is doing this
|
||||||
|
self.migrate_support_models()
|
||||||
|
self.migrate_conversion_models()
|
||||||
|
self.migrate_tuning_models()
|
||||||
|
self.migrate_defined_models()
|
||||||
|
|
||||||
|
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
|
||||||
|
'''
|
||||||
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
|
'''
|
||||||
|
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
|
||||||
|
parser.add_argument(
|
||||||
|
'--embedding_directory',
|
||||||
|
'--embedding_path',
|
||||||
|
type=Path,
|
||||||
|
dest='embedding_path',
|
||||||
|
default=Path('embeddings'),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--lora_directory',
|
||||||
|
dest='lora_path',
|
||||||
|
type=Path,
|
||||||
|
default=Path('loras'),
|
||||||
|
)
|
||||||
|
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
|
||||||
|
return ModelPaths(
|
||||||
|
models = root / 'models',
|
||||||
|
embeddings = root / str(opt.embedding_path).strip('"'),
|
||||||
|
loras = root / str(opt.lora_path).strip('"'),
|
||||||
|
controlnets = root / 'controlnets',
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
||||||
|
'''
|
||||||
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
|
'''
|
||||||
|
# Don't use the config object because it is unforgiving of version updates
|
||||||
|
# Just use omegaconf directly
|
||||||
|
opt = OmegaConf.load(initfile)
|
||||||
|
paths = opt.InvokeAI.Paths
|
||||||
|
models = paths.get('models_dir','models')
|
||||||
|
embeddings = paths.get('embedding_dir','embeddings')
|
||||||
|
loras = paths.get('lora_dir','loras')
|
||||||
|
controlnets = paths.get('controlnet_dir','controlnets')
|
||||||
|
return ModelPaths(
|
||||||
|
models = root / models,
|
||||||
|
embeddings = root / embeddings,
|
||||||
|
loras = root /loras,
|
||||||
|
controlnets = root / controlnets,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||||
|
path = root / 'invokeai.init'
|
||||||
|
if path.exists():
|
||||||
|
return _parse_legacy_initfile(root, path)
|
||||||
|
path = root / 'invokeai.yaml'
|
||||||
|
if path.exists():
|
||||||
|
return _parse_legacy_yamlfile(root, path)
|
||||||
|
|
||||||
|
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||||
|
|
||||||
|
dest_models = dest_directory / 'models-3.0'
|
||||||
|
dest_yaml = dest_directory / 'configs/models.yaml-3.0'
|
||||||
|
|
||||||
|
paths = get_legacy_embeddings(src_directory)
|
||||||
|
|
||||||
|
with open(dest_yaml,'w') as yaml_file:
|
||||||
|
migrator = MigrateTo3(src_directory,
|
||||||
|
dest_models,
|
||||||
|
yaml_file,
|
||||||
|
src_paths = paths,
|
||||||
|
)
|
||||||
|
migrator.migrate()
|
||||||
|
|
||||||
|
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
|
||||||
|
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
||||||
|
dest_models.replace(dest_directory / 'models')
|
||||||
|
|
||||||
|
(dest_directory /'configs/models.yaml').replace(dest_directory / 'configs/models.yaml.orig')
|
||||||
|
dest_yaml.replace(dest_directory / 'configs/models.yaml')
|
||||||
|
print(f"""Migration successful.
|
||||||
|
Original models directory moved to {dest_directory}/models.orig
|
||||||
|
Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig
|
||||||
|
""")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
||||||
|
description="""
|
||||||
|
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||||
|
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||||
|
|
||||||
|
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||||
|
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||||
|
script, which will perform a full upgrade in place."""
|
||||||
|
)
|
||||||
|
parser.add_argument('--from-directory',
|
||||||
|
dest='root_directory',
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
||||||
|
)
|
||||||
|
parser.add_argument('--to-directory',
|
||||||
|
dest='dest_directory',
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||||
|
)
|
||||||
|
# TO DO: Implement full directory scanning
|
||||||
|
# parser.add_argument('--all-models',
|
||||||
|
# action="store_true",
|
||||||
|
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
||||||
|
# )
|
||||||
|
args = parser.parse_args()
|
||||||
|
root_directory = args.root_directory
|
||||||
|
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
|
||||||
|
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
|
||||||
|
assert (root_directory / 'invokeai.init').exists() or (root_directory / 'invokeai.yaml').exists(), f"{root_directory} does not contain an InvokeAI init file."
|
||||||
|
|
||||||
|
dest_directory = args.dest_directory
|
||||||
|
assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory"
|
||||||
|
assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory"
|
||||||
|
assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file."
|
||||||
|
|
||||||
|
do_migrate(root_directory,dest_directory)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2,46 +2,36 @@
|
|||||||
Utility (backend) functions used by model_install.py
|
Utility (backend) functions used by model_install.py
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass,field
|
from dataclasses import dataclass,field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryFile
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List, Dict, Callable
|
from typing import List, Dict, Callable, Union, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import StableDiffusionPipeline
|
||||||
from huggingface_hub import hf_hub_url, HfFolder
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||||
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||||
|
from invokeai.backend.util import download_with_resume
|
||||||
from ..util.logging import InvokeAILogger
|
from ..util.logging import InvokeAILogger
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||||
Model_dir = "models"
|
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||||
|
|
||||||
# initial models omegaconf
|
|
||||||
Datasets = None
|
|
||||||
|
|
||||||
# logger
|
|
||||||
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
|
||||||
|
|
||||||
Config_preamble = """
|
Config_preamble = """
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
@ -52,6 +42,24 @@ Config_preamble = """
|
|||||||
# was trained on.
|
# was trained on.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
LEGACY_CONFIGS = {
|
||||||
|
BaseModelType.StableDiffusion1: {
|
||||||
|
ModelVariantType.Normal: 'v1-inference.yaml',
|
||||||
|
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
|
||||||
|
},
|
||||||
|
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelVariantType.Normal: {
|
||||||
|
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
|
||||||
|
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
|
||||||
|
},
|
||||||
|
ModelVariantType.Inpaint: {
|
||||||
|
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
||||||
|
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInstallList:
|
class ModelInstallList:
|
||||||
'''Class for listing models to be installed/removed'''
|
'''Class for listing models to be installed/removed'''
|
||||||
@ -59,133 +67,321 @@ class ModelInstallList:
|
|||||||
remove_models: List[str] = field(default_factory=list)
|
remove_models: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserSelections():
|
class InstallSelections():
|
||||||
install_models: List[str]= field(default_factory=list)
|
install_models: List[str]= field(default_factory=list)
|
||||||
remove_models: List[str]=field(default_factory=list)
|
remove_models: List[str]=field(default_factory=list)
|
||||||
purge_deleted_models: bool=field(default_factory=list)
|
# scan_directory: Path = None
|
||||||
install_cn_models: List[str] = field(default_factory=list)
|
# autoscan_on_startup: bool=False
|
||||||
remove_cn_models: List[str] = field(default_factory=list)
|
|
||||||
install_lora_models: List[str] = field(default_factory=list)
|
@dataclass
|
||||||
remove_lora_models: List[str] = field(default_factory=list)
|
class ModelLoadInfo():
|
||||||
install_ti_models: List[str] = field(default_factory=list)
|
name: str
|
||||||
remove_ti_models: List[str] = field(default_factory=list)
|
model_type: ModelType
|
||||||
scan_directory: Path = None
|
base_type: BaseModelType
|
||||||
autoscan_on_startup: bool=False
|
path: Path = None
|
||||||
import_model_paths: str=None
|
repo_id: str = None
|
||||||
|
description: str = ''
|
||||||
|
installed: bool = False
|
||||||
|
recommended: bool = False
|
||||||
|
default: bool = False
|
||||||
|
|
||||||
|
class ModelInstall(object):
|
||||||
|
def __init__(self,
|
||||||
|
config:InvokeAIAppConfig,
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
model_manager: ModelManager = None,
|
||||||
|
access_token:str = None):
|
||||||
|
self.config = config
|
||||||
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
|
self.prediction_helper = prediction_type_helper
|
||||||
|
self.access_token = access_token or HfFolder.get_token()
|
||||||
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||||
|
|
||||||
|
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||||
|
'''
|
||||||
|
Return dict of model_key=>ModelLoadInfo objects.
|
||||||
|
This method consolidates and simplifies the entries in both
|
||||||
|
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||||
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
|
by their name, to improve the display somewhat.
|
||||||
|
'''
|
||||||
|
model_dict = dict()
|
||||||
|
|
||||||
def default_config_file():
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
return config.model_conf_path
|
for key, value in self.datasets.items():
|
||||||
|
name,base,model_type = ModelManager.parse_key(key)
|
||||||
|
value['name'] = name
|
||||||
|
value['base_type'] = base
|
||||||
|
value['model_type'] = model_type
|
||||||
|
model_dict[key] = ModelLoadInfo(**value)
|
||||||
|
|
||||||
def sd_configs():
|
# supplement with entries in models.yaml
|
||||||
return config.legacy_conf_path
|
installed_models = self.mgr.list_models()
|
||||||
|
for md in installed_models:
|
||||||
def initial_models():
|
base = md['base_model']
|
||||||
global Datasets
|
model_type = md['type']
|
||||||
if Datasets:
|
name = md['name']
|
||||||
return Datasets
|
key = ModelManager.create_key(name, base, model_type)
|
||||||
return (Datasets := OmegaConf.load(Dataset_path)['diffusers'])
|
if key in model_dict:
|
||||||
|
model_dict[key].installed = True
|
||||||
def install_requested_models(
|
else:
|
||||||
diffusers: ModelInstallList = None,
|
model_dict[key] = ModelLoadInfo(
|
||||||
controlnet: ModelInstallList = None,
|
name = name,
|
||||||
lora: ModelInstallList = None,
|
base_type = base,
|
||||||
ti: ModelInstallList = None,
|
model_type = model_type,
|
||||||
cn_model_map: Dict[str,str] = None, # temporary - move to model manager
|
path = value.get('path'),
|
||||||
scan_directory: Path = None,
|
installed = True,
|
||||||
external_models: List[str] = None,
|
|
||||||
scan_at_startup: bool = False,
|
|
||||||
precision: str = "float16",
|
|
||||||
purge_deleted: bool = False,
|
|
||||||
config_file_path: Path = None,
|
|
||||||
model_config_file_callback: Callable[[Path],Path] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Entry point for installing/deleting starter models, or installing external models.
|
|
||||||
"""
|
|
||||||
access_token = HfFolder.get_token()
|
|
||||||
config_file_path = config_file_path or default_config_file()
|
|
||||||
if not config_file_path.exists():
|
|
||||||
open(config_file_path, "w")
|
|
||||||
|
|
||||||
# prevent circular import here
|
|
||||||
from ..model_management import ModelManager
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
|
||||||
if controlnet:
|
|
||||||
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
|
|
||||||
model_manager.delete_controlnet_models(controlnet.remove_models)
|
|
||||||
|
|
||||||
if lora:
|
|
||||||
model_manager.install_lora_models(lora.install_models, access_token=access_token)
|
|
||||||
model_manager.delete_lora_models(lora.remove_models)
|
|
||||||
|
|
||||||
if ti:
|
|
||||||
model_manager.install_ti_models(ti.install_models, access_token=access_token)
|
|
||||||
model_manager.delete_ti_models(ti.remove_models)
|
|
||||||
|
|
||||||
if diffusers:
|
|
||||||
# TODO: Replace next three paragraphs with calls into new model manager
|
|
||||||
if diffusers.remove_models and len(diffusers.remove_models) > 0:
|
|
||||||
logger.info("Processing requested deletions")
|
|
||||||
for model in diffusers.remove_models:
|
|
||||||
logger.info(f"{model}...")
|
|
||||||
model_manager.del_model(model, delete_files=purge_deleted)
|
|
||||||
model_manager.commit(config_file_path)
|
|
||||||
|
|
||||||
if diffusers.install_models and len(diffusers.install_models) > 0:
|
|
||||||
logger.info("Installing requested models")
|
|
||||||
downloaded_paths = download_weight_datasets(
|
|
||||||
models=diffusers.install_models,
|
|
||||||
access_token=None,
|
|
||||||
precision=precision,
|
|
||||||
)
|
|
||||||
successful = {x:v for x,v in downloaded_paths.items() if v is not None}
|
|
||||||
if len(successful) > 0:
|
|
||||||
update_config_file(successful, config_file_path)
|
|
||||||
if len(successful) < len(diffusers.install_models):
|
|
||||||
unsuccessful = [x for x in downloaded_paths if downloaded_paths[x] is None]
|
|
||||||
logger.warning(f"Some of the model downloads were not successful: {unsuccessful}")
|
|
||||||
|
|
||||||
# due to above, we have to reload the model manager because conf file
|
|
||||||
# was changed behind its back
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file_path), precision=precision)
|
|
||||||
|
|
||||||
external_models = external_models or list()
|
|
||||||
if scan_directory:
|
|
||||||
external_models.append(str(scan_directory))
|
|
||||||
|
|
||||||
if len(external_models) > 0:
|
|
||||||
logger.info("INSTALLING EXTERNAL MODELS")
|
|
||||||
for path_url_or_repo in external_models:
|
|
||||||
try:
|
|
||||||
logger.debug(f'In install_requested_models; callback = {model_config_file_callback}')
|
|
||||||
model_manager.heuristic_import(
|
|
||||||
path_url_or_repo,
|
|
||||||
commit_to_conf=config_file_path,
|
|
||||||
config_file_callback = model_config_file_callback,
|
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||||
sys.exit(-1)
|
|
||||||
except Exception:
|
def starter_models(self)->Set[str]:
|
||||||
|
models = set()
|
||||||
|
for key, value in self.datasets.items():
|
||||||
|
name,base,model_type = ModelManager.parse_key(key)
|
||||||
|
if model_type==ModelType.Main:
|
||||||
|
models.add(key)
|
||||||
|
return models
|
||||||
|
|
||||||
|
def recommended_models(self)->Set[str]:
|
||||||
|
starters = self.starter_models()
|
||||||
|
return set([x for x in starters if self.datasets[x].get('recommended',False)])
|
||||||
|
|
||||||
|
def default_model(self)->str:
|
||||||
|
starters = self.starter_models()
|
||||||
|
defaults = [x for x in starters if self.datasets[x].get('default',False)]
|
||||||
|
return defaults[0]
|
||||||
|
|
||||||
|
def install(self, selections: InstallSelections):
|
||||||
|
job = 1
|
||||||
|
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||||
|
|
||||||
|
# remove requested models
|
||||||
|
for key in selections.remove_models:
|
||||||
|
name,base,mtype = self.mgr.parse_key(key)
|
||||||
|
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||||
|
self.mgr.del_model(name,base,mtype)
|
||||||
|
job += 1
|
||||||
|
|
||||||
|
# add requested models
|
||||||
|
for path in selections.install_models:
|
||||||
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
|
self.heuristic_install(path)
|
||||||
|
job += 1
|
||||||
|
|
||||||
|
self.mgr.commit()
|
||||||
|
|
||||||
|
def heuristic_install(self,
|
||||||
|
model_path_id_or_url: Union[str,Path],
|
||||||
|
models_installed: Set[Path]=None)->Set[Path]:
|
||||||
|
|
||||||
|
if not models_installed:
|
||||||
|
models_installed = set()
|
||||||
|
|
||||||
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
|
self.current_id = model_path_id_or_url
|
||||||
|
path = Path(model_path_id_or_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# checkpoint file, or similar
|
||||||
|
if path.is_file():
|
||||||
|
models_installed.add(self._install_path(path))
|
||||||
|
|
||||||
|
# folders style or similar
|
||||||
|
elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
|
models_installed.add(self._install_path(path))
|
||||||
|
|
||||||
|
# recursive scan
|
||||||
|
elif path.is_dir():
|
||||||
|
for child in path.iterdir():
|
||||||
|
self.heuristic_install(child, models_installed=models_installed)
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
elif len(str(path).split('/')) == 2:
|
||||||
|
models_installed.add(self._install_repo(str(path)))
|
||||||
|
|
||||||
|
# a URL
|
||||||
|
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")):
|
||||||
|
models_installed.add(self._install_url(model_path_id_or_url))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
|
||||||
|
return models_installed
|
||||||
|
|
||||||
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||||
|
try:
|
||||||
|
# logger.debug(f'Probing {path}')
|
||||||
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
|
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||||
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
|
attributes = self._make_attributes(path,info)
|
||||||
|
self.mgr.add_model(model_name = model_name,
|
||||||
|
base_model = info.base_type,
|
||||||
|
model_type = info.model_type,
|
||||||
|
model_attributes = attributes,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'{str(e)} Skipping registration.')
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _install_url(self, url: str)->Path:
|
||||||
|
# copy to a staging area, probe, import and delete
|
||||||
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
|
location = download_with_resume(url,Path(staging))
|
||||||
|
if not location:
|
||||||
|
logger.error(f'Unable to download {url}. Skipping.')
|
||||||
|
info = ModelProbe().heuristic_probe(location)
|
||||||
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
|
models_path = shutil.move(location,dest)
|
||||||
|
|
||||||
|
# staged version will be garbage-collected at this time
|
||||||
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
|
def _install_repo(self, repo_id: str)->Path:
|
||||||
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
|
# we try to figure out how to download this most economically
|
||||||
|
# list all the files in the repo
|
||||||
|
files = [x.rfilename for x in hinfo.siblings]
|
||||||
|
location = None
|
||||||
|
|
||||||
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
|
staging = Path(staging)
|
||||||
|
if 'model_index.json' in files:
|
||||||
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
|
else:
|
||||||
|
for suffix in ['safetensors','bin']:
|
||||||
|
if f'pytorch_lora_weights.{suffix}' in files:
|
||||||
|
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
||||||
|
break
|
||||||
|
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
|
||||||
|
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
|
elif f'diffusion_pytorch_model.{suffix}' in files:
|
||||||
|
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
|
elif f'learned_embeds.{suffix}' in files:
|
||||||
|
location = self._download_hf_model(repo_id, ['learned_embeds.suffix'], staging)
|
||||||
|
break
|
||||||
|
if not location:
|
||||||
|
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||||
|
return
|
||||||
|
|
||||||
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
|
if not info:
|
||||||
|
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||||
|
return
|
||||||
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||||
|
if dest.exists():
|
||||||
|
shutil.rmtree(dest)
|
||||||
|
shutil.copytree(location,dest)
|
||||||
|
return self._install_path(dest, info)
|
||||||
|
|
||||||
|
def _get_model_name(self,path_name: str, location: Path)->str:
|
||||||
|
'''
|
||||||
|
Calculate a name for the model - primitive implementation.
|
||||||
|
'''
|
||||||
|
if key := self.reverse_paths.get(path_name):
|
||||||
|
(name, base, mtype) = ModelManager.parse_key(key)
|
||||||
|
return name
|
||||||
|
else:
|
||||||
|
return location.stem
|
||||||
|
|
||||||
|
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||||
|
model_name = path.name if path.is_dir() else path.stem
|
||||||
|
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
|
||||||
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
|
if key in self.datasets:
|
||||||
|
description = self.datasets[key].get('description') or description
|
||||||
|
|
||||||
|
rel_path = self.relative_to_root(path)
|
||||||
|
|
||||||
|
attributes = dict(
|
||||||
|
path = str(rel_path),
|
||||||
|
description = str(description),
|
||||||
|
model_format = info.format,
|
||||||
|
)
|
||||||
|
if info.model_type == ModelType.Main:
|
||||||
|
attributes.update(dict(variant = info.variant_type,))
|
||||||
|
if info.format=="checkpoint":
|
||||||
|
try:
|
||||||
|
possible_conf = path.with_suffix('.yaml')
|
||||||
|
if possible_conf.exists():
|
||||||
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
|
elif info.base_type == BaseModelType.StableDiffusion2:
|
||||||
|
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
|
||||||
|
else:
|
||||||
|
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
|
||||||
|
except KeyError:
|
||||||
|
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
||||||
|
|
||||||
|
attributes.update(
|
||||||
|
dict(
|
||||||
|
config = str(legacy_conf)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
def relative_to_root(self, path: Path)->Path:
|
||||||
|
root = self.config.root_path
|
||||||
|
if path.is_relative_to(root):
|
||||||
|
return path.relative_to(root)
|
||||||
|
else:
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
|
||||||
|
'''
|
||||||
|
This retrieves a StableDiffusion model from cache or remote and then
|
||||||
|
does a save_pretrained() to the indicated staging area.
|
||||||
|
'''
|
||||||
|
_,name = repo_id.split("/")
|
||||||
|
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
|
||||||
|
model = None
|
||||||
|
for revision in revisions:
|
||||||
|
try:
|
||||||
|
model = StableDiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||||
|
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
pass
|
pass
|
||||||
|
if model:
|
||||||
|
break
|
||||||
|
if not model:
|
||||||
|
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
|
||||||
|
return None
|
||||||
|
model.save_pretrained(staging / name, safe_serialization=True)
|
||||||
|
return staging / name
|
||||||
|
|
||||||
if scan_at_startup and scan_directory.is_dir():
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
|
||||||
update_autoconvert_dir(scan_directory)
|
_,name = repo_id.split("/")
|
||||||
else:
|
location = staging / name
|
||||||
update_autoconvert_dir(None)
|
paths = list()
|
||||||
|
for filename in files:
|
||||||
def update_autoconvert_dir(autodir: Path):
|
p = hf_download_with_resume(repo_id,
|
||||||
'''
|
model_dir=location,
|
||||||
Update the "autoconvert_dir" option in invokeai.yaml
|
model_name=filename,
|
||||||
'''
|
access_token = self.access_token
|
||||||
invokeai_config_path = config.init_file_path
|
)
|
||||||
conf = OmegaConf.load(invokeai_config_path)
|
if p:
|
||||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
|
paths.append(p)
|
||||||
yaml = OmegaConf.to_yaml(conf)
|
else:
|
||||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
logger.warning(f'Could not download {filename} from {repo_id}.')
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
|
||||||
outfile.write(yaml)
|
return location if len(paths)>0 else None
|
||||||
tmpfile.replace(invokeai_config_path)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _reverse_paths(cls,datasets)->dict:
|
||||||
|
'''
|
||||||
|
Reverse mapping from repo_id/path to destination name.
|
||||||
|
'''
|
||||||
|
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
def yes_or_no(prompt: str, default_yes=True):
|
||||||
@ -197,133 +393,19 @@ def yes_or_no(prompt: str, default_yes=True):
|
|||||||
return response[0] in ("y", "Y")
|
return response[0] in ("y", "Y")
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def recommended_datasets() -> List['str']:
|
def hf_download_from_pretrained(
|
||||||
datasets = set()
|
model_class: object, model_name: str, destination: Path, **kwargs
|
||||||
for ds in initial_models().keys():
|
|
||||||
if initial_models()[ds].get("recommended", False):
|
|
||||||
datasets.add(ds)
|
|
||||||
return list(datasets)
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def default_dataset() -> dict:
|
|
||||||
datasets = set()
|
|
||||||
for ds in initial_models().keys():
|
|
||||||
if initial_models()[ds].get("default", False):
|
|
||||||
datasets.add(ds)
|
|
||||||
return list(datasets)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def all_datasets() -> dict:
|
|
||||||
datasets = dict()
|
|
||||||
for ds in initial_models().keys():
|
|
||||||
datasets[ds] = True
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
# look for legacy model.ckpt in models directory and offer to
|
|
||||||
# normalize its name
|
|
||||||
def migrate_models_ckpt():
|
|
||||||
model_path = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
|
||||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
|
||||||
return
|
|
||||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
|
||||||
logger.warning(
|
|
||||||
'The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.'
|
|
||||||
)
|
|
||||||
logger.warning(f"model.ckpt => {new_name}")
|
|
||||||
os.replace(
|
|
||||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_weight_datasets(
|
|
||||||
models: List[str], access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
migrate_models_ckpt()
|
|
||||||
successful = dict()
|
|
||||||
for mod in models:
|
|
||||||
logger.info(f"Downloading {mod}:")
|
|
||||||
successful[mod] = _download_repo_or_file(
|
|
||||||
initial_models()[mod], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return successful
|
|
||||||
|
|
||||||
|
|
||||||
def _download_repo_or_file(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
) -> Path:
|
|
||||||
path = None
|
|
||||||
if mconfig["format"] == "ckpt":
|
|
||||||
path = _download_ckpt_weights(mconfig, access_token)
|
|
||||||
else:
|
|
||||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
|
||||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
|
||||||
_download_diffusion_weights(
|
|
||||||
mconfig["vae"], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
filename = mconfig["file"]
|
|
||||||
cache_dir = os.path.join(config.root_dir, Model_dir, Weights_dir)
|
|
||||||
return hf_download_with_resume(
|
|
||||||
repo_id=repo_id,
|
|
||||||
model_dir=cache_dir,
|
|
||||||
model_name=filename,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_from_hf(
|
|
||||||
model_class: object, model_name: str, **kwargs
|
|
||||||
):
|
):
|
||||||
logger = InvokeAILogger.getLogger('InvokeAI')
|
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||||
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||||
|
|
||||||
path = config.cache_dir
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
cache_dir=path,
|
|
||||||
resume_download=True,
|
resume_download=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
model_name = "--".join(("models", *model_name.split("/")))
|
model.save_pretrained(destination, safe_serialization=True)
|
||||||
return path / model_name if model else None
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def _download_diffusion_weights(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
model_class = (
|
|
||||||
StableDiffusionGeneratorPipeline
|
|
||||||
if mconfig.get("format", None) == "diffusers"
|
|
||||||
else AutoencoderKL
|
|
||||||
)
|
|
||||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
|
||||||
path = None
|
|
||||||
for extra_args in extra_arg_list:
|
|
||||||
try:
|
|
||||||
path = download_from_hf(
|
|
||||||
model_class,
|
|
||||||
repo_id,
|
|
||||||
safety_checker=None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
except OSError as e:
|
|
||||||
if 'Revision Not Found' in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.error(str(e))
|
|
||||||
if path:
|
|
||||||
break
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_with_resume(
|
def hf_download_with_resume(
|
||||||
@ -383,128 +465,3 @@ def hf_download_with_resume(
|
|||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def update_config_file(successfully_downloaded: dict, config_file: Path):
|
|
||||||
config_file = (
|
|
||||||
Path(config_file) if config_file is not None else default_config_file()
|
|
||||||
)
|
|
||||||
|
|
||||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
|
||||||
# Create it if it doesn't exist.
|
|
||||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
|
||||||
# are doing if they are passing a custom config file from elsewhere.
|
|
||||||
if config_file is default_config_file() and not config_file.parent.exists():
|
|
||||||
configs_src = Dataset_path.parent
|
|
||||||
configs_dest = default_config_file().parent
|
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
yaml = new_config_file_contents(successfully_downloaded, config_file)
|
|
||||||
|
|
||||||
try:
|
|
||||||
backup = None
|
|
||||||
if os.path.exists(config_file):
|
|
||||||
logger.warning(
|
|
||||||
f"{config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
|
||||||
)
|
|
||||||
backup = config_file.with_suffix(".yaml.orig")
|
|
||||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
|
||||||
if sys.platform == "win32" and backup.is_file():
|
|
||||||
backup.unlink()
|
|
||||||
config_file.rename(backup)
|
|
||||||
|
|
||||||
with TemporaryFile() as tmp:
|
|
||||||
tmp.write(Config_preamble.encode())
|
|
||||||
tmp.write(yaml.encode())
|
|
||||||
|
|
||||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
|
||||||
tmp.seek(0)
|
|
||||||
new_config.write(tmp.read())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating config file {config_file}: {str(e)}")
|
|
||||||
if backup is not None:
|
|
||||||
logger.info("restoring previous config file")
|
|
||||||
## workaround, for WinError 183, see above
|
|
||||||
if sys.platform == "win32" and config_file.is_file():
|
|
||||||
config_file.unlink()
|
|
||||||
backup.rename(config_file)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Successfully created new configuration file {config_file}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def new_config_file_contents(
|
|
||||||
successfully_downloaded: dict,
|
|
||||||
config_file: Path,
|
|
||||||
) -> str:
|
|
||||||
if config_file.exists():
|
|
||||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
|
||||||
else:
|
|
||||||
conf = OmegaConf.create()
|
|
||||||
|
|
||||||
default_selected = None
|
|
||||||
for model in successfully_downloaded:
|
|
||||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
|
||||||
# version of the model was previously defined, and whether the current
|
|
||||||
# model is a diffusers (indicated with a path)
|
|
||||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
|
||||||
delete_weights(model, conf[model])
|
|
||||||
|
|
||||||
stanza = {}
|
|
||||||
mod = initial_models()[model]
|
|
||||||
stanza["description"] = mod["description"]
|
|
||||||
stanza["repo_id"] = mod["repo_id"]
|
|
||||||
stanza["format"] = mod["format"]
|
|
||||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
|
||||||
# so we no longer require these in INITIAL_MODELS.yaml
|
|
||||||
if "width" in mod:
|
|
||||||
stanza["width"] = mod["width"]
|
|
||||||
if "height" in mod:
|
|
||||||
stanza["height"] = mod["height"]
|
|
||||||
if "file" in mod:
|
|
||||||
stanza["weights"] = os.path.relpath(
|
|
||||||
successfully_downloaded[model], start=config.root_dir
|
|
||||||
)
|
|
||||||
stanza["config"] = os.path.normpath(
|
|
||||||
os.path.join(sd_configs(), mod["config"])
|
|
||||||
)
|
|
||||||
if "vae" in mod:
|
|
||||||
if "file" in mod["vae"]:
|
|
||||||
stanza["vae"] = os.path.normpath(
|
|
||||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stanza["vae"] = mod["vae"]
|
|
||||||
if mod.get("default", False):
|
|
||||||
stanza["default"] = True
|
|
||||||
default_selected = True
|
|
||||||
|
|
||||||
conf[model] = stanza
|
|
||||||
|
|
||||||
# if no default model was chosen, then we select the first
|
|
||||||
# one in the list
|
|
||||||
if not default_selected:
|
|
||||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
|
||||||
|
|
||||||
return OmegaConf.to_yaml(conf)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def delete_weights(model_name: str, conf_stanza: dict):
|
|
||||||
if not (weights := conf_stanza.get("weights")):
|
|
||||||
return
|
|
||||||
if re.match("/VAE/", conf_stanza.get("config")):
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"\nThe checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = Path(weights)
|
|
||||||
if not weights.is_absolute():
|
|
||||||
weights = config.root_dir / weights
|
|
||||||
try:
|
|
||||||
weights.unlink()
|
|
||||||
except OSError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
@ -4,3 +4,4 @@ Initialization file for invokeai.backend.model_management
|
|||||||
from .model_manager import ModelManager, ModelInfo
|
from .model_manager import ModelManager, ModelInfo
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
|
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -73,7 +73,9 @@ from transformers import (
|
|||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
MODEL_ROOT = None
|
# TODO: redo in future
|
||||||
|
#CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core" / "convert"
|
||||||
|
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / "models" / "core" / "convert"
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
@ -605,7 +607,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
else:
|
else:
|
||||||
vae_state_dict = checkpoint
|
vae_state_dict = checkpoint
|
||||||
|
|
||||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict, config)
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||||
@ -828,7 +830,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
text_model = CLIPTextModel.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
text_model_dict = {}
|
text_model_dict = {}
|
||||||
@ -882,7 +884,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||||
subfolder='text_encoder',
|
subfolder='text_encoder',
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -949,7 +951,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
def replace_checkpoint_vae(checkpoint, vae_path:str):
|
def replace_checkpoint_vae(checkpoint, vae_path: str):
|
||||||
if vae_path.endswith(".safetensors"):
|
if vae_path.endswith(".safetensors"):
|
||||||
vae_ckpt = load_file(vae_path)
|
vae_ckpt = load_file(vae_path)
|
||||||
else:
|
else:
|
||||||
@ -959,7 +961,7 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
|||||||
new_key = f'first_stage_model.{vae_key}'
|
new_key = f'first_stage_model.{vae_key}'
|
||||||
checkpoint[new_key] = state_dict[vae_key]
|
checkpoint[new_key] = state_dict[vae_key]
|
||||||
|
|
||||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
|
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
vae_config, image_size=image_size
|
vae_config, image_size=image_size
|
||||||
)
|
)
|
||||||
@ -979,8 +981,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
original_config_file: str,
|
original_config_file: str,
|
||||||
extract_ema: bool = True,
|
extract_ema: bool = True,
|
||||||
precision: torch.dtype = torch.float32,
|
precision: torch.dtype = torch.float32,
|
||||||
upcast_attention: bool = False,
|
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
|
|
||||||
scan_needed: bool = True,
|
scan_needed: bool = True,
|
||||||
) -> StableDiffusionPipeline:
|
) -> StableDiffusionPipeline:
|
||||||
"""
|
"""
|
||||||
@ -994,8 +994,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param checkpoint_path: Path to `.ckpt` file.
|
:param checkpoint_path: Path to `.ckpt` file.
|
||||||
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
||||||
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
|
||||||
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
|
||||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
||||||
@ -1003,17 +1001,16 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
|
||||||
running stable diffusion 2.1.
|
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
if not isinstance(checkpoint_path, Path):
|
||||||
|
checkpoint_path = Path(checkpoint_path)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
if str(checkpoint_path).endswith(".safetensors"):
|
if checkpoint_path.suffix == ".safetensors":
|
||||||
checkpoint = load_file(checkpoint_path)
|
checkpoint = load_file(checkpoint_path)
|
||||||
else:
|
else:
|
||||||
if scan_needed:
|
if scan_needed:
|
||||||
@ -1026,9 +1023,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
|
if model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"]["parameterization"] == "v":
|
||||||
|
prediction_type = "v_prediction"
|
||||||
|
upcast_attention = True
|
||||||
image_size = 768
|
image_size = 768
|
||||||
else:
|
else:
|
||||||
|
prediction_type = "epsilon"
|
||||||
|
upcast_attention = False
|
||||||
image_size = 512
|
image_size = 512
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -1083,7 +1084,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if model_type == "FrozenOpenCLIPEmbedder":
|
if model_type == "FrozenOpenCLIPEmbedder":
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||||
subfolder='tokenizer',
|
subfolder='tokenizer',
|
||||||
)
|
)
|
||||||
pipe = StableDiffusionPipeline(
|
pipe = StableDiffusionPipeline(
|
||||||
@ -1099,9 +1100,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
|
|
||||||
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
||||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||||
pipe = StableDiffusionPipeline(
|
pipe = StableDiffusionPipeline(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model.to(precision),
|
||||||
@ -1115,7 +1116,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
text_config = create_ldm_bert_config(original_config)
|
text_config = create_ldm_bert_config(original_config)
|
||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
tokenizer = BertTokenizerFast.from_pretrained(MODEL_ROOT / "bert-base-uncased")
|
tokenizer = BertTokenizerFast.from_pretrained(CONVERT_MODEL_ROOT / "bert-base-uncased")
|
||||||
pipe = LDMTextToImagePipeline(
|
pipe = LDMTextToImagePipeline(
|
||||||
vqvae=vae,
|
vqvae=vae,
|
||||||
bert=text_model,
|
bert=text_model,
|
||||||
@ -1131,7 +1132,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
def convert_ckpt_to_diffusers(
|
def convert_ckpt_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
model_root: Union[str, Path],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1139,9 +1139,6 @@ def convert_ckpt_to_diffusers(
|
|||||||
and in addition a path-like object indicating the location of the desired diffusers
|
and in addition a path-like object indicating the location of the desired diffusers
|
||||||
model to be written.
|
model to be written.
|
||||||
"""
|
"""
|
||||||
# setting global here to avoid massive changes late at night
|
|
||||||
global MODEL_ROOT
|
|
||||||
MODEL_ROOT = Path(model_root) / 'core/convert'
|
|
||||||
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
|
@ -1,118 +0,0 @@
|
|||||||
"""
|
|
||||||
Routines for downloading and installing models.
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import safetensors
|
|
||||||
import safetensors.torch
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from diffusers import ModelMixin
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from . import ModelManager
|
|
||||||
from .models import BaseModelType, ModelType, VariantType
|
|
||||||
from .model_probe import ModelProbe, ModelVariantInfo
|
|
||||||
from .model_cache import SilenceWarnings
|
|
||||||
|
|
||||||
class ModelInstall(object):
|
|
||||||
'''
|
|
||||||
This class is able to download and install several different kinds of
|
|
||||||
InvokeAI models. The helper function, if provided, is called on to distinguish
|
|
||||||
between v2-base and v2-768 stable diffusion pipelines. This usually involves
|
|
||||||
asking the user to select the proper type, as there is no way of distinguishing
|
|
||||||
the two type of v2 file programmatically (as far as I know).
|
|
||||||
'''
|
|
||||||
def __init__(self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
model_base_helper: Callable[[Path],BaseModelType]=None,
|
|
||||||
clobber:bool = False
|
|
||||||
):
|
|
||||||
'''
|
|
||||||
:param config: InvokeAI configuration object
|
|
||||||
:param model_base_helper: A function call that accepts the Path to a checkpoint model and returns a ModelType enum
|
|
||||||
:param clobber: If true, models with colliding names will be overwritten
|
|
||||||
'''
|
|
||||||
self.config = config
|
|
||||||
self.clogger = clobber
|
|
||||||
self.helper = model_base_helper
|
|
||||||
self.prober = ModelProbe()
|
|
||||||
|
|
||||||
def install_checkpoint_file(self, checkpoint: Path)->dict:
|
|
||||||
'''
|
|
||||||
Install the checkpoint file at path and return a
|
|
||||||
configuration entry that can be added to `models.yaml`.
|
|
||||||
Model checkpoints and VAEs will be converted into
|
|
||||||
diffusers before installation. Note that the model manager
|
|
||||||
does not hold entries for anything but diffusers pipelines,
|
|
||||||
and the configuration file stanzas returned from such models
|
|
||||||
can be safely ignored.
|
|
||||||
'''
|
|
||||||
model_info = self.prober.probe(checkpoint, self.helper)
|
|
||||||
if not model_info:
|
|
||||||
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
|
|
||||||
|
|
||||||
key = ModelManager.create_key(
|
|
||||||
model_name = checkpoint.stem,
|
|
||||||
base_model = model_info.base_type,
|
|
||||||
model_type = model_info.model_type,
|
|
||||||
)
|
|
||||||
destination_path = self._dest_path(model_info) / checkpoint
|
|
||||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._check_for_collision(destination_path)
|
|
||||||
stanza = {
|
|
||||||
key: dict(
|
|
||||||
name = checkpoint.stem,
|
|
||||||
description = f'{model_info.model_type} model {checkpoint.stem}',
|
|
||||||
base = model_info.base_model.value,
|
|
||||||
type = model_info.model_type.value,
|
|
||||||
variant = model_info.variant_type.value,
|
|
||||||
path = str(destination_path),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# non-pipeline; no conversion needed, just copy into right place
|
|
||||||
if model_info.model_type != ModelType.Pipeline:
|
|
||||||
shutil.copyfile(checkpoint, destination_path)
|
|
||||||
stanza[key].update({'format': 'checkpoint'})
|
|
||||||
|
|
||||||
# pipeline - conversion needed here
|
|
||||||
else:
|
|
||||||
destination_path = self._dest_path(model_info) / checkpoint.stem
|
|
||||||
config_file = self._pipeline_type_to_config_file(model_info.model_type)
|
|
||||||
|
|
||||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
|
||||||
with SilenceWarnings:
|
|
||||||
convert_ckpt_to_diffusers(
|
|
||||||
checkpoint,
|
|
||||||
destination_path,
|
|
||||||
extract_ema=True,
|
|
||||||
original_config_file=config_file,
|
|
||||||
scan_needed=False,
|
|
||||||
)
|
|
||||||
stanza[key].update({'format': 'folder',
|
|
||||||
'path': destination_path, # no suffix on this
|
|
||||||
})
|
|
||||||
|
|
||||||
return stanza
|
|
||||||
|
|
||||||
|
|
||||||
def _check_for_collision(self, path: Path):
|
|
||||||
if not path.exists():
|
|
||||||
return
|
|
||||||
if self.clobber:
|
|
||||||
shutil.rmtree(path)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Destination {path} already exists. Won't overwrite unless clobber=True.")
|
|
||||||
|
|
||||||
def _staging_directory(self)->tempfile.TemporaryDirectory:
|
|
||||||
return tempfile.TemporaryDirectory(dir=self.config.root_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,53 +1,209 @@
|
|||||||
"""This module manages the InvokeAI `models.yaml` file, mapping
|
"""This module manages the InvokeAI `models.yaml` file, mapping
|
||||||
symbolic diffusers model names to the paths and repo_ids used
|
symbolic diffusers model names to the paths and repo_ids used by the
|
||||||
by the underlying `from_pretrained()` call.
|
underlying `from_pretrained()` call.
|
||||||
|
|
||||||
For fetching models, use manager.get_model('symbolic name'). This will
|
SYNOPSIS:
|
||||||
return a ModelInfo object that contains the following attributes:
|
|
||||||
|
|
||||||
* context -- a context manager Generator that loads and locks the
|
|
||||||
model into GPU VRAM and returns the model for use.
|
|
||||||
See below for usage.
|
|
||||||
* name -- symbolic name of the model
|
|
||||||
* type -- SubModelType of the model
|
|
||||||
* hash -- unique hash for the model
|
|
||||||
* location -- path or repo_id of the model
|
|
||||||
* revision -- revision of the model if coming from a repo id,
|
|
||||||
e.g. 'fp16'
|
|
||||||
* precision -- torch precision of the model
|
|
||||||
|
|
||||||
Typical usage:
|
mgr = ModelManager('/home/phi/invokeai/configs/models.yaml')
|
||||||
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type=SubModelType.Unet)
|
||||||
|
with sd1_5 as unet:
|
||||||
|
run_some_inference(unet)
|
||||||
|
|
||||||
from invokeai.backend import ModelManager
|
FETCHING MODELS:
|
||||||
|
|
||||||
manager = ModelManager(
|
Models are described using four attributes:
|
||||||
config='./configs/models.yaml',
|
|
||||||
max_cache_size=8
|
|
||||||
) # gigabytes
|
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers)
|
1) model_name -- the symbolic name for the model
|
||||||
with model_info.context as my_model:
|
|
||||||
my_model.latents_from_embeddings(...)
|
|
||||||
|
|
||||||
The manager uses the underlying ModelCache class to keep
|
2) ModelType -- an enum describing the type of the model. Currently
|
||||||
frequently-used models in RAM and move them into GPU as needed for
|
defined types are:
|
||||||
generation operations. The optional `max_cache_size` argument
|
ModelType.Main -- a full model capable of generating images
|
||||||
indicates the maximum size the cache can grow to, in gigabytes. The
|
ModelType.Vae -- a VAE model
|
||||||
underlying ModelCache object can be accessed using the manager's "cache"
|
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
||||||
attribute.
|
ModelType.TextualInversion -- a textual inversion embedding
|
||||||
|
ModelType.ControlNet -- a ControlNet model
|
||||||
|
|
||||||
Because the model manager can return multiple different types of
|
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
||||||
models, you may wish to add additional type checking on the class
|
BaseModelType.StableDiffusion1
|
||||||
of model returned. To do this, provide the option `model_type`
|
BaseModelType.StableDiffusion2
|
||||||
parameter:
|
|
||||||
|
|
||||||
model_info = manager.get_model(
|
4) SubModelType (optional) -- an enum that refers to one of the submodels contained
|
||||||
'clip-tokenizer',
|
within the main model. Values are:
|
||||||
model_type=SubModelType.Tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
This will raise an InvalidModelError if the format defined in the
|
SubModelType.UNet
|
||||||
config file doesn't match the requested model type.
|
SubModelType.TextEncoder
|
||||||
|
SubModelType.Tokenizer
|
||||||
|
SubModelType.Scheduler
|
||||||
|
SubModelType.SafetyChecker
|
||||||
|
|
||||||
|
To fetch a model, use `manager.get_model()`. This takes the symbolic
|
||||||
|
name of the model, the ModelType, the BaseModelType and the
|
||||||
|
SubModelType. The latter is required for ModelType.Main.
|
||||||
|
|
||||||
|
get_model() will return a ModelInfo object that can then be used in
|
||||||
|
context to retrieve the model and move it into GPU VRAM (on GPU
|
||||||
|
systems).
|
||||||
|
|
||||||
|
A typical example is:
|
||||||
|
|
||||||
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type=SubModelType.Unet)
|
||||||
|
with sd1_5 as unet:
|
||||||
|
run_some_inference(unet)
|
||||||
|
|
||||||
|
The ModelInfo object provides a number of useful fields describing the
|
||||||
|
model, including:
|
||||||
|
|
||||||
|
name -- symbolic name of the model
|
||||||
|
base_model -- base model (BaseModelType)
|
||||||
|
type -- model type (ModelType)
|
||||||
|
location -- path to the model file
|
||||||
|
precision -- torch precision of the model
|
||||||
|
hash -- unique sha256 checksum for this model
|
||||||
|
|
||||||
|
SUBMODELS:
|
||||||
|
|
||||||
|
When fetching a main model, you must specify the submodel. Retrieval
|
||||||
|
of full pipelines is not supported.
|
||||||
|
|
||||||
|
vae_info = mgr.get_model('stable-diffusion-1.5',
|
||||||
|
model_type = ModelType.Main,
|
||||||
|
base_model = BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type = SubModelType.Vae
|
||||||
|
)
|
||||||
|
with vae_info as vae:
|
||||||
|
do_something(vae)
|
||||||
|
|
||||||
|
This rule does not apply to controlnets, embeddings, loras and standalone
|
||||||
|
VAEs, which do not have submodels.
|
||||||
|
|
||||||
|
LISTING MODELS
|
||||||
|
|
||||||
|
The model_names() method will return a list of Tuples describing each
|
||||||
|
model it knows about:
|
||||||
|
|
||||||
|
>> mgr.model_names()
|
||||||
|
[
|
||||||
|
('stable-diffusion-1.5', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Main: 'main'>),
|
||||||
|
('stable-diffusion-2.1', <BaseModelType.StableDiffusion2: 'sd-2'>, <ModelType.Main: 'main'>),
|
||||||
|
('inpaint', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.ControlNet: 'controlnet'>)
|
||||||
|
('Ink scenery', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Lora: 'lora'>)
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
The tuple is in the correct order to pass to get_model():
|
||||||
|
|
||||||
|
for m in mgr.model_names():
|
||||||
|
info = get_model(*m)
|
||||||
|
|
||||||
|
In contrast, the list_models() method returns a list of dicts, each
|
||||||
|
providing information about a model defined in models.yaml. For example:
|
||||||
|
|
||||||
|
>>> models = mgr.list_models()
|
||||||
|
>>> json.dumps(models[0])
|
||||||
|
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||||
|
"model_format": "diffusers",
|
||||||
|
"name": "canny",
|
||||||
|
"base_model": "sd-1",
|
||||||
|
"type": "controlnet"
|
||||||
|
}
|
||||||
|
|
||||||
|
You can filter by model type and base model as shown here:
|
||||||
|
|
||||||
|
|
||||||
|
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
|
||||||
|
base_model=BaseModelType.StableDiffusion1)
|
||||||
|
for c in controlnets:
|
||||||
|
name = c['name']
|
||||||
|
format = c['model_format']
|
||||||
|
path = c['path']
|
||||||
|
type = c['type']
|
||||||
|
# etc
|
||||||
|
|
||||||
|
ADDING AND REMOVING MODELS
|
||||||
|
|
||||||
|
At startup time, the `models` directory will be scanned for
|
||||||
|
checkpoints, diffusers pipelines, controlnets, LoRAs and TI
|
||||||
|
embeddings. New entries will be added to the model manager and defunct
|
||||||
|
ones removed. Anything that is a main model (ModelType.Main) will be
|
||||||
|
added to models.yaml. For scanning to succeed, files need to be in
|
||||||
|
their proper places. For example, a controlnet folder built on the
|
||||||
|
stable diffusion 2 base, will need to be placed in
|
||||||
|
`models/sd-2/controlnet`.
|
||||||
|
|
||||||
|
Layout of the `models` directory:
|
||||||
|
|
||||||
|
models
|
||||||
|
├── sd-1
|
||||||
|
│ ├── controlnet
|
||||||
|
│ ├── lora
|
||||||
|
│ ├── main
|
||||||
|
│ └── embedding
|
||||||
|
├── sd-2
|
||||||
|
│ ├── controlnet
|
||||||
|
│ ├── lora
|
||||||
|
│ ├── main
|
||||||
|
│ └── embedding
|
||||||
|
└── core
|
||||||
|
├── face_reconstruction
|
||||||
|
│ ├── codeformer
|
||||||
|
│ └── gfpgan
|
||||||
|
├── sd-conversion
|
||||||
|
│ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
||||||
|
│ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
||||||
|
│ └── stable-diffusion-safety-checker
|
||||||
|
└── upscaling
|
||||||
|
└─── esrgan
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed
|
||||||
|
explicitly in models.yaml, but are added to the in-memory data
|
||||||
|
structure at initialization time by scanning the models directory. The
|
||||||
|
in-memory data structure can be resynchronized by calling
|
||||||
|
`manager.scan_models_directory()`.
|
||||||
|
|
||||||
|
Files and folders placed inside the `autoimport` paths (paths
|
||||||
|
defined in `invokeai.yaml`) will also be scanned for new models at
|
||||||
|
initialization time and added to `models.yaml`. Files will not be
|
||||||
|
moved from this location but preserved in-place. These directories
|
||||||
|
are:
|
||||||
|
|
||||||
|
configuration default description
|
||||||
|
------------- ------- -----------
|
||||||
|
autoimport_dir autoimport/main main models
|
||||||
|
lora_dir autoimport/lora LoRA/LyCORIS models
|
||||||
|
embedding_dir autoimport/embedding TI embeddings
|
||||||
|
controlnet_dir autoimport/controlnet ControlNet models
|
||||||
|
|
||||||
|
In actuality, models located in any of these directories are scanned
|
||||||
|
to determine their type, so it isn't strictly necessary to organize
|
||||||
|
the different types in this way. This entry in `invokeai.yaml` will
|
||||||
|
recursively scan all subdirectories within `autoimport`, scan models
|
||||||
|
files it finds, and import them if recognized.
|
||||||
|
|
||||||
|
Paths:
|
||||||
|
autoimport_dir: autoimport
|
||||||
|
|
||||||
|
A model can be manually added using `add_model()` using the model's
|
||||||
|
name, base model, type and a dict of model attributes. See
|
||||||
|
`invokeai/backend/model_management/models` for the attributes required
|
||||||
|
by each model type.
|
||||||
|
|
||||||
|
A model can be deleted using `del_model()`, providing the same
|
||||||
|
identifying information as `get_model()`
|
||||||
|
|
||||||
|
The `heuristic_import()` method will take a set of strings
|
||||||
|
corresponding to local paths, remote URLs, and repo_ids, probe the
|
||||||
|
object to determine what type of model it is (if any), and import new
|
||||||
|
models into the manager. If passed a directory, it will recursively
|
||||||
|
scan it for models to import. The return value is a set of the models
|
||||||
|
successfully added.
|
||||||
|
|
||||||
MODELS.YAML
|
MODELS.YAML
|
||||||
|
|
||||||
@ -56,93 +212,18 @@ The general format of a models.yaml section is:
|
|||||||
type-of-model/name-of-model:
|
type-of-model/name-of-model:
|
||||||
path: /path/to/local/file/or/directory
|
path: /path/to/local/file/or/directory
|
||||||
description: a description
|
description: a description
|
||||||
format: folder|ckpt|safetensors|pt
|
format: diffusers|checkpoint
|
||||||
base: SD-1|SD-2
|
variant: normal|inpaint|depth
|
||||||
subfolder: subfolder-name
|
|
||||||
|
|
||||||
The type of model is given in the stanza key, and is one of
|
The type of model is given in the stanza key, and is one of
|
||||||
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
{main, vae, lora, controlnet, textual}
|
||||||
safety_checker, feature_extractor, lora, textual_inversion,
|
|
||||||
controlnet}, and correspond to items in the SubModelType enum defined
|
|
||||||
in model_cache.py
|
|
||||||
|
|
||||||
The format indicates whether the model is organized as a folder with
|
The format indicates whether the model is organized as a diffusers
|
||||||
model subdirectories, or is contained in a single checkpoint or
|
folder with model subdirectories, or is contained in a single
|
||||||
safetensors file.
|
checkpoint or safetensors file.
|
||||||
|
|
||||||
One, but not both, of repo_id and path are provided. repo_id is the
|
The path points to a file or directory on disk. If a relative path,
|
||||||
HuggingFace repository ID of the model, and path points to the file or
|
the root is the InvokeAI ROOTDIR.
|
||||||
directory on disk.
|
|
||||||
|
|
||||||
If subfolder is provided, then the model exists in a subdirectory of
|
|
||||||
the main model. These are usually named after the model type, such as
|
|
||||||
"unet".
|
|
||||||
|
|
||||||
This example summarizes the two ways of getting a non-diffuser model:
|
|
||||||
|
|
||||||
text_encoder/clip-test-1:
|
|
||||||
format: folder
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone CLIPTextModel
|
|
||||||
|
|
||||||
text_encoder/clip-test-2:
|
|
||||||
format: folder
|
|
||||||
repo_id: /path/to/folder
|
|
||||||
subfolder: text_encoder
|
|
||||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
|
||||||
|
|
||||||
SUBMODELS:
|
|
||||||
|
|
||||||
It is also possible to fetch an isolated submodel from a diffusers
|
|
||||||
model. Use the `submodel` parameter to select which part:
|
|
||||||
|
|
||||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
|
|
||||||
with vae.context as my_vae:
|
|
||||||
print(type(my_vae))
|
|
||||||
# "AutoencoderKL"
|
|
||||||
|
|
||||||
DIRECTORY_SCANNING:
|
|
||||||
|
|
||||||
Loras, textual_inversion and controlnet models are usually not listed
|
|
||||||
explicitly in models.yaml, but are added to the in-memory data
|
|
||||||
structure at initialization time by scanning the models directory. The
|
|
||||||
in-memory data structure can be resynchronized by calling
|
|
||||||
`manager.scan_models_directory`.
|
|
||||||
|
|
||||||
DISAMBIGUATION:
|
|
||||||
|
|
||||||
You may wish to use the same name for a related family of models. To
|
|
||||||
do this, disambiguate the stanza key with the model and and format
|
|
||||||
separated by "/". Example:
|
|
||||||
|
|
||||||
tokenizer/clip-large:
|
|
||||||
format: tokenizer
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone tokenizer
|
|
||||||
|
|
||||||
text_encoder/clip-large:
|
|
||||||
format: text_encoder
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone text encoder
|
|
||||||
|
|
||||||
You can now use the `model_type` argument to indicate which model you
|
|
||||||
want:
|
|
||||||
|
|
||||||
tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
|
|
||||||
encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
|
|
||||||
|
|
||||||
OTHER FUNCTIONS:
|
|
||||||
|
|
||||||
Other methods provided by ModelManager support importing, editing,
|
|
||||||
converting and deleting models.
|
|
||||||
|
|
||||||
IMPORTANT CHANGES AND LIMITATIONS SINCE 2.3:
|
|
||||||
|
|
||||||
1. Only local paths are supported. Repo_ids are no longer accepted. This
|
|
||||||
simplifies the logic.
|
|
||||||
|
|
||||||
2. VAEs can't be swapped in and out at load time. They must be baked
|
|
||||||
into the model when downloaded or converted.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -151,13 +232,11 @@ import os
|
|||||||
import hashlib
|
import hashlib
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from packaging import version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, List, Tuple, Union, types
|
from typing import Optional, List, Tuple, Union, Set, Callable, types
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import scan_cache_dir
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
@ -165,9 +244,13 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES
|
from .models import (
|
||||||
|
BaseModelType, ModelType, SubModelType,
|
||||||
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
|
ModelConfigBase,
|
||||||
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
@ -183,7 +266,6 @@ class ModelInfo():
|
|||||||
hash: str
|
hash: str
|
||||||
location: Union[Path, str]
|
location: Union[Path, str]
|
||||||
precision: torch.dtype
|
precision: torch.dtype
|
||||||
revision: str = None
|
|
||||||
_cache: ModelCache = None
|
_cache: ModelCache = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -199,31 +281,6 @@ class InvalidModelError(Exception):
|
|||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
# layout of the models directory:
|
|
||||||
# models
|
|
||||||
# ├── sd-1
|
|
||||||
# │ ├── controlnet
|
|
||||||
# │ ├── lora
|
|
||||||
# │ ├── pipeline
|
|
||||||
# │ └── textual_inversion
|
|
||||||
# ├── sd-2
|
|
||||||
# │ ├── controlnet
|
|
||||||
# │ ├── lora
|
|
||||||
# │ ├── pipeline
|
|
||||||
# │ └── textual_inversion
|
|
||||||
# └── core
|
|
||||||
# ├── face_reconstruction
|
|
||||||
# │ ├── codeformer
|
|
||||||
# │ └── gfpgan
|
|
||||||
# ├── sd-conversion
|
|
||||||
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
|
||||||
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
|
||||||
# │ └── stable-diffusion-safety-checker
|
|
||||||
# └── upscaling
|
|
||||||
# └─── esrgan
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
@ -271,7 +328,7 @@ class ModelManager(object):
|
|||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.globals = InvokeAIAppConfig.get_config()
|
self.app_config = InvokeAIAppConfig.get_config()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
@ -307,7 +364,8 @@ class ModelManager(object):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return f"{base_model}/{model_type}/{model_name}"
|
return f"{base_model}/{model_type}/{model_name}"
|
||||||
|
|
||||||
def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
@classmethod
|
||||||
|
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||||
base_model_str, model_type_str, model_name = model_key.split('/', 2)
|
base_model_str, model_type_str, model_name = model_key.split('/', 2)
|
||||||
try:
|
try:
|
||||||
model_type = ModelType(model_type_str)
|
model_type = ModelType(model_type_str)
|
||||||
@ -321,86 +379,44 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return (model_name, base_model, model_type)
|
return (model_name, base_model, model_type)
|
||||||
|
|
||||||
|
def _get_model_cache_path(self, model_path):
|
||||||
|
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None
|
submodel_type: Optional[SubModelType] = None
|
||||||
):
|
)->ModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an ModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: ModelType enum indicating the type of model to return
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
:param submode_typel: an ModelType enum indicating the portion of
|
:param submode_typel: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
|
|
||||||
If not provided, the model_type will be read from the `format` field
|
|
||||||
of the corresponding stanza. If provided, the model_type will be used
|
|
||||||
to disambiguate stanzas in the configuration file. The default is to
|
|
||||||
assume a diffusers pipeline. The behavior is illustrated here:
|
|
||||||
|
|
||||||
[models.yaml]
|
|
||||||
diffusers/test1:
|
|
||||||
repo_id: foo/bar
|
|
||||||
description: Typical diffusers pipeline
|
|
||||||
|
|
||||||
lora/test1:
|
|
||||||
repo_id: /tmp/loras/test1.safetensors
|
|
||||||
description: Typical lora file
|
|
||||||
|
|
||||||
test1_pipeline = mgr.get_model('test1')
|
|
||||||
# returns a StableDiffusionGeneratorPipeline
|
|
||||||
|
|
||||||
test1_vae1 = mgr.get_model('test1', submodel=ModelType.Vae)
|
|
||||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
|
||||||
|
|
||||||
test1_vae2 = mgr.get_model('test1', model_type=ModelType.Diffusers, submodel=ModelType.Vae)
|
|
||||||
# does the same thing as the previous statement. Note that model_type
|
|
||||||
# is for the parent model, and submodel is for the part
|
|
||||||
|
|
||||||
test1_lora = mgr.get_model('test1', model_type=ModelType.Lora)
|
|
||||||
# returns a LoRA embed (as a 'dict' of tensors)
|
|
||||||
|
|
||||||
test1_encoder = mgr.get_modelI('test1', model_type=ModelType.TextEncoder)
|
|
||||||
# raises an InvalidModelError
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
# if model not found try to find it (maybe file just pasted)
|
# if model not found try to find it (maybe file just pasted)
|
||||||
if model_key not in self.models:
|
if model_key not in self.models:
|
||||||
# TODO: find by mask or try rescan?
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
if model_key not in self.models:
|
||||||
if False: # model_path = next(find_by_mask(path_mask)):
|
|
||||||
model_path = None # TODO:
|
|
||||||
model_config = model_class.probe_config(model_path)
|
|
||||||
self.models[model_key] = model_config
|
|
||||||
else:
|
|
||||||
raise Exception(f"Model not found - {model_key}")
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# if it known model check that target path exists (if manualy deleted)
|
|
||||||
else:
|
|
||||||
# logic repeated twice(in rescan too) any way to optimize?
|
|
||||||
if not os.path.exists(self.models[model_key].path):
|
|
||||||
if model_class.save_to_config:
|
|
||||||
self.models[model_key].error = ModelError.NotFound
|
|
||||||
raise Exception(f"Files for model \"{model_key}\" not found")
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.models.pop(model_key, None)
|
|
||||||
raise Exception(f"Model not found - {model_key}")
|
|
||||||
|
|
||||||
# reset model errors?
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
|
model_path = self.app_config.root_path / model_config.path
|
||||||
|
|
||||||
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
|
if not model_path.exists():
|
||||||
# /models/{base_model}/{model_type}/{name}/
|
if model_class.save_to_config:
|
||||||
model_path = model_config.path
|
self.models[model_key].error = ModelError.NotFound
|
||||||
|
raise Exception(f"Files for model \"{model_key}\" not found")
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.models.pop(model_key, None)
|
||||||
|
raise Exception(f"Model not found - {model_key}")
|
||||||
|
|
||||||
# vae/movq override
|
# vae/movq override
|
||||||
# TODO:
|
# TODO:
|
||||||
@ -414,10 +430,10 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# TODO: path
|
# TODO: path
|
||||||
# TODO: is it accurate to use path as id
|
# TODO: is it accurate to use path as id
|
||||||
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(model_path.encode()).hexdigest()
|
dst_convert_path = self._get_model_cache_path(model_path)
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_path=model_path,
|
model_path=str(model_path), # TODO: refactor str/Path types logic
|
||||||
output_path=dst_convert_path,
|
output_path=dst_convert_path,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
)
|
)
|
||||||
@ -476,11 +492,6 @@ class ModelManager(object):
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a list of models.
|
Return a list of models.
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
|
||||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
|
||||||
object derived from models.yaml
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
@ -507,7 +518,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Print a table of models, their descriptions
|
Print a table of models and their descriptions. This needs to be redone
|
||||||
"""
|
"""
|
||||||
# TODO: redo
|
# TODO: redo
|
||||||
for model_type, model_dict in self.list_models().items():
|
for model_type, model_dict in self.list_models().items():
|
||||||
@ -515,7 +526,7 @@ class ModelManager(object):
|
|||||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
# TODO: test when ui implemented
|
# Tested - LS
|
||||||
def del_model(
|
def del_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -525,7 +536,6 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
"""
|
"""
|
||||||
raise Exception("TODO: del_model") # TODO: redo
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
model_cfg = self.models.pop(model_key, None)
|
model_cfg = self.models.pop(model_key, None)
|
||||||
|
|
||||||
@ -541,14 +551,18 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
# if model inside invoke models folder - delete files
|
# if model inside invoke models folder - delete files
|
||||||
if model_cfg.path.startswith("models/") or model_cfg.path.startswith("models\\"):
|
model_path = self.app_config.root_path / model_cfg.path
|
||||||
model_path = self.globals.root_dir / model_cfg.path
|
cache_path = self._get_model_cache_path(model_path)
|
||||||
if model_path.isdir():
|
if cache_path.exists():
|
||||||
shutil.rmtree(str(model_path))
|
rmtree(str(cache_path))
|
||||||
|
|
||||||
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
if model_path.is_dir():
|
||||||
|
rmtree(str(model_path))
|
||||||
else:
|
else:
|
||||||
model_path.unlink()
|
model_path.unlink()
|
||||||
|
|
||||||
# TODO: test when ui implemented
|
# LS: tested
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -569,18 +583,30 @@ class ModelManager(object):
|
|||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
assert (
|
if model_key in self.models and not clobber:
|
||||||
clobber or model_key not in self.models
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
), f'attempt to overwrite existing model definition "{model_key}"'
|
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
old_model = self.models.pop(model_key, None)
|
||||||
|
if old_model is not None:
|
||||||
if clobber and model_key in self.cache_keys:
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
|
# remove conversion cache as config changed
|
||||||
|
old_model_path = self.app_config.root_path / old_model.path
|
||||||
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
|
if old_model_cache.exists():
|
||||||
|
if old_model_cache.is_dir():
|
||||||
|
rmtree(str(old_model_cache))
|
||||||
|
else:
|
||||||
|
old_model_cache.unlink()
|
||||||
|
|
||||||
|
# remove in-memory cache
|
||||||
# note: it not garantie to release memory(model can has other references)
|
# note: it not garantie to release memory(model can has other references)
|
||||||
cache_ids = self.cache_keys.pop(model_key, [])
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
@ -621,7 +647,7 @@ class ModelManager(object):
|
|||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None,'no config file path to write to'
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
config_file_path = self.globals.root_dir / config_file_path
|
config_file_path = self.app_config.root_path / config_file_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(self.preamble())
|
outfile.write(self.preamble())
|
||||||
@ -644,42 +670,150 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def scan_models_directory(self):
|
def scan_models_directory(
|
||||||
|
self,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
|
):
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
for model_key, model_config in list(self.models.items()):
|
with Chdir(self.app_config.root_path):
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_path = str(self.globals.root / model_config.path)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if not os.path.exists(model_path):
|
model_path = self.app_config.root_path / model_config.path
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
model_config.error = ModelError.NotFound
|
if model_class.save_to_config:
|
||||||
|
model_config.error = ModelError.NotFound
|
||||||
|
else:
|
||||||
|
self.models.pop(model_key, None)
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
loaded_files.add(model_path)
|
||||||
else:
|
|
||||||
loaded_files.add(model_path)
|
|
||||||
|
|
||||||
for base_model in BaseModelType:
|
for cur_base_model in BaseModelType:
|
||||||
for model_type in ModelType:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
continue
|
||||||
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
|
|
||||||
|
|
||||||
if not os.path.exists(models_dir):
|
for cur_model_type in ModelType:
|
||||||
continue # TODO: or create all folders?
|
if model_type is not None and cur_model_type != model_type:
|
||||||
|
continue
|
||||||
for entry_name in os.listdir(models_dir):
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
model_path = os.path.join(models_dir, entry_name)
|
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
||||||
if model_path not in loaded_files: # TODO: check
|
|
||||||
model_name = Path(model_path).stem
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
if model_key in self.models:
|
if not models_dir.exists():
|
||||||
raise Exception(f"Model with key {model_key} added twice")
|
continue # TODO: or create all folders?
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
for model_path in models_dir.iterdir():
|
||||||
self.models[model_key] = model_config
|
if model_path not in loaded_files: # TODO: check
|
||||||
new_models_found = True
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
||||||
|
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
||||||
|
|
||||||
if new_models_found:
|
if model_key in self.models:
|
||||||
|
raise Exception(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
|
try:
|
||||||
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
new_models_found = True
|
||||||
|
except NotImplementedError as e:
|
||||||
|
self.logger.warning(e)
|
||||||
|
|
||||||
|
imported_models = self.autoimport()
|
||||||
|
|
||||||
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
def autoimport(self)->set[Path]:
|
||||||
|
'''
|
||||||
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
|
'''
|
||||||
|
# avoid circular import
|
||||||
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
installer = ModelInstall(config = self.app_config,
|
||||||
|
model_manager = self,
|
||||||
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
installed = set()
|
||||||
|
scanned_dirs = set()
|
||||||
|
|
||||||
|
config = self.app_config
|
||||||
|
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||||
|
|
||||||
|
for autodir in [config.autoimport_dir,
|
||||||
|
config.lora_dir,
|
||||||
|
config.embedding_dir,
|
||||||
|
config.controlnet_dir]:
|
||||||
|
if autodir is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.logger.info(f'Scanning {autodir} for models to import')
|
||||||
|
|
||||||
|
autodir = self.app_config.root_path / autodir
|
||||||
|
if not autodir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_scanned = 0
|
||||||
|
new_models_found = set()
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(autodir):
|
||||||
|
items_scanned += len(dirs) + len(files)
|
||||||
|
for d in dirs:
|
||||||
|
path = Path(root) / d
|
||||||
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
|
scanned_dirs.add(path)
|
||||||
|
continue
|
||||||
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
|
new_models_found.update(installer.heuristic_install(path))
|
||||||
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(root) / f
|
||||||
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
|
continue
|
||||||
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
|
new_models_found.update(installer.heuristic_install(path))
|
||||||
|
|
||||||
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
|
installed.update(new_models_found)
|
||||||
|
|
||||||
|
return installed
|
||||||
|
|
||||||
|
def heuristic_import(self,
|
||||||
|
items_to_import: Set[str],
|
||||||
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
)->Set[str]:
|
||||||
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
'''
|
||||||
|
# avoid circular import here
|
||||||
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
successfully_installed = set()
|
||||||
|
|
||||||
|
installer = ModelInstall(config = self.app_config,
|
||||||
|
prediction_type_helper = prediction_type_helper,
|
||||||
|
model_manager = self)
|
||||||
|
for thing in items_to_import:
|
||||||
|
try:
|
||||||
|
installed = installer.heuristic_install(thing)
|
||||||
|
successfully_installed.update(installed)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f'{thing} could not be imported: {str(e)}')
|
||||||
|
|
||||||
|
self.commit()
|
||||||
|
return successfully_installed
|
||||||
|
@ -1,27 +1,28 @@
|
|||||||
import json
|
import json
|
||||||
import traceback
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
|
from diffusers import ModelMixin, ConfigMixin
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal, Union, Dict
|
from typing import Callable, Literal, Union, Dict
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
from .models import (
|
||||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
BaseModelType, ModelType, ModelVariantType,
|
||||||
|
SchedulerPredictionType, SilenceWarnings,
|
||||||
|
)
|
||||||
|
from .models.base import read_checkpoint_meta
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelVariantInfo(object):
|
class ModelProbeInfo(object):
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal['folder','checkpoint']
|
format: Literal['diffusers','checkpoint', 'lycoris']
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
@ -31,19 +32,19 @@ class ProbeBase(object):
|
|||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
|
|
||||||
PROBES = {
|
PROBES = {
|
||||||
'folder': { },
|
'diffusers': { },
|
||||||
'checkpoint': { },
|
'checkpoint': { },
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
'StableDiffusionPipeline' : ModelType.Pipeline,
|
'StableDiffusionPipeline' : ModelType.Main,
|
||||||
'AutoencoderKL' : ModelType.Vae,
|
'AutoencoderKL' : ModelType.Vae,
|
||||||
'ControlNetModel' : ModelType.ControlNet,
|
'ControlNetModel' : ModelType.ControlNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls,
|
def register_probe(cls,
|
||||||
format: Literal['folder','file'],
|
format: Literal['diffusers','checkpoint'],
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
probe_class: ProbeBase):
|
probe_class: ProbeBase):
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
@ -51,8 +52,8 @@ class ModelProbe(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def heuristic_probe(cls,
|
def heuristic_probe(cls,
|
||||||
model: Union[Dict, ModelMixin, Path],
|
model: Union[Dict, ModelMixin, Path],
|
||||||
prediction_type_helper: Callable[[Path],BaseModelType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
)->ModelVariantInfo:
|
)->ModelProbeInfo:
|
||||||
if isinstance(model,Path):
|
if isinstance(model,Path):
|
||||||
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
||||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||||
@ -64,7 +65,7 @@ class ModelProbe(object):
|
|||||||
def probe(cls,
|
def probe(cls,
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
model: Union[Dict, ModelMixin] = None,
|
model: Union[Dict, ModelMixin] = None,
|
||||||
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
|
||||||
'''
|
'''
|
||||||
Probe the model at model_path and return sufficient information about it
|
Probe the model at model_path and return sufficient information about it
|
||||||
to place it somewhere in the models directory hierarchy. If the model is
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
@ -74,23 +75,24 @@ class ModelProbe(object):
|
|||||||
between V2-Base and V2-768 SD models.
|
between V2-Base and V2-768 SD models.
|
||||||
'''
|
'''
|
||||||
if model_path:
|
if model_path:
|
||||||
format = 'folder' if model_path.is_dir() else 'checkpoint'
|
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||||
else:
|
else:
|
||||||
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
if format == 'folder' \
|
if format_type == 'diffusers' \
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
probe_class = cls.PROBES[format].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
probe = probe_class(model_path, model, prediction_type_helper)
|
probe = probe_class(model_path, model, prediction_type_helper)
|
||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
variant_type = probe.get_variant_type()
|
variant_type = probe.get_variant_type()
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
prediction_type = probe.get_scheduler_prediction_type()
|
||||||
model_info = ModelVariantInfo(
|
format = probe.get_format()
|
||||||
|
model_info = ModelProbeInfo(
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
base_type = base_type,
|
base_type = base_type,
|
||||||
variant_type = variant_type,
|
variant_type = variant_type,
|
||||||
@ -102,32 +104,40 @@ class ModelProbe(object):
|
|||||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||||
) else 512,
|
) else 512,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||||
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors'):
|
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
|
||||||
return None
|
return None
|
||||||
if model_path.name=='learned_embeds.bin':
|
|
||||||
|
if model_path.name == "learned_embeds.bin":
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
|
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
return ModelType.Pipeline
|
|
||||||
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
|
for key in ckpt.keys():
|
||||||
return ModelType.Vae
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
if "string_to_token" in state_dict or "emb_params" in state_dict:
|
return ModelType.Main
|
||||||
return ModelType.TextualInversion
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
if any([x.startswith("lora") for x in state_dict.keys()]):
|
return ModelType.Vae
|
||||||
return ModelType.Lora
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
if any([x.startswith("control_model") for x in state_dict.keys()]):
|
return ModelType.Lora
|
||||||
return ModelType.ControlNet
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
|
return ModelType.ControlNet
|
||||||
return ModelType.ControlNet
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
return None # give up
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
else:
|
||||||
|
# diffusers-ti
|
||||||
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
raise ValueError("Unable to determine model type")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
@ -192,11 +202,14 @@ class ProbeBase(object):
|
|||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
pass
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
class CheckpointProbeBase(ProbeBase):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
checkpoint_path: Path,
|
checkpoint_path: Path,
|
||||||
checkpoint: dict,
|
checkpoint: dict,
|
||||||
helper: Callable[[Path],BaseModelType] = None
|
helper: Callable[[Path],SchedulerPredictionType] = None
|
||||||
)->BaseModelType:
|
)->BaseModelType:
|
||||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||||
self.checkpoint_path = checkpoint_path
|
self.checkpoint_path = checkpoint_path
|
||||||
@ -205,9 +218,12 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'checkpoint'
|
||||||
|
|
||||||
def get_variant_type(self)-> ModelVariantType:
|
def get_variant_type(self)-> ModelVariantType:
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
||||||
if model_type != ModelType.Pipeline:
|
if model_type != ModelType.Main:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
||||||
in_channels = state_dict[
|
in_channels = state_dict[
|
||||||
@ -246,7 +262,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return SchedulerPredictionType.Epsilon
|
return SchedulerPredictionType.Epsilon
|
||||||
elif checkpoint["global_step"] == 110000:
|
elif checkpoint["global_step"] == 110000:
|
||||||
return SchedulerPredictionType.VPrediction
|
return SchedulerPredictionType.VPrediction
|
||||||
if self.checkpoint_path and self.helper:
|
if self.checkpoint_path and self.helper \
|
||||||
|
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
|
||||||
return self.helper(self.checkpoint_path)
|
return self.helper(self.checkpoint_path)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -257,6 +274,9 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'lycoris'
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
@ -276,6 +296,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
if 'string_to_token' in checkpoint:
|
if 'string_to_token' in checkpoint:
|
||||||
@ -322,17 +345,16 @@ class FolderProbeBase(ProbeBase):
|
|||||||
def get_variant_type(self)->ModelVariantType:
|
def get_variant_type(self)->ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'diffusers'
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
if self.model:
|
if self.model:
|
||||||
unet_conf = self.model.unet.config
|
unet_conf = self.model.unet.config
|
||||||
scheduler_conf = self.model.scheduler.config
|
|
||||||
else:
|
else:
|
||||||
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
||||||
unet_conf = json.load(file)
|
unet_conf = json.load(file)
|
||||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
|
||||||
scheduler_conf = json.load(file)
|
|
||||||
|
|
||||||
if unet_conf['cross_attention_dim'] == 768:
|
if unet_conf['cross_attention_dim'] == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif unet_conf['cross_attention_dim'] == 1024:
|
elif unet_conf['cross_attention_dim'] == 1024:
|
||||||
@ -381,6 +403,9 @@ class VaeFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
path = self.folder_path / 'learned_embeds.bin'
|
path = self.folder_path / 'learned_embeds.bin'
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@ -401,16 +426,24 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
else BaseModelType.StableDiffusion2
|
else BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
class LoRAFolderProbe(FolderProbeBase):
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
# I've never seen one of these in the wild, so this is a noop
|
def get_base_type(self)->BaseModelType:
|
||||||
pass
|
model_file = None
|
||||||
|
for suffix in ['safetensors','bin']:
|
||||||
|
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
|
||||||
|
if base_file.exists():
|
||||||
|
model_file = base_file
|
||||||
|
break
|
||||||
|
if not model_file:
|
||||||
|
raise Exception('Unknown LoRA format encountered')
|
||||||
|
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
|
@ -11,21 +11,21 @@ from .textual_inversion import TextualInversionModel
|
|||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelType.Pipeline: StableDiffusion1Model,
|
ModelType.Main: StableDiffusion1Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelType.Pipeline: StableDiffusion2Model,
|
ModelType.Main: StableDiffusion2Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
#BaseModelType.Kandinsky2_1: {
|
#BaseModelType.Kandinsky2_1: {
|
||||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
# ModelType.Main: Kandinsky2_1Model,
|
||||||
# ModelType.MoVQ: MoVQModel,
|
# ModelType.MoVQ: MoVQModel,
|
||||||
# ModelType.Lora: LoRAModel,
|
# ModelType.Lora: LoRAModel,
|
||||||
# ModelType.ControlNet: ControlNetModel,
|
# ModelType.ControlNet: ControlNetModel,
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
import inspect
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
@ -18,7 +21,7 @@ class BaseModelType(str, Enum):
|
|||||||
#Kandinsky2_1 = "kandinsky-2.1"
|
#Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
Pipeline = "pipeline"
|
Main = "main"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
@ -56,7 +59,6 @@ class ModelConfigBase(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
use_enum_values = True
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(cls, *args, **kwargs):
|
def load_config(cls, *args, **kwargs):
|
||||||
@ -124,7 +126,10 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
if hasattr(inspect,'get_annotations'):
|
||||||
|
fields = inspect.get_annotations(value)
|
||||||
|
else:
|
||||||
|
fields = value.__annotations__
|
||||||
try:
|
try:
|
||||||
field = fields["model_format"]
|
field = fields["model_format"]
|
||||||
except:
|
except:
|
||||||
@ -383,15 +388,18 @@ def _fast_safetensors_reader(path: str):
|
|||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||||
def read_checkpoint_meta(path: str):
|
if str(path).endswith(".safetensors"):
|
||||||
if path.endswith(".safetensors"):
|
|
||||||
try:
|
try:
|
||||||
checkpoint = _fast_safetensors_reader(path)
|
checkpoint = _fast_safetensors_reader(path)
|
||||||
except:
|
except:
|
||||||
# TODO: create issue for support "meta"?
|
# TODO: create issue for support "meta"?
|
||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
|
if scan:
|
||||||
|
scan_result = scan_file_path(path)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
@ -34,17 +34,17 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
assert model_type == ModelType.Pipeline
|
assert model_type == ModelType.Main
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -69,7 +69,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||||
@ -81,6 +81,8 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
raise Exception("Unkown stable diffusion 1.* model format")
|
raise Exception("Unkown stable diffusion 1.* model format")
|
||||||
|
|
||||||
|
if ckpt_config_path is None:
|
||||||
|
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
|
||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
@ -109,14 +111,12 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
) # TODO: args
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
@ -131,25 +131,20 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
|
||||||
upcast_attention: bool
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion2
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
assert model_type == ModelType.Pipeline
|
assert model_type == ModelType.Main
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=BaseModelType.StableDiffusion2,
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -188,13 +183,8 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
raise Exception("Unkown stable diffusion 2.* model format")
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
if variant == ModelVariantType.Normal:
|
if ckpt_config_path is None:
|
||||||
prediction_type = SchedulerPredictionType.VPrediction
|
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
|
||||||
upcast_attention = True
|
|
||||||
|
|
||||||
else:
|
|
||||||
prediction_type = SchedulerPredictionType.Epsilon
|
|
||||||
upcast_attention = False
|
|
||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
@ -202,8 +192,6 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
prediction_type=prediction_type,
|
|
||||||
upcast_attention=upcast_attention,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classproperty
|
@classproperty
|
||||||
@ -225,14 +213,12 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion2,
|
version=BaseModelType.StableDiffusion2,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
) # TODO: args
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
@ -243,18 +229,18 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
# code further will manually set upcast_attention and v_prediction
|
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
||||||
ModelVariantType.Normal: "v2-inference.yaml",
|
|
||||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
try:
|
try:
|
||||||
# TODO: path
|
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
||||||
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
|
if config_path.is_relative_to(app_config.root_path):
|
||||||
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
|
config_path = config_path.relative_to(app_config.root_path)
|
||||||
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
|
return str(config_path)
|
||||||
|
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
@ -273,36 +259,14 @@ def _convert_ckpt_and_cache(
|
|||||||
"""
|
"""
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
if model_config.config is None:
|
|
||||||
model_config.config = _select_ckpt_config(version, model_config.variant)
|
|
||||||
if model_config.config is None:
|
|
||||||
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
|
|
||||||
|
|
||||||
|
|
||||||
weights = app_config.root_path / model_config.path
|
weights = app_config.root_path / model_config.path
|
||||||
config_file = app_config.root_path / model_config.config
|
config_file = app_config.root_path / model_config.config
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
if version == BaseModelType.StableDiffusion1:
|
|
||||||
upcast_attention = False
|
|
||||||
prediction_type = SchedulerPredictionType.Epsilon
|
|
||||||
|
|
||||||
elif version == BaseModelType.StableDiffusion2:
|
|
||||||
upcast_attention = model_config.upcast_attention
|
|
||||||
prediction_type = model_config.prediction_type
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unknown model provided: {version}")
|
|
||||||
|
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
# TODO: I think that it more correctly to convert with embedded vae
|
|
||||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
|
||||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
|
||||||
|
|
||||||
# to avoid circular import errors
|
# to avoid circular import errors
|
||||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
@ -313,9 +277,6 @@ def _convert_ckpt_and_cache(
|
|||||||
model_variant=model_config.variant,
|
model_variant=model_config.variant,
|
||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
upcast_attention=upcast_attention,
|
|
||||||
prediction_type=prediction_type,
|
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
model_root=app_config.models_path,
|
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -16,6 +16,7 @@ from .util import (
|
|||||||
download_with_resume,
|
download_with_resume,
|
||||||
instantiate_from_config,
|
instantiate_from_config,
|
||||||
url_attachment_name,
|
url_attachment_name,
|
||||||
|
Chdir
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -381,3 +381,18 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
|||||||
buffered.getvalue()
|
buffered.getvalue()
|
||||||
).decode("UTF-8")
|
).decode("UTF-8")
|
||||||
return image_base64
|
return image_base64
|
||||||
|
|
||||||
|
class Chdir(object):
|
||||||
|
'''Context manager to chdir to desired directory and change back after context exits:
|
||||||
|
Args:
|
||||||
|
path (Path): The path to the cwd
|
||||||
|
'''
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.path = path
|
||||||
|
self.original = Path().absolute()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
os.chdir(self.path)
|
||||||
|
|
||||||
|
def __exit__(self,*args):
|
||||||
|
os.chdir(self.original)
|
||||||
|
@ -1,107 +1,92 @@
|
|||||||
# This file predefines a few models that the user may want to install.
|
# This file predefines a few models that the user may want to install.
|
||||||
diffusers:
|
sd-1/main/stable-diffusion-v1-5:
|
||||||
stable-diffusion-1.5:
|
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
repo_id: runwayml/stable-diffusion-v1-5
|
||||||
repo_id: runwayml/stable-diffusion-v1-5
|
recommended: True
|
||||||
format: diffusers
|
default: True
|
||||||
vae:
|
sd-1/main/stable-diffusion-inpainting:
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||||
recommended: True
|
repo_id: runwayml/stable-diffusion-inpainting
|
||||||
default: True
|
recommended: True
|
||||||
sd-inpainting-1.5:
|
sd-2/main/stable-diffusion-2-1:
|
||||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||||
repo_id: runwayml/stable-diffusion-inpainting
|
repo_id: stabilityai/stable-diffusion-2-1
|
||||||
format: diffusers
|
recommended: True
|
||||||
vae:
|
sd-2/main/stable-diffusion-2-inpainting:
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||||
recommended: True
|
repo_id: stabilityai/stable-diffusion-2-inpainting
|
||||||
stable-diffusion-2.1:
|
recommended: False
|
||||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
sd-1/main/Analog-Diffusion:
|
||||||
repo_id: stabilityai/stable-diffusion-2-1
|
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||||
format: diffusers
|
repo_id: wavymulder/Analog-Diffusion
|
||||||
recommended: True
|
recommended: false
|
||||||
sd-inpainting-2.0:
|
sd-1/main/Deliberate:
|
||||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||||
repo_id: stabilityai/stable-diffusion-2-inpainting
|
repo_id: XpucT/Deliberate
|
||||||
format: diffusers
|
recommended: False
|
||||||
recommended: False
|
sd-1/main/Dungeons-and-Diffusion:
|
||||||
analog-diffusion-1.0:
|
description: Dungeons & Dragons characters (2.13 GB)
|
||||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
repo_id: 0xJustin/Dungeons-and-Diffusion
|
||||||
repo_id: wavymulder/Analog-Diffusion
|
recommended: False
|
||||||
format: diffusers
|
sd-1/main/dreamlike-photoreal-2:
|
||||||
recommended: false
|
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||||
deliberate-1.0:
|
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
||||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
recommended: False
|
||||||
format: diffusers
|
sd-1/main/Inkpunk-Diffusion:
|
||||||
repo_id: XpucT/Deliberate
|
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
||||||
recommended: False
|
repo_id: Envvi/Inkpunk-Diffusion
|
||||||
d&d-diffusion-1.0:
|
recommended: False
|
||||||
description: Dungeons & Dragons characters (2.13 GB)
|
sd-1/main/openjourney:
|
||||||
format: diffusers
|
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
||||||
repo_id: 0xJustin/Dungeons-and-Diffusion
|
repo_id: prompthero/openjourney
|
||||||
recommended: False
|
recommended: False
|
||||||
dreamlike-photoreal-2.0:
|
sd-1/main/portraitplus:
|
||||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)
|
||||||
format: diffusers
|
repo_id: wavymulder/portraitplus
|
||||||
repo_id: dreamlike-art/dreamlike-photoreal-2.0
|
recommended: False
|
||||||
recommended: False
|
sd-1/main/seek.art_MEGA:
|
||||||
inkpunk-1.0:
|
repo_id: coreco/seek.art_MEGA
|
||||||
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
||||||
format: diffusers
|
recommended: False
|
||||||
repo_id: Envvi/Inkpunk-Diffusion
|
sd-1/main/trinart_stable_diffusion_v2:
|
||||||
recommended: False
|
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||||
openjourney-4.0:
|
repo_id: naclbit/trinart_stable_diffusion_v2
|
||||||
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
recommended: False
|
||||||
format: diffusers
|
sd-1/main/waifu-diffusion:
|
||||||
repo_id: prompthero/openjourney
|
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
|
||||||
vae:
|
repo_id: hakurei/waifu-diffusion
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
recommended: False
|
||||||
recommended: False
|
sd-1/controlnet/canny:
|
||||||
portrait-plus-1.0:
|
repo_id: lllyasviel/control_v11p_sd15_canny
|
||||||
description: An SD-1.5 model trained on close range portraits of people; prompt with "portrait+" (2.13 GB)
|
sd-1/controlnet/inpaint:
|
||||||
format: diffusers
|
repo_id: lllyasviel/control_v11p_sd15_inpaint
|
||||||
repo_id: wavymulder/portraitplus
|
sd-1/controlnet/mlsd:
|
||||||
recommended: False
|
repo_id: lllyasviel/control_v11p_sd15_mlsd
|
||||||
seek-art-mega-1.0:
|
sd-1/controlnet/depth:
|
||||||
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
repo_id: lllyasviel/control_v11f1p_sd15_depth
|
||||||
repo_id: coreco/seek.art_MEGA
|
sd-1/controlnet/normal_bae:
|
||||||
format: diffusers
|
repo_id: lllyasviel/control_v11p_sd15_normalbae
|
||||||
vae:
|
sd-1/controlnet/seg:
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
repo_id: lllyasviel/control_v11p_sd15_seg
|
||||||
recommended: False
|
sd-1/controlnet/lineart:
|
||||||
trinart-2.0:
|
repo_id: lllyasviel/control_v11p_sd15_lineart
|
||||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
sd-1/controlnet/lineart_anime:
|
||||||
repo_id: naclbit/trinart_stable_diffusion_v2
|
repo_id: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
format: diffusers
|
sd-1/controlnet/scribble:
|
||||||
vae:
|
repo_id: lllyasviel/control_v11p_sd15_scribble
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
sd-1/controlnet/softedge:
|
||||||
recommended: False
|
repo_id: lllyasviel/control_v11p_sd15_softedge
|
||||||
waifu-diffusion-1.4:
|
sd-1/controlnet/shuffle:
|
||||||
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
|
repo_id: lllyasviel/control_v11e_sd15_shuffle
|
||||||
repo_id: hakurei/waifu-diffusion
|
sd-1/controlnet/tile:
|
||||||
format: diffusers
|
repo_id: lllyasviel/control_v11f1e_sd15_tile
|
||||||
vae:
|
sd-1/controlnet/ip2p:
|
||||||
repo_id: stabilityai/sd-vae-ft-mse
|
repo_id: lllyasviel/control_v11e_sd15_ip2p
|
||||||
recommended: False
|
sd-1/embedding/EasyNegative:
|
||||||
controlnet:
|
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||||
canny: lllyasviel/control_v11p_sd15_canny
|
sd-1/embedding/ahx-beta-453407d:
|
||||||
inpaint: lllyasviel/control_v11p_sd15_inpaint
|
repo_id: sd-concepts-library/ahx-beta-453407d
|
||||||
mlsd: lllyasviel/control_v11p_sd15_mlsd
|
sd-1/lora/LowRA:
|
||||||
depth: lllyasviel/control_v11f1p_sd15_depth
|
path: https://civitai.com/api/download/models/63006
|
||||||
normal_bae: lllyasviel/control_v11p_sd15_normalbae
|
sd-1/lora/Ink scenery:
|
||||||
seg: lllyasviel/control_v11p_sd15_seg
|
path: https://civitai.com/api/download/models/83390
|
||||||
lineart: lllyasviel/control_v11p_sd15_lineart
|
|
||||||
lineart_anime: lllyasviel/control_v11p_sd15s2_lineart_anime
|
|
||||||
scribble: lllyasviel/control_v11p_sd15_scribble
|
|
||||||
softedge: lllyasviel/control_v11p_sd15_softedge
|
|
||||||
shuffle: lllyasviel/control_v11e_sd15_shuffle
|
|
||||||
tile: lllyasviel/control_v11f1e_sd15_tile
|
|
||||||
ip2p: lllyasviel/control_v11e_sd15_ip2p
|
|
||||||
textual_inversion:
|
|
||||||
'EasyNegative': https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
|
||||||
'ahx-beta-453407d': sd-concepts-library/ahx-beta-453407d
|
|
||||||
lora:
|
|
||||||
'LowRA': https://civitai.com/api/download/models/63006
|
|
||||||
'Ink scenery': https://civitai.com/api/download/models/83390
|
|
||||||
'sd-model-finetuned-lora-t4': sayakpaul/sd-model-finetuned-lora-t4
|
|
||||||
|
|
||||||
|
159
invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml
Normal file
159
invokeai/configs/stable-diffusion/v2-inpainting-inference-v.yaml
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
parameterization: "v"
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: null # for concat as in LAION-A
|
||||||
|
p_unsafe_threshold: 0.1
|
||||||
|
filter_word_list: "data/filters.yaml"
|
||||||
|
max_pwatermark: 0.45
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 6
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: True
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 50 # todo check these out for depth2img,
|
||||||
|
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
158
invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
158
invokeai/configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: hybrid
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
finetune_keys: null
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
#attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: null # for concat as in LAION-A
|
||||||
|
p_unsafe_threshold: 0.1
|
||||||
|
filter_word_list: "data/filters.yaml"
|
||||||
|
max_pwatermark: 0.45
|
||||||
|
batch_size: 8
|
||||||
|
num_workers: 6
|
||||||
|
multinode: True
|
||||||
|
min_size: 512
|
||||||
|
train:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||||
|
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards:
|
||||||
|
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
postprocess:
|
||||||
|
target: ldm.data.laion.AddMask
|
||||||
|
params:
|
||||||
|
mode: "512train-large"
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: True
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
metrics_over_trainsteps_checkpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
enable_autocast: False
|
||||||
|
disabled: False
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 5.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
ddim_steps: 50 # todo check these out for depth2img,
|
||||||
|
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
@ -11,7 +11,6 @@ The work is actually done in backend code in model_install_backend.py.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
@ -20,28 +19,22 @@ from multiprocessing import Process
|
|||||||
from multiprocessing.connection import Connection, Pipe
|
from multiprocessing.connection import Connection, Pipe
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import get_terminal_size
|
from shutil import get_terminal_size
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import npyscreen
|
import npyscreen
|
||||||
import torch
|
import torch
|
||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
Dataset_path,
|
|
||||||
default_config_file,
|
|
||||||
default_dataset,
|
|
||||||
install_requested_models,
|
|
||||||
recommended_datasets,
|
|
||||||
ModelInstallList,
|
ModelInstallList,
|
||||||
UserSelections,
|
InstallSelections,
|
||||||
|
ModelInstall,
|
||||||
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend.model_management import ModelManager, ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
CenteredTitleText,
|
CenteredTitleText,
|
||||||
MultiSelectColumns,
|
MultiSelectColumns,
|
||||||
@ -58,6 +51,7 @@ from invokeai.frontend.install.widgets import (
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
# build a table mapping all non-printable characters to None
|
# build a table mapping all non-printable characters to None
|
||||||
# for stripping control characters
|
# for stripping control characters
|
||||||
@ -71,8 +65,8 @@ def make_printable(s:str)->str:
|
|||||||
return s.translate(NOPRINT_TRANS_TABLE)
|
return s.translate(NOPRINT_TRANS_TABLE)
|
||||||
|
|
||||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing set to False, but this seems to cause a crash!
|
||||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||||
|
|
||||||
# for persistence
|
# for persistence
|
||||||
current_tab = 0
|
current_tab = 0
|
||||||
@ -90,25 +84,10 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
if not config.model_conf_path.exists():
|
if not config.model_conf_path.exists():
|
||||||
with open(config.model_conf_path,'w') as file:
|
with open(config.model_conf_path,'w') as file:
|
||||||
print('# InvokeAI model configuration file',file=file)
|
print('# InvokeAI model configuration file',file=file)
|
||||||
model_manager = ModelManager(config.model_conf_path)
|
self.installer = ModelInstall(config)
|
||||||
|
self.all_models = self.installer.all_models()
|
||||||
self.starter_models = OmegaConf.load(Dataset_path)['diffusers']
|
self.starter_models = self.installer.starter_models()
|
||||||
self.installed_diffusers_models = self.list_additional_diffusers_models(
|
self.model_labels = self._get_model_labels()
|
||||||
model_manager,
|
|
||||||
self.starter_models,
|
|
||||||
)
|
|
||||||
self.installed_cn_models = model_manager.list_controlnet_models()
|
|
||||||
self.installed_lora_models = model_manager.list_lora_models()
|
|
||||||
self.installed_ti_models = model_manager.list_ti_models()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.existing_models = OmegaConf.load(default_config_file())
|
|
||||||
except:
|
|
||||||
self.existing_models = dict()
|
|
||||||
|
|
||||||
self.starter_model_list = list(self.starter_models.keys())
|
|
||||||
self.installed_models = dict()
|
|
||||||
|
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
|
|
||||||
self.nextrely -= 1
|
self.nextrely -= 1
|
||||||
@ -141,39 +120,37 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
scroll_exit = True,
|
scroll_exit = True,
|
||||||
)
|
)
|
||||||
self.tabs.on_changed = self._toggle_tables
|
self.tabs.on_changed = self._toggle_tables
|
||||||
|
|
||||||
top_of_table = self.nextrely
|
top_of_table = self.nextrely
|
||||||
self.starter_diffusers_models = self.add_starter_diffusers()
|
self.starter_pipelines = self.add_starter_pipelines()
|
||||||
bottom_of_table = self.nextrely
|
bottom_of_table = self.nextrely
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.diffusers_models = self.add_diffusers_widgets(
|
self.pipeline_models = self.add_pipeline_widgets(
|
||||||
predefined_models=self.installed_diffusers_models,
|
model_type=ModelType.Main,
|
||||||
model_type='Diffusers',
|
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
|
exclude = self.starter_models
|
||||||
)
|
)
|
||||||
|
# self.pipeline_models['autoload_pending'] = True
|
||||||
bottom_of_table = max(bottom_of_table,self.nextrely)
|
bottom_of_table = max(bottom_of_table,self.nextrely)
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.controlnet_models = self.add_model_widgets(
|
self.controlnet_models = self.add_model_widgets(
|
||||||
predefined_models=self.installed_cn_models,
|
model_type=ModelType.ControlNet,
|
||||||
model_type='ControlNet',
|
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
)
|
)
|
||||||
bottom_of_table = max(bottom_of_table,self.nextrely)
|
bottom_of_table = max(bottom_of_table,self.nextrely)
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.lora_models = self.add_model_widgets(
|
self.lora_models = self.add_model_widgets(
|
||||||
predefined_models=self.installed_lora_models,
|
model_type=ModelType.Lora,
|
||||||
model_type="LoRA/LyCORIS",
|
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
)
|
)
|
||||||
bottom_of_table = max(bottom_of_table,self.nextrely)
|
bottom_of_table = max(bottom_of_table,self.nextrely)
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.ti_models = self.add_model_widgets(
|
self.ti_models = self.add_model_widgets(
|
||||||
predefined_models=self.installed_ti_models,
|
model_type=ModelType.TextualInversion,
|
||||||
model_type="Textual Inversion Embeddings",
|
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
)
|
)
|
||||||
bottom_of_table = max(bottom_of_table,self.nextrely)
|
bottom_of_table = max(bottom_of_table,self.nextrely)
|
||||||
@ -184,7 +161,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
BufferBox,
|
BufferBox,
|
||||||
name='Log Messages',
|
name='Log Messages',
|
||||||
editable=False,
|
editable=False,
|
||||||
max_height = 16,
|
max_height = 10,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
@ -197,13 +174,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
rely=-3,
|
rely=-3,
|
||||||
when_pressed_function=self.on_back,
|
when_pressed_function=self.on_back,
|
||||||
)
|
)
|
||||||
self.ok_button = self.add_widget_intelligent(
|
else:
|
||||||
npyscreen.ButtonPress,
|
self.ok_button = self.add_widget_intelligent(
|
||||||
name=done_label,
|
npyscreen.ButtonPress,
|
||||||
relx=(window_width - len(done_label)) // 2,
|
name=done_label,
|
||||||
rely=-3,
|
relx=(window_width - len(done_label)) // 2,
|
||||||
when_pressed_function=self.on_execute
|
rely=-3,
|
||||||
)
|
when_pressed_function=self.on_execute
|
||||||
|
)
|
||||||
|
|
||||||
label = "APPLY CHANGES & EXIT"
|
label = "APPLY CHANGES & EXIT"
|
||||||
self.done = self.add_widget_intelligent(
|
self.done = self.add_widget_intelligent(
|
||||||
@ -220,18 +198,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self._toggle_tables([self.current_tab])
|
self._toggle_tables([self.current_tab])
|
||||||
|
|
||||||
############# diffusers tab ##########
|
############# diffusers tab ##########
|
||||||
def add_starter_diffusers(self)->dict[str, npyscreen.widget]:
|
def add_starter_pipelines(self)->dict[str, npyscreen.widget]:
|
||||||
'''Add widgets responsible for selecting diffusers models'''
|
'''Add widgets responsible for selecting diffusers models'''
|
||||||
widgets = dict()
|
widgets = dict()
|
||||||
|
models = self.all_models
|
||||||
starter_model_labels = self._get_starter_model_labels()
|
starters = self.starter_models
|
||||||
recommended_models = [
|
starter_model_labels = self.model_labels
|
||||||
x
|
|
||||||
for x in self.starter_model_list
|
|
||||||
if self.starter_models[x].get("recommended", False)
|
|
||||||
]
|
|
||||||
self.installed_models = sorted(
|
self.installed_models = sorted(
|
||||||
[x for x in list(self.starter_models.keys()) if x in self.existing_models]
|
[x for x in starters if models[x].installed]
|
||||||
)
|
)
|
||||||
|
|
||||||
widgets.update(
|
widgets.update(
|
||||||
@ -246,55 +221,46 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.nextrely -= 1
|
self.nextrely -= 1
|
||||||
# if user has already installed some initial models, then don't patronize them
|
# if user has already installed some initial models, then don't patronize them
|
||||||
# by showing more recommendations
|
# by showing more recommendations
|
||||||
show_recommended = not self.existing_models
|
show_recommended = len(self.installed_models)==0
|
||||||
|
keys = [x for x in models.keys() if x in starters]
|
||||||
widgets.update(
|
widgets.update(
|
||||||
models_selected = self.add_widget_intelligent(
|
models_selected = self.add_widget_intelligent(
|
||||||
MultiSelectColumns,
|
MultiSelectColumns,
|
||||||
columns=1,
|
columns=1,
|
||||||
name="Install Starter Models",
|
name="Install Starter Models",
|
||||||
values=starter_model_labels,
|
values=[starter_model_labels[x] for x in keys],
|
||||||
value=[
|
value=[
|
||||||
self.starter_model_list.index(x)
|
keys.index(x)
|
||||||
for x in self.starter_model_list
|
for x in keys
|
||||||
if (show_recommended and x in recommended_models)\
|
if (show_recommended and models[x].recommended) \
|
||||||
or (x in self.existing_models)
|
or (x in self.installed_models)
|
||||||
],
|
],
|
||||||
max_height=len(starter_model_labels) + 1,
|
max_height=len(starters) + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
),
|
||||||
|
models = keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
widgets.update(
|
|
||||||
purge_deleted = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Purge unchecked diffusers models from disk",
|
|
||||||
value=False,
|
|
||||||
scroll_exit=True,
|
|
||||||
relx=4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
widgets['purge_deleted'].when_value_edited = lambda: self.sync_purge_buttons(widgets['purge_deleted'])
|
|
||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
return widgets
|
return widgets
|
||||||
|
|
||||||
############# Add a set of model install widgets ########
|
############# Add a set of model install widgets ########
|
||||||
def add_model_widgets(self,
|
def add_model_widgets(self,
|
||||||
predefined_models: dict[str,bool],
|
model_type: ModelType,
|
||||||
model_type: str,
|
|
||||||
window_width: int=120,
|
window_width: int=120,
|
||||||
install_prompt: str=None,
|
install_prompt: str=None,
|
||||||
add_purge_deleted: bool=False,
|
exclude: set=set(),
|
||||||
)->dict[str,npyscreen.widget]:
|
)->dict[str,npyscreen.widget]:
|
||||||
'''Generic code to create model selection widgets'''
|
'''Generic code to create model selection widgets'''
|
||||||
widgets = dict()
|
widgets = dict()
|
||||||
model_list = sorted(predefined_models.keys())
|
model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude]
|
||||||
|
model_labels = [self.model_labels[x] for x in model_list]
|
||||||
if len(model_list) > 0:
|
if len(model_list) > 0:
|
||||||
max_width = max([len(x) for x in model_list])
|
max_width = max([len(x) for x in model_labels])
|
||||||
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
columns = window_width // (max_width+8) # 8 characters for "[x] " and padding
|
||||||
columns = min(len(model_list),columns) or 1
|
columns = min(len(model_list),columns) or 1
|
||||||
prompt = install_prompt or f"Select the desired {model_type} models to install. Unchecked models will be purged from disk."
|
prompt = install_prompt or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
|
||||||
|
|
||||||
widgets.update(
|
widgets.update(
|
||||||
label1 = self.add_widget_intelligent(
|
label1 = self.add_widget_intelligent(
|
||||||
@ -310,31 +276,19 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
MultiSelectColumns,
|
MultiSelectColumns,
|
||||||
columns=columns,
|
columns=columns,
|
||||||
name=f"Install {model_type} Models",
|
name=f"Install {model_type} Models",
|
||||||
values=model_list,
|
values=model_labels,
|
||||||
value=[
|
value=[
|
||||||
model_list.index(x)
|
model_list.index(x)
|
||||||
for x in model_list
|
for x in model_list
|
||||||
if predefined_models[x]
|
if self.all_models[x].installed
|
||||||
],
|
],
|
||||||
max_height=len(model_list)//columns + 1,
|
max_height=len(model_list)//columns + 1,
|
||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
),
|
||||||
|
models = model_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_purge_deleted:
|
|
||||||
self.nextrely += 1
|
|
||||||
widgets.update(
|
|
||||||
purge_deleted = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Purge unchecked diffusers models from disk",
|
|
||||||
value=False,
|
|
||||||
scroll_exit=True,
|
|
||||||
relx=4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
widgets['purge_deleted'].when_value_edited = lambda: self.sync_purge_buttons(widgets['purge_deleted'])
|
|
||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
widgets.update(
|
widgets.update(
|
||||||
download_ids = self.add_widget_intelligent(
|
download_ids = self.add_widget_intelligent(
|
||||||
@ -348,63 +302,33 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
return widgets
|
return widgets
|
||||||
|
|
||||||
### Tab for arbitrary diffusers widgets ###
|
### Tab for arbitrary diffusers widgets ###
|
||||||
def add_diffusers_widgets(self,
|
def add_pipeline_widgets(self,
|
||||||
predefined_models: dict[str,bool],
|
model_type: ModelType=ModelType.Main,
|
||||||
model_type: str='Diffusers',
|
window_width: int=120,
|
||||||
window_width: int=120,
|
**kwargs,
|
||||||
)->dict[str,npyscreen.widget]:
|
)->dict[str,npyscreen.widget]:
|
||||||
'''Similar to add_model_widgets() but adds some additional widgets at the bottom
|
'''Similar to add_model_widgets() but adds some additional widgets at the bottom
|
||||||
to support the autoload directory'''
|
to support the autoload directory'''
|
||||||
widgets = self.add_model_widgets(
|
widgets = self.add_model_widgets(
|
||||||
predefined_models,
|
model_type = model_type,
|
||||||
'Diffusers',
|
window_width = window_width,
|
||||||
window_width,
|
install_prompt=f"Additional {model_type.value.title()} models already installed.",
|
||||||
install_prompt="Additional diffusers models already installed.",
|
**kwargs,
|
||||||
add_purge_deleted=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
|
|
||||||
self.nextrely += 1
|
|
||||||
widgets.update(
|
|
||||||
autoload_directory = self.add_widget_intelligent(
|
|
||||||
FileBox,
|
|
||||||
max_height=3,
|
|
||||||
name=label,
|
|
||||||
value=str(config.autoconvert_dir) if config.autoconvert_dir else None,
|
|
||||||
select_dir=True,
|
|
||||||
must_exist=True,
|
|
||||||
use_two_lines=False,
|
|
||||||
labelColor="DANGER",
|
|
||||||
begin_entry_at=len(label)+1,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
widgets.update(
|
|
||||||
autoscan_on_startup = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Scan and import from this directory each time InvokeAI starts",
|
|
||||||
value=config.autoconvert_dir is not None,
|
|
||||||
relx=4,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return widgets
|
return widgets
|
||||||
|
|
||||||
def sync_purge_buttons(self,checkbox):
|
|
||||||
value = checkbox.value
|
|
||||||
self.starter_diffusers_models['purge_deleted'].value = value
|
|
||||||
self.diffusers_models['purge_deleted'].value = value
|
|
||||||
|
|
||||||
def resize(self):
|
def resize(self):
|
||||||
super().resize()
|
super().resize()
|
||||||
if (s := self.starter_diffusers_models.get("models_selected")):
|
if (s := self.starter_pipelines.get("models_selected")):
|
||||||
s.values = self._get_starter_model_labels()
|
keys = [x for x in self.all_models.keys() if x in self.starter_models]
|
||||||
|
s.values = [self.model_labels[x] for x in keys]
|
||||||
|
|
||||||
def _toggle_tables(self, value=None):
|
def _toggle_tables(self, value=None):
|
||||||
selected_tab = value[0]
|
selected_tab = value[0]
|
||||||
widgets = [
|
widgets = [
|
||||||
self.starter_diffusers_models,
|
self.starter_pipelines,
|
||||||
self.diffusers_models,
|
self.pipeline_models,
|
||||||
self.controlnet_models,
|
self.controlnet_models,
|
||||||
self.lora_models,
|
self.lora_models,
|
||||||
self.ti_models,
|
self.ti_models,
|
||||||
@ -412,34 +336,38 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
for group in widgets:
|
for group in widgets:
|
||||||
for k,v in group.items():
|
for k,v in group.items():
|
||||||
v.hidden = True
|
try:
|
||||||
v.editable = False
|
v.hidden = True
|
||||||
|
v.editable = False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
for k,v in widgets[selected_tab].items():
|
for k,v in widgets[selected_tab].items():
|
||||||
v.hidden = False
|
try:
|
||||||
if not isinstance(v,(npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
v.hidden = False
|
||||||
v.editable = True
|
if not isinstance(v,(npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||||
|
v.editable = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
self.__class__.current_tab = selected_tab # for persistence
|
self.__class__.current_tab = selected_tab # for persistence
|
||||||
self.display()
|
self.display()
|
||||||
|
|
||||||
def _get_starter_model_labels(self) -> List[str]:
|
def _get_model_labels(self) -> dict[str,str]:
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
label_width = 25
|
|
||||||
checkbox_width = 4
|
checkbox_width = 4
|
||||||
spacing_width = 2
|
spacing_width = 2
|
||||||
|
|
||||||
|
models = self.all_models
|
||||||
|
label_width = max([len(models[x].name) for x in models])
|
||||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||||
im = self.starter_models
|
|
||||||
names = self.starter_model_list
|
|
||||||
descriptions = [
|
|
||||||
im[x].description[0 : description_width - 3] + "..."
|
|
||||||
if len(im[x].description) > description_width
|
|
||||||
else im[x].description
|
|
||||||
for x in names
|
|
||||||
]
|
|
||||||
return [
|
|
||||||
f"%-{label_width}s %s" % (names[x], descriptions[x])
|
|
||||||
for x in range(0, len(names))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
result = dict()
|
||||||
|
for x in models.keys():
|
||||||
|
description = models[x].description
|
||||||
|
description = description[0 : description_width - 3] + "..." \
|
||||||
|
if description and len(description) > description_width \
|
||||||
|
else description if description else ''
|
||||||
|
result[x] = f"%-{label_width}s %s" % (models[x].name, description)
|
||||||
|
return result
|
||||||
|
|
||||||
def _get_columns(self) -> int:
|
def _get_columns(self) -> int:
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
@ -467,7 +395,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
target = process_and_execute,
|
target = process_and_execute,
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
opt = app.program_opts,
|
opt = app.program_opts,
|
||||||
selections = app.user_selections,
|
selections = app.install_selections,
|
||||||
conn_out = child_conn,
|
conn_out = child_conn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -475,8 +403,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
child_conn.close()
|
child_conn.close()
|
||||||
self.subprocess_connection = parent_conn
|
self.subprocess_connection = parent_conn
|
||||||
self.subprocess = p
|
self.subprocess = p
|
||||||
app.user_selections = UserSelections()
|
app.install_selections = InstallSelections()
|
||||||
# process_and_execute(app.opt, app.user_selections)
|
# process_and_execute(app.opt, app.install_selections)
|
||||||
|
|
||||||
def on_back(self):
|
def on_back(self):
|
||||||
self.parentApp.switchFormPrevious()
|
self.parentApp.switchFormPrevious()
|
||||||
@ -492,7 +420,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
self.parentApp.setNextForm(None)
|
self.parentApp.setNextForm(None)
|
||||||
self.parentApp.user_cancelled = False
|
self.parentApp.user_cancelled = False
|
||||||
self.editing = False
|
self.editing = False
|
||||||
|
|
||||||
########## This routine monitors the child process that is performing model installation and removal #####
|
########## This routine monitors the child process that is performing model installation and removal #####
|
||||||
def while_waiting(self):
|
def while_waiting(self):
|
||||||
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
'''Called during idle periods. Main task is to update the Log Messages box with messages
|
||||||
@ -548,8 +476,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||||
saved_messages = self.monitor.entry_widget.values
|
saved_messages = self.monitor.entry_widget.values
|
||||||
autoload_dir = self.diffusers_models['autoload_directory'].value
|
# autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
||||||
autoscan = self.diffusers_models['autoscan_on_startup'].value
|
# autoscan = self.pipeline_models['autoscan_on_startup'].value
|
||||||
|
|
||||||
app.main_form = app.addForm(
|
app.main_form = app.addForm(
|
||||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
|
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
|
||||||
@ -558,23 +486,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
app.main_form.monitor.entry_widget.values = saved_messages
|
app.main_form.monitor.entry_widget.values = saved_messages
|
||||||
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
||||||
app.main_form.diffusers_models['autoload_directory'].value = autoload_dir
|
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||||
app.main_form.diffusers_models['autoscan_on_startup'].value = autoscan
|
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||||
|
|
||||||
###############################################################
|
|
||||||
|
|
||||||
def list_additional_diffusers_models(self,
|
|
||||||
manager: ModelManager,
|
|
||||||
starters:dict
|
|
||||||
)->dict[str,bool]:
|
|
||||||
'''Return a dict of all the currently installed models that are not on the starter list'''
|
|
||||||
model_info = manager.list_models()
|
|
||||||
additional_models = {
|
|
||||||
x:True for x in model_info \
|
|
||||||
if model_info[x]['format']=='diffusers' \
|
|
||||||
and x not in starters
|
|
||||||
}
|
|
||||||
return additional_models
|
|
||||||
|
|
||||||
def marshall_arguments(self):
|
def marshall_arguments(self):
|
||||||
"""
|
"""
|
||||||
@ -586,89 +499,40 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||||
"""
|
"""
|
||||||
# we're using a global here rather than storing the result in the parentapp
|
selections = self.parentApp.install_selections
|
||||||
# due to some bug in npyscreen that is causing attributes to be lost
|
all_models = self.all_models
|
||||||
selections = self.parentApp.user_selections
|
|
||||||
|
|
||||||
# Starter models to install/remove
|
# Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove
|
||||||
starter_models = dict(
|
ui_sections = [self.starter_pipelines, self.pipeline_models,
|
||||||
map(
|
self.controlnet_models, self.lora_models, self.ti_models]
|
||||||
lambda x: (self.starter_model_list[x], True),
|
for section in ui_sections:
|
||||||
self.starter_diffusers_models['models_selected'].value,
|
if not 'models_selected' in section:
|
||||||
)
|
continue
|
||||||
)
|
selected = set([section['models'][x] for x in section['models_selected'].value])
|
||||||
selections.purge_deleted_models = self.starter_diffusers_models['purge_deleted'].value or \
|
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||||
self.diffusers_models['purge_deleted'].value
|
models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed]
|
||||||
|
selections.remove_models.extend(models_to_remove)
|
||||||
selections.install_models = [x for x in starter_models if x not in self.existing_models]
|
selections.install_models.extend(all_models[x].path or all_models[x].repo_id \
|
||||||
selections.remove_models = [x for x in self.starter_model_list if x in self.existing_models and x not in starter_models]
|
for x in models_to_install if all_models[x].path or all_models[x].repo_id)
|
||||||
|
|
||||||
# "More" models
|
# models located in the 'download_ids" section
|
||||||
selections.import_model_paths = self.diffusers_models['download_ids'].value.split()
|
for section in ui_sections:
|
||||||
if diffusers_selected := self.diffusers_models.get('models_selected'):
|
if downloads := section.get('download_ids'):
|
||||||
selections.remove_models.extend([x
|
selections.install_models.extend(downloads.value.split())
|
||||||
for x in diffusers_selected.values
|
|
||||||
if self.installed_diffusers_models[x]
|
|
||||||
and diffusers_selected.values.index(x) not in diffusers_selected.value
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: REFACTOR THIS REPETITIVE CODE
|
|
||||||
if cn_models_selected := self.controlnet_models.get('models_selected'):
|
|
||||||
selections.install_cn_models = [cn_models_selected.values[x]
|
|
||||||
for x in cn_models_selected.value
|
|
||||||
if not self.installed_cn_models[cn_models_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_cn_models = [x
|
|
||||||
for x in cn_models_selected.values
|
|
||||||
if self.installed_cn_models[x]
|
|
||||||
and cn_models_selected.values.index(x) not in cn_models_selected.value
|
|
||||||
]
|
|
||||||
if (additional_cns := self.controlnet_models['download_ids'].value.split()):
|
|
||||||
valid_cns = [x for x in additional_cns if '/' in x]
|
|
||||||
selections.install_cn_models.extend(valid_cns)
|
|
||||||
|
|
||||||
# same thing, for LoRAs
|
|
||||||
if loras_selected := self.lora_models.get('models_selected'):
|
|
||||||
selections.install_lora_models = [loras_selected.values[x]
|
|
||||||
for x in loras_selected.value
|
|
||||||
if not self.installed_lora_models[loras_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_lora_models = [x
|
|
||||||
for x in loras_selected.values
|
|
||||||
if self.installed_lora_models[x]
|
|
||||||
and loras_selected.values.index(x) not in loras_selected.value
|
|
||||||
]
|
|
||||||
if (additional_loras := self.lora_models['download_ids'].value.split()):
|
|
||||||
selections.install_lora_models.extend(additional_loras)
|
|
||||||
|
|
||||||
# same thing, for TIs
|
|
||||||
# TODO: refactor
|
|
||||||
if tis_selected := self.ti_models.get('models_selected'):
|
|
||||||
selections.install_ti_models = [tis_selected.values[x]
|
|
||||||
for x in tis_selected.value
|
|
||||||
if not self.installed_ti_models[tis_selected.values[x]]
|
|
||||||
]
|
|
||||||
selections.remove_ti_models = [x
|
|
||||||
for x in tis_selected.values
|
|
||||||
if self.installed_ti_models[x]
|
|
||||||
and tis_selected.values.index(x) not in tis_selected.value
|
|
||||||
]
|
|
||||||
|
|
||||||
if (additional_tis := self.ti_models['download_ids'].value.split()):
|
|
||||||
selections.install_ti_models.extend(additional_tis)
|
|
||||||
|
|
||||||
# load directory and whether to scan on startup
|
# load directory and whether to scan on startup
|
||||||
selections.scan_directory = self.diffusers_models['autoload_directory'].value
|
# if self.parentApp.autoload_pending:
|
||||||
selections.autoscan_on_startup = self.diffusers_models['autoscan_on_startup'].value
|
# selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
||||||
|
# self.parentApp.autoload_pending = False
|
||||||
|
# selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
def __init__(self,opt):
|
def __init__(self,opt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.program_opts = opt
|
self.program_opts = opt
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.user_selections = UserSelections()
|
# self.autoload_pending = True
|
||||||
|
self.install_selections = InstallSelections()
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||||
@ -687,26 +551,22 @@ class StderrToMessage():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def ask_user_for_config_file(model_path: Path,
|
def ask_user_for_prediction_type(model_path: Path,
|
||||||
tui_conn: Connection=None
|
tui_conn: Connection=None
|
||||||
)->Path:
|
)->SchedulerPredictionType:
|
||||||
if tui_conn:
|
if tui_conn:
|
||||||
logger.debug('Waiting for user response...')
|
logger.debug('Waiting for user response...')
|
||||||
return _ask_user_for_cf_tui(model_path, tui_conn)
|
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||||
else:
|
else:
|
||||||
return _ask_user_for_cf_cmdline(model_path)
|
return _ask_user_for_pt_cmdline(model_path)
|
||||||
|
|
||||||
def _ask_user_for_cf_cmdline(model_path):
|
def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
|
||||||
choices = [
|
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||||
config.legacy_conf_path / x
|
|
||||||
for x in ['v2-inference.yaml','v2-inference-v.yaml']
|
|
||||||
]
|
|
||||||
choices.extend([None])
|
|
||||||
print(
|
print(
|
||||||
f"""
|
f"""
|
||||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||||
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
|
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
|
||||||
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
|
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
|
||||||
[3] Skip this model and come back later.
|
[3] Skip this model and come back later.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -723,7 +583,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
|
|||||||
return
|
return
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
|
||||||
try:
|
try:
|
||||||
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||||
# note that we don't do any status checking here
|
# note that we don't do any status checking here
|
||||||
@ -731,20 +591,20 @@ def _ask_user_for_cf_tui(model_path: Path, tui_conn: Connection)->Path:
|
|||||||
if response is None:
|
if response is None:
|
||||||
return None
|
return None
|
||||||
elif response == 'epsilon':
|
elif response == 'epsilon':
|
||||||
return config.legacy_conf_path / 'v2-inference.yaml'
|
return SchedulerPredictionType.epsilon
|
||||||
elif response == 'v':
|
elif response == 'v':
|
||||||
return config.legacy_conf_path / 'v2-inference-v.yaml'
|
return SchedulerPredictionType.VPrediction
|
||||||
elif response == 'abort':
|
elif response == 'abort':
|
||||||
logger.info('Conversion aborted')
|
logger.info('Conversion aborted')
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return Path(response)
|
return response
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def process_and_execute(opt: Namespace,
|
def process_and_execute(opt: Namespace,
|
||||||
selections: UserSelections,
|
selections: InstallSelections,
|
||||||
conn_out: Connection=None,
|
conn_out: Connection=None,
|
||||||
):
|
):
|
||||||
# set up so that stderr is sent to conn_out
|
# set up so that stderr is sent to conn_out
|
||||||
@ -755,34 +615,14 @@ def process_and_execute(opt: Namespace,
|
|||||||
logger = InvokeAILogger.getLogger()
|
logger = InvokeAILogger.getLogger()
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
logger.addHandler(logging.StreamHandler(translator))
|
logger.addHandler(logging.StreamHandler(translator))
|
||||||
|
|
||||||
models_to_install = selections.install_models
|
|
||||||
models_to_remove = selections.remove_models
|
|
||||||
directory_to_scan = selections.scan_directory
|
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
|
||||||
potential_models_to_install = selections.import_model_paths
|
|
||||||
|
|
||||||
install_requested_models(
|
installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out))
|
||||||
diffusers = ModelInstallList(models_to_install, models_to_remove),
|
installer.install(selections)
|
||||||
controlnet = ModelInstallList(selections.install_cn_models, selections.remove_cn_models),
|
|
||||||
lora = ModelInstallList(selections.install_lora_models, selections.remove_lora_models),
|
|
||||||
ti = ModelInstallList(selections.install_ti_models, selections.remove_ti_models),
|
|
||||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
|
||||||
external_models=potential_models_to_install,
|
|
||||||
scan_at_startup=scan_at_startup,
|
|
||||||
precision="float32"
|
|
||||||
if opt.full_precision
|
|
||||||
else choose_precision(torch.device(choose_torch_device())),
|
|
||||||
purge_deleted=selections.purge_deleted_models,
|
|
||||||
config_file_path=Path(opt.config_file) if opt.config_file else config.model_conf_path,
|
|
||||||
model_config_file_callback = lambda x: ask_user_for_config_file(x,conn_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
if conn_out:
|
if conn_out:
|
||||||
conn_out.send_bytes('*done*'.encode('utf-8'))
|
conn_out.send_bytes('*done*'.encode('utf-8'))
|
||||||
conn_out.close()
|
conn_out.close()
|
||||||
|
|
||||||
|
|
||||||
def do_listings(opt)->bool:
|
def do_listings(opt)->bool:
|
||||||
"""List installed models of various sorts, and return
|
"""List installed models of various sorts, and return
|
||||||
True if any were requested."""
|
True if any were requested."""
|
||||||
@ -813,39 +653,34 @@ def select_and_download_models(opt: Namespace):
|
|||||||
if opt.full_precision
|
if opt.full_precision
|
||||||
else choose_precision(torch.device(choose_torch_device()))
|
else choose_precision(torch.device(choose_torch_device()))
|
||||||
)
|
)
|
||||||
|
config.precision = precision
|
||||||
if do_listings(opt):
|
helper = lambda x: ask_user_for_prediction_type(x)
|
||||||
pass
|
# if do_listings(opt):
|
||||||
# this processes command line additions/removals
|
# pass
|
||||||
elif opt.diffusers or opt.controlnets or opt.textual_inversions or opt.loras:
|
|
||||||
action = 'remove_models' if opt.delete else 'install_models'
|
installer = ModelInstall(config, prediction_type_helper=helper)
|
||||||
diffusers_args = {'diffusers':ModelInstallList(remove_models=opt.diffusers or [])} \
|
if opt.add or opt.delete:
|
||||||
if opt.delete \
|
selections = InstallSelections(
|
||||||
else {'external_models':opt.diffusers or []}
|
install_models = opt.add or [],
|
||||||
install_requested_models(
|
remove_models = opt.delete or []
|
||||||
**diffusers_args,
|
|
||||||
controlnet=ModelInstallList(**{action:opt.controlnets or []}),
|
|
||||||
ti=ModelInstallList(**{action:opt.textual_inversions or []}),
|
|
||||||
lora=ModelInstallList(**{action:opt.loras or []}),
|
|
||||||
precision=precision,
|
|
||||||
purge_deleted=True,
|
|
||||||
model_config_file_callback=lambda x: ask_user_for_config_file(x),
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
elif opt.default_only:
|
elif opt.default_only:
|
||||||
install_requested_models(
|
selections = InstallSelections(
|
||||||
diffusers=ModelInstallList(install_models=default_dataset()),
|
install_models = installer.default_model()
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
elif opt.yes_to_all:
|
elif opt.yes_to_all:
|
||||||
install_requested_models(
|
selections = InstallSelections(
|
||||||
diffusers=ModelInstallList(install_models=recommended_datasets()),
|
install_models = installer.recommended_models()
|
||||||
precision=precision,
|
|
||||||
)
|
)
|
||||||
|
installer.install(selections)
|
||||||
|
|
||||||
# this is where the TUI is called
|
# this is where the TUI is called
|
||||||
else:
|
else:
|
||||||
# needed because the torch library is loaded, even though we don't use it
|
# needed because the torch library is loaded, even though we don't use it
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
# currently commented out because it has started generating errors (?)
|
||||||
|
# torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
# the third argument is needed in the Windows 11 environment in
|
# the third argument is needed in the Windows 11 environment in
|
||||||
# order to launch and resize a console window running this program
|
# order to launch and resize a console window running this program
|
||||||
@ -861,35 +696,20 @@ def select_and_download_models(opt: Namespace):
|
|||||||
installApp.main_form.subprocess.terminate()
|
installApp.main_form.subprocess.terminate()
|
||||||
installApp.main_form.subprocess = None
|
installApp.main_form.subprocess = None
|
||||||
raise e
|
raise e
|
||||||
process_and_execute(opt, installApp.user_selections)
|
process_and_execute(opt, installApp.install_selections)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--diffusers",
|
"--add",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
help="List of URLs or repo_ids of diffusers to install/delete",
|
help="List of URLs, local paths or repo_ids of models to install",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--loras",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of LoRA/LyCORIS models to install/delete",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--controlnets",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of controlnet models to install/delete",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--textual-inversions",
|
|
||||||
nargs="*",
|
|
||||||
help="List of URLs or repo_ids of textual inversion embeddings to install/delete",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--delete",
|
"--delete",
|
||||||
action="store_true",
|
nargs="*",
|
||||||
help="Delete models listed on command line rather than installing them",
|
help="List of names of models to idelete",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--full-precision",
|
"--full-precision",
|
||||||
@ -909,7 +729,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--default_only",
|
"--default_only",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="only install the default model",
|
help="Only install the default model",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-models",
|
"--list-models",
|
||||||
|
@ -17,8 +17,8 @@ from shutil import get_terminal_size
|
|||||||
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
from curses import BUTTON2_CLICKED,BUTTON3_CLICKED
|
||||||
|
|
||||||
# minimum size for UIs
|
# minimum size for UIs
|
||||||
MIN_COLS = 120
|
MIN_COLS = 130
|
||||||
MIN_LINES = 50
|
MIN_LINES = 40
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
||||||
@ -73,6 +73,12 @@ def _set_terminal_size_unix(width: int, height: int):
|
|||||||
import fcntl
|
import fcntl
|
||||||
import termios
|
import termios
|
||||||
|
|
||||||
|
# These terminals accept the size command and report that the
|
||||||
|
# size changed, but they lie!!!
|
||||||
|
for bad_terminal in ['TERMINATOR_UUID', 'ALACRITTY_WINDOW_ID']:
|
||||||
|
if os.environ.get(bad_terminal):
|
||||||
|
return
|
||||||
|
|
||||||
winsize = struct.pack("HHHH", height, width, 0, 0)
|
winsize = struct.pack("HHHH", height, width, 0, 0)
|
||||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||||
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
||||||
@ -87,6 +93,12 @@ def set_min_terminal_size(min_cols: int, min_lines: int, launch_command: str=Non
|
|||||||
lines = max(term_lines, min_lines)
|
lines = max(term_lines, min_lines)
|
||||||
set_terminal_size(cols, lines, launch_command)
|
set_terminal_size(cols, lines, launch_command)
|
||||||
|
|
||||||
|
# did it work?
|
||||||
|
term_cols, term_lines = get_terminal_size()
|
||||||
|
if term_cols < cols or term_lines < lines:
|
||||||
|
print(f'This window is too small for optimal display. For best results please enlarge it.')
|
||||||
|
input('After resizing, press any key to continue...')
|
||||||
|
|
||||||
class IntSlider(npyscreen.Slider):
|
class IntSlider(npyscreen.Slider):
|
||||||
def translate_value(self):
|
def translate_value(self):
|
||||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||||
@ -390,13 +402,12 @@ def select_stable_diffusion_config_file(
|
|||||||
wrap:bool =True,
|
wrap:bool =True,
|
||||||
model_name:str='Unknown',
|
model_name:str='Unknown',
|
||||||
):
|
):
|
||||||
message = "Please select the correct base model for the V2 checkpoint named {model_name}. Press <CANCEL> to skip installation."
|
message = f"Please select the correct base model for the V2 checkpoint named '{model_name}'. Press <CANCEL> to skip installation."
|
||||||
title = "CONFIG FILE SELECTION"
|
title = "CONFIG FILE SELECTION"
|
||||||
options=[
|
options=[
|
||||||
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
"An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)",
|
||||||
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
"An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)",
|
||||||
"Skip installation for now and come back later",
|
"Skip installation for now and come back later",
|
||||||
"Enter config file path manually",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
F = ConfirmCancelPopup(
|
F = ConfirmCancelPopup(
|
||||||
@ -418,35 +429,17 @@ def select_stable_diffusion_config_file(
|
|||||||
mlw.values = message
|
mlw.values = message
|
||||||
|
|
||||||
choice = F.add(
|
choice = F.add(
|
||||||
SingleSelectWithChanged,
|
npyscreen.SelectOne,
|
||||||
values = options,
|
values = options,
|
||||||
value = [0],
|
value = [0],
|
||||||
max_height = len(options)+1,
|
max_height = len(options)+1,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
file = F.add(
|
|
||||||
FileBox,
|
|
||||||
name='Path to config file',
|
|
||||||
max_height=3,
|
|
||||||
hidden=True,
|
|
||||||
must_exist=True,
|
|
||||||
scroll_exit=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def toggle_visible(value):
|
|
||||||
value = value[0]
|
|
||||||
if value==3:
|
|
||||||
file.hidden=False
|
|
||||||
else:
|
|
||||||
file.hidden=True
|
|
||||||
F.display()
|
|
||||||
|
|
||||||
choice.on_changed = toggle_visible
|
|
||||||
|
|
||||||
F.editw = 1
|
F.editw = 1
|
||||||
F.edit()
|
F.edit()
|
||||||
if not F.value:
|
if not F.value:
|
||||||
return None
|
return None
|
||||||
assert choice.value[0] in range(0,4),'invalid choice'
|
assert choice.value[0] in range(0,3),'invalid choice'
|
||||||
choices = ['epsilon','v','abort',file.value]
|
choices = ['epsilon','v','abort']
|
||||||
return choices[choice.value[0]]
|
return choices[choice.value[0]]
|
||||||
|
@ -48,7 +48,7 @@ const App = ({
|
|||||||
const isApplicationReady = useIsApplicationReady();
|
const isApplicationReady = useIsApplicationReady();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
const { data: controlnetModels } = useListModelsQuery({
|
const { data: controlnetModels } = useListModelsQuery({
|
||||||
model_type: 'controlnet',
|
model_type: 'controlnet',
|
||||||
|
@ -23,7 +23,7 @@ const ModelInputFieldComponent = (
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
|
@ -24,7 +24,7 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
|
@ -120,6 +120,7 @@ dependencies = [
|
|||||||
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
|
||||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||||
"invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install"
|
"invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install"
|
||||||
|
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||||
"invokeai-update" = "invokeai.frontend.install:invokeai_update"
|
"invokeai-update" = "invokeai.frontend.install:invokeai_update"
|
||||||
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
|
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
|
||||||
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
|
||||||
|
4
scripts/invokeai-migrate3
Normal file
4
scripts/invokeai-migrate3
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from invokeai.backend.install.migrate_to_3 import main
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
main()
|
3
scripts/invokeai-model-install.py
Normal file
3
scripts/invokeai-model-install.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from invokeai.frontend.install.model_install import main
|
||||||
|
main()
|
||||||
|
|
@ -1,278 +0,0 @@
|
|||||||
'''
|
|
||||||
Migrate the models directory and models.yaml file from an existing
|
|
||||||
InvokeAI 2.3 installation to 3.0.0.
|
|
||||||
'''
|
|
||||||
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
import diffusers
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from transformers import (
|
|
||||||
CLIPTextModel,
|
|
||||||
CLIPTokenizer,
|
|
||||||
AutoFeatureExtractor,
|
|
||||||
BertTokenizerFast,
|
|
||||||
)
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.model_management.model_probe import (
|
|
||||||
ModelProbe, ModelType, BaseModelType
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
def create_directory_structure(dest: Path):
|
|
||||||
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
|
||||||
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
|
|
||||||
ModelType.ControlNet,ModelType.TextualInversion]:
|
|
||||||
path = dest / model_base.value / model_type.value
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = dest / 'core'
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def copy_file(src:Path,dest:Path):
|
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
|
||||||
try:
|
|
||||||
shutil.copy(src, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'COPY FAILED: {str(e)}')
|
|
||||||
|
|
||||||
def copy_dir(src:Path,dest:Path):
|
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
|
||||||
try:
|
|
||||||
shutil.copytree(src, dest)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'COPY FAILED: {str(e)}')
|
|
||||||
|
|
||||||
def migrate_models(src_dir: Path, dest_dir: Path):
|
|
||||||
for root, dirs, files in os.walk(src_dir):
|
|
||||||
for f in files:
|
|
||||||
# hack - don't copy raw learned_embeds.bin, let them
|
|
||||||
# be copied as part of a tree copy operation
|
|
||||||
if f == 'learned_embeds.bin':
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
model = Path(root,f)
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
|
|
||||||
copy_file(model, dest)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
for d in dirs:
|
|
||||||
try:
|
|
||||||
model = Path(root,d)
|
|
||||||
info = ModelProbe().heuristic_probe(model)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
|
|
||||||
copy_dir(model, dest)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def migrate_support_models(dest_directory: Path):
|
|
||||||
if Path('./models/clipseg').exists():
|
|
||||||
copy_dir(Path('./models/clipseg'),dest_directory / 'core/misc/clipseg')
|
|
||||||
if Path('./models/realesrgan').exists():
|
|
||||||
copy_dir(Path('./models/realesrgan'),dest_directory / 'core/upscaling/realesrgan')
|
|
||||||
for d in ['codeformer','gfpgan']:
|
|
||||||
path = Path('./models',d)
|
|
||||||
if path.exists():
|
|
||||||
copy_dir(path,dest_directory / f'core/face_restoration/{d}')
|
|
||||||
|
|
||||||
def migrate_conversion_models(dest_directory: Path):
|
|
||||||
# These are needed for the conversion script
|
|
||||||
kwargs = dict(
|
|
||||||
cache_dir = Path('./models/hub'),
|
|
||||||
#local_files_only = True
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
logger.info('Migrating core tokenizers and text encoders')
|
|
||||||
target_dir = dest_directory / 'core' / 'convert'
|
|
||||||
|
|
||||||
# bert
|
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
|
||||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-1
|
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-2
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
|
||||||
|
|
||||||
# VAE
|
|
||||||
logger.info('Migrating stable diffusion VAE')
|
|
||||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
|
||||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
|
||||||
|
|
||||||
# safety checking
|
|
||||||
logger.info('Migrating safety checker')
|
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def migrate_tuning_models(dest: Path):
|
|
||||||
for subdir in ['embeddings','loras','controlnets']:
|
|
||||||
src = Path('.',subdir)
|
|
||||||
if not src.is_dir():
|
|
||||||
logger.info(f'{subdir} directory not found; skipping')
|
|
||||||
continue
|
|
||||||
logger.info(f'Scanning {subdir}')
|
|
||||||
migrate_models(src, dest)
|
|
||||||
|
|
||||||
def migrate_pipelines(dest_dir: Path, dest_yaml: io.TextIOBase):
|
|
||||||
cache = Path('./models/hub')
|
|
||||||
kwargs = dict(
|
|
||||||
cache_dir = cache,
|
|
||||||
local_files_only = True,
|
|
||||||
safety_checker = None,
|
|
||||||
)
|
|
||||||
for model in cache.glob('models--*'):
|
|
||||||
if len(list(model.glob('snapshots/**/model_index.json')))==0:
|
|
||||||
continue
|
|
||||||
_,owner,repo_name=model.name.split('--')
|
|
||||||
repo_id = f'{owner}/{repo_name}'
|
|
||||||
revisions = [x.name for x in model.glob('refs/*')]
|
|
||||||
for revision in revisions:
|
|
||||||
logger.info(f'Migrating {repo_id}, revision {revision}')
|
|
||||||
try:
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
||||||
repo_id,
|
|
||||||
revision=revision,
|
|
||||||
**kwargs)
|
|
||||||
info = ModelProbe().heuristic_probe(pipeline)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}-{revision}')
|
|
||||||
pipeline.save_pretrained(dest, safe_serialization=True)
|
|
||||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
|
||||||
stanza = {
|
|
||||||
f'{info.base_type.value}/{info.model_type.value}/{repo_name}-{revision}':
|
|
||||||
{
|
|
||||||
'name': repo_name,
|
|
||||||
'path': str(rel_path),
|
|
||||||
'description': f'diffusers model {repo_id}',
|
|
||||||
'format': 'diffusers',
|
|
||||||
'image_size': info.image_size,
|
|
||||||
'base': info.base_type.value,
|
|
||||||
'variant': info.variant_type.value,
|
|
||||||
'prediction_type': info.prediction_type.value,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
print(yaml.dump(stanza),file=dest_yaml,end="")
|
|
||||||
dest_yaml.flush()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f'Could not load the "{revision}" version of {repo_id}. Skipping.')
|
|
||||||
|
|
||||||
def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
|
|
||||||
# find any checkpoints referred to in old models.yaml
|
|
||||||
conf = OmegaConf.load('./configs/models.yaml')
|
|
||||||
orig_models_dir = Path.cwd() / 'models'
|
|
||||||
for model_name, stanza in conf.items():
|
|
||||||
if stanza.get('format') and stanza['format'] == 'ckpt':
|
|
||||||
try:
|
|
||||||
logger.info(f'Migrating checkpoint model {model_name}')
|
|
||||||
weights = orig_models_dir.parent / stanza['weights']
|
|
||||||
config = stanza['config']
|
|
||||||
info = ModelProbe().heuristic_probe(weights)
|
|
||||||
if not info:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# uh oh, weights is in the old models directory - move it into the new one
|
|
||||||
if Path(weights).is_relative_to(orig_models_dir):
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value,weights.name)
|
|
||||||
copy_file(weights,dest)
|
|
||||||
weights = Path('models', info.base_type.value, info.model_type.value,weights.name)
|
|
||||||
stanza = {
|
|
||||||
f'{info.base_type.value}/{info.model_type.value}/{model_name}':
|
|
||||||
{
|
|
||||||
'name': model_name,
|
|
||||||
'path': str(weights),
|
|
||||||
'description': f'checkpoint model {model_name}',
|
|
||||||
'format': 'checkpoint',
|
|
||||||
'image_size': info.image_size,
|
|
||||||
'base': info.base_type.value,
|
|
||||||
'variant': info.variant_type.value,
|
|
||||||
'config': config
|
|
||||||
}
|
|
||||||
}
|
|
||||||
print(yaml.dump(stanza),file=dest_yaml,end="")
|
|
||||||
dest_yaml.flush()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model directory migrator")
|
|
||||||
parser.add_argument('root_directory',
|
|
||||||
help='Root directory (containing "models", "embeddings", "controlnets" and "loras")'
|
|
||||||
)
|
|
||||||
parser.add_argument('--dest-directory',
|
|
||||||
default='./models-3.0',
|
|
||||||
help='Destination for new models directory',
|
|
||||||
)
|
|
||||||
parser.add_argument('--dest-yaml',
|
|
||||||
default='./models.yaml-3.0',
|
|
||||||
help='Destination for new models.yaml file',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
root_directory = Path(args.root_directory)
|
|
||||||
assert root_directory.is_dir(), f"{root_directory} is not a valid directory"
|
|
||||||
assert (root_directory / 'models').is_dir(), f"{root_directory} does not contain a 'models' subdirectory"
|
|
||||||
|
|
||||||
dest_directory = Path(args.dest_directory).resolve()
|
|
||||||
dest_yaml = Path(args.dest_yaml).resolve()
|
|
||||||
|
|
||||||
os.chdir(root_directory)
|
|
||||||
with open(dest_yaml,'w') as yaml_file:
|
|
||||||
print(yaml.dump({'__metadata__':
|
|
||||||
{'version':'3.0.0'}
|
|
||||||
}
|
|
||||||
),file=yaml_file,end=""
|
|
||||||
)
|
|
||||||
create_directory_structure(dest_directory)
|
|
||||||
migrate_support_models(dest_directory)
|
|
||||||
migrate_conversion_models(dest_directory)
|
|
||||||
migrate_tuning_models(dest_directory)
|
|
||||||
migrate_pipelines(dest_directory,yaml_file)
|
|
||||||
migrate_checkpoints(dest_directory,yaml_file)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user