fix incorrect variable/typenames in model_cache

This commit is contained in:
Lincoln Stein 2023-06-10 10:41:48 -04:00
parent 3d2ff7755e
commit 74b43c9bdf
3 changed files with 33 additions and 436 deletions

View File

@ -20,27 +20,19 @@ import gc
import os import os
import sys import sys
import hashlib import hashlib
import json
import warnings import warnings
from contextlib import suppress from contextlib import suppress
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, Union, types, Optional, List, Type, Any from typing import Dict, Union, types, Optional, Type, Any
import torch import torch
import transformers
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from huggingface_hub import HfApi, scan_cache_dir
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel from .model_manager import SDModelInfo, ModelType, SubModelType, ModelBase
from .models import MODEL_CLASSES
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
@ -113,7 +105,7 @@ class ModelCache(object):
#max_cache_size = 9999 #max_cache_size = 9999
execution_device = torch.device('cuda') execution_device = torch.device('cuda')
self.model_infos: Dict[str, ModelInfoBase] = dict() self.model_infos: Dict[str, SDModelInfo] = dict()
self.lazy_offloading = lazy_offloading self.lazy_offloading = lazy_offloading
#self.sequential_offload: bool=sequential_offload #self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision self.precision: torch.dtype=precision
@ -129,8 +121,8 @@ class ModelCache(object):
def get_key( def get_key(
self, self,
model_path: str, model_path: str,
model_type: SDModelType, model_type: ModelType,
submodel_type: Optional[SDModelType] = None, submodel_type: Optional[ModelType] = None,
): ):
key = f"{model_path}:{model_type}" key = f"{model_path}:{model_type}"
@ -141,11 +133,11 @@ class ModelCache(object):
#def get_model( #def get_model(
# self, # self,
# repo_id_or_path: Union[str, Path], # repo_id_or_path: Union[str, Path],
# model_type: SDModelType = SDModelType.Diffusers, # model_type: ModelType = ModelType.Diffusers,
# subfolder: Path = None, # subfolder: Path = None,
# submodel: SDModelType = None, # submodel: ModelType = None,
# revision: str = None, # revision: str = None,
# attach_model_part: Tuple[SDModelType, str] = (None, None), # attach_model_part: Tuple[ModelType, str] = (None, None),
# gpu_load: bool = True, # gpu_load: bool = True,
#) -> ModelLocker: # ?? what does it return #) -> ModelLocker: # ?? what does it return
def _get_model_info( def _get_model_info(
@ -155,14 +147,14 @@ class ModelCache(object):
): ):
model_info_key = self.get_key( model_info_key = self.get_key(
model_path=model_path, model_path=model_path,
model_type=model_type, model_type=model_class,
submodel_type=None, submodel_type=None,
) )
if model_info_key not in self.model_infos: if model_info_key not in self.model_infos:
self.model_infos[model_info_key] = model_class( self.model_infos[model_info_key] = model_class(
model_path, model_path,
model_type, model_class,
) )
return self.model_infos[model_info_key] return self.model_infos[model_info_key]
@ -188,14 +180,14 @@ class ModelCache(object):
) )
key = self.get_key( key = self.get_key(
model_path=model_path, model_path=model_path,
model_type=model_type, # TODO: model_type=model_class, # TODO:
submodel_type=submodel, submodel_type=submodel,
) )
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info(f'Loading model {model_path}, type {model_type}:{submodel}') self.logger.info(f'Loading model {model_path}, type {model_class}:{submodel}')
# this will remove older cached models until # this will remove older cached models until
# there is sufficient room to load the requested model # there is sufficient room to load the requested model
@ -203,7 +195,7 @@ class ModelCache(object):
# clean memory to make MemoryUsage() more accurate # clean memory to make MemoryUsage() more accurate
gc.collect() gc.collect()
model = model_info.get_model(submodel, torch_dtype=self.precision, variant=) model = model_info.get_model(submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel): if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')

View File

@ -148,21 +148,14 @@ into the model when downloaded or converted.
from __future__ import annotations from __future__ import annotations
import os import os
import re
import textwrap import textwrap
import shutil
import traceback
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from packaging import version from packaging import version
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Optional, List, Tuple, Union, types from typing import Dict, Optional, List, Tuple, Union, types
from shutil import rmtree from shutil import rmtree
import safetensors
import safetensors.torch
import torch import torch
from diffusers import AutoencoderKL
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -170,8 +163,7 @@ from omegaconf.dictconfig import DictConfig
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from ..install.model_install_backend import Dataset_path, hf_download_with_resume from .model_cache import ModelCache, ModelLocker
from .model_cache import ModelCache, ModelLocker, SilenceWarnings
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -183,7 +175,7 @@ CONFIG_FILE_VERSION='3.0.0'
class SDModelInfo(): class SDModelInfo():
context: ModelLocker context: ModelLocker
name: str name: str
type: SDModelType type: ModelType
hash: str hash: str
location: Union[Path,str] location: Union[Path,str]
precision: torch.dtype precision: torch.dtype
@ -292,7 +284,7 @@ class ModelManager(object):
def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]: def parse_key(self, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
base_model_str, model_type_str, model_name = model_key.split('/', 2) base_model_str, model_type_str, model_name = model_key.split('/', 2)
try: try:
model_type = SDModelType(model_type_str) model_type = ModelType(model_type_str)
except: except:
raise Exception(f"Unknown model type: {model_type_str}") raise Exception(f"Unknown model type: {model_type_str}")
@ -313,9 +305,9 @@ class ModelManager(object):
"""Given a model named identified in models.yaml, return """Given a model named identified in models.yaml, return
an SDModelInfo object describing it. an SDModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: SDModelType enum indicating the type of model to return :param model_type: ModelType enum indicating the type of model to return
:param submodel: an SDModelType enum indicating the portion of :param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. SDModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
If not provided, the model_type will be read from the `format` field If not provided, the model_type will be read from the `format` field
of the corresponding stanza. If provided, the model_type will be used of the corresponding stanza. If provided, the model_type will be used
@ -334,35 +326,23 @@ class ModelManager(object):
test1_pipeline = mgr.get_model('test1') test1_pipeline = mgr.get_model('test1')
# returns a StableDiffusionGeneratorPipeline # returns a StableDiffusionGeneratorPipeline
test1_vae1 = mgr.get_model('test1', submodel=SDModelType.Vae) test1_vae1 = mgr.get_model('test1', submodel=ModelType.Vae)
# returns the VAE part of a diffusers model as an AutoencoderKL # returns the VAE part of a diffusers model as an AutoencoderKL
test1_vae2 = mgr.get_model('test1', model_type=SDModelType.Diffusers, submodel=SDModelType.Vae) test1_vae2 = mgr.get_model('test1', model_type=ModelType.Diffusers, submodel=ModelType.Vae)
# does the same thing as the previous statement. Note that model_type # does the same thing as the previous statement. Note that model_type
# is for the parent model, and submodel is for the part # is for the parent model, and submodel is for the part
test1_lora = mgr.get_model('test1', model_type=SDModelType.Lora) test1_lora = mgr.get_model('test1', model_type=ModelType.Lora)
# returns a LoRA embed (as a 'dict' of tensors) # returns a LoRA embed (as a 'dict' of tensors)
test1_encoder = mgr.get_modelI('test1', model_type=SDModelType.TextEncoder) test1_encoder = mgr.get_modelI('test1', model_type=ModelType.TextEncoder)
# raises an InvalidModelError # raises an InvalidModelError
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_dir = self.globals.models_path
#if model_type in {
# ModelType.Lora,
# ModelType.ControlNet,
# ModelType.TextualInversion,
# ModelType.Vae,
#}:
if not model_class.has_config: if not model_class.has_config:
#if model_class.Config is None:
# skip config
# load from
# /models/{base_model}/{model_type}/{model_name}
# /models/{base_model}/{model_type}/{model_name}.{ext}
model_config = None model_config = None
for ext in {"pt", "ckpt", "safetensors"}: for ext in {"pt", "ckpt", "safetensors"}:
@ -385,17 +365,14 @@ class ModelManager(object):
) )
model_config = self.config[model_key] model_config = self.config[model_key]
# /models/{base_model}/{model_type}/{name}.ckpt or .safentesors
# /models/{base_model}/{model_type}/{name}/
model_path = model_config.path model_path = model_config.path
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel is not None and submodel in model_config: if submodel_type is not None and submodel_type in model_config:
model_path = model_config[submodel]["path"] model_path = model_config[submodel_type]["path"]
model_type = submodel model_type = submodel_type
submodel = None submodel_type = None
dst_convert_path = None # TODO: dst_convert_path = None # TODO:
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
@ -407,7 +384,7 @@ class ModelManager(object):
model_context = self.cache.get_model( model_context = self.cache.get_model(
model_path, model_path,
model_class, model_class,
submodel, submodel_type,
) )
hash = "<NO_HASH>" # TODO: hash = "<NO_HASH>" # TODO:
@ -416,7 +393,7 @@ class ModelManager(object):
context = model_context, context = model_context,
name = model_name, name = model_name,
base_model = base_model, base_model = base_model,
type = submodel or model_type, type = submodel_type or model_type,
hash = hash, hash = hash,
location = model_path, # TODO: location = model_path, # TODO:
precision = self.cache.precision, precision = self.cache.precision,
@ -480,7 +457,7 @@ class ModelManager(object):
def list_models( def list_models(
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[SDModelType] = None, model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]: ) -> Dict[str, Dict[str, str]]:
""" """
Return a dict of models, in format [base_model][model_type][model_name] Return a dict of models, in format [base_model][model_type][model_name]
@ -540,7 +517,7 @@ class ModelManager(object):
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
model_type: SDModelType.Diffusers, model_type: ModelType.Diffusers,
delete_files: bool = False, delete_files: bool = False,
): ):
""" """
@ -622,183 +599,6 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key]) self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key] del self.cache_keys[model_key]
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
model_name: str = None,
description: str = None,
vae: dict = None,
commit_to_conf: Path = None,
) -> bool:
"""
Attempts to install the indicated diffuser model and returns True if successful.
"repo_or_path" can be either a repo-id or a path-like object corresponding to the
top of a downloaded diffusers directory.
You can optionally provide a model name and/or description. If not provided,
then these will be derived from the repo name. If you provide a commit_to_conf
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_description = description or f"Imported diffusers model {model_name}"
new_config = dict(
description=model_description,
vae=vae,
format="diffusers",
)
if isinstance(repo_or_path, Path) and repo_or_path.exists():
new_config.update(path=str(repo_or_path))
else:
new_config.update(repo_id=repo_or_path)
self.add_model(model_name, SDModelType.Diffusers, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return self.create_key(model_name, SDModelType.Diffusers)
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_lora(
self,
path: Path,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(
model_name,
SDModelType.Lora,
dict(
format="lora",
weights=str(path),
description=model_description,
),
True
)
# TODO: DELETE OR UPDATE - handled by scan_models_directory()
def import_embedding(
self,
path: Path,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
"""
Creates an entry for the indicated lora file. Call
mgr.commit() to write out the configuration to models.yaml
"""
path = Path(path)
if path.is_directory() and (path / "learned_embeds.bin").exists():
weights = path / "learned_embeds.bin"
else:
weights = path
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(
model_name,
SDModelType.TextualInversion,
dict(
format="textual_inversion",
weights=str(weights),
description=model_description,
),
True
)
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae: dict = None,
vae_path: Path = None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool = True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
into models.yaml.
"""
ckpt_path = self._resolve_path(ckpt_path, "models/ldm/stable-diffusion-v1")
if original_config_file:
original_config_file = self._resolve_path(
original_config_file, "configs/stable-diffusion"
)
new_config = None
if diffusers_path.exists():
self.logger.error(
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
)
return
model_name = model_name or diffusers_path.name
model_description = model_description or f"Converted version of {model_name}"
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
# to avoid circular import errors
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
if vae:
vae_location = self.globals.root_dir / vae.get('path') \
if vae.get('path') \
else vae.get('repo_id')
vae_model = self.cache.get_model(vae_location, SDModelType.Vae).model
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}"
)
self.logger.debug(f"Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
format="diffusers",
)
if self.model_exists(model_name, SDModelType.Diffusers):
self.del_model(model_name, SDModelType.Diffusers)
self.add_model(
model_name,
SDModelType.Diffusers,
new_config,
True
)
if commit_to_conf:
self.commit(commit_to_conf)
self.logger.debug(f"Model {model_name} installed")
except Exception as e:
self.logger.warning(f"Conversion failed: {str(e)}")
self.logger.warning(traceback.format_exc())
self.logger.warning(
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
)
return model_name
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -1005,184 +805,3 @@ class ModelManager(object):
) )
##### NONE OF THE METHODS BELOW WORK NOW BECAUSE OF MODEL DIRECTORY REORGANIZATION
##### AND NEED TO BE REWRITTEN
def list_lora_models(self)->Dict[str,bool]:
'''Return a dict of installed lora models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('lora') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.lora_path
installed_models = dict()
for root, dirs, files in os.walk(dir):
for name in files:
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
continue
if name == 'pytorch_lora_weights.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_lora_models(self, model_names: list[str], access_token:str=None):
'''Download list of LoRA/LyCORIS models'''
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
for name in model_names:
name = short_names.get(name) or name
# HuggingFace style LoRA
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.lora_path / dest_dir,
model_name = 'pytorch_lora_weights.bin',
access_token = access_token,
)
elif name.startswith(("http:", "https:", "ftp:")):
download_with_resume(name, self.globals.lora_path)
else:
self.logger.error(f"Unknown repo_id or URL: {name}")
def delete_lora_models(self, model_names: List[str]):
'''Remove the list of lora models'''
for name in model_names:
file_or_directory = self.globals.lora_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging LoRA/LyCORIS {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.lora_path.glob(f'{name}.*'):
self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink()
def list_ti_models(self)->Dict[str,bool]:
'''Return a dict of installed textual models; key is either the shortname
defined in INITIAL_MODELS, or the basename of the file in the LoRA
directory. Value is True if installed'''
models = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
installed_models = {x: False for x in models.keys()}
dir = self.globals.embedding_path
for root, dirs, files in os.walk(dir):
for name in files:
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
continue
if name == 'learned_embeds.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_ti_models(self, model_names: list[str], access_token: str=None):
'''Download list of textual inversion embeddings'''
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
for name in model_names:
name = short_names.get(name) or name
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading Textual Inversion embedding {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.embedding_path / dest_dir,
model_name = 'learned_embeds.bin',
access_token = access_token
)
elif name.startswith(('http:','https:','ftp:')):
download_with_resume(name, self.globals.embedding_path)
else:
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
def delete_ti_models(self, model_names: list[str]):
'''Remove TI embeddings from disk'''
for name in model_names:
file_or_directory = self.globals.embedding_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging textual inversion embedding {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.embedding_path.glob(f'{name}.*'):
self.logger.info(f'Purging textual inversion embedding {name}')
path.unlink()
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and value is True if installed'''
cn_models = OmegaConf.load(Dataset_path).get('controlnet') or {}
installed_models = {x: False for x in cn_models.keys()}
cn_dir = self.globals.controlnet_path
for root, dirs, files in os.walk(cn_dir):
for name in dirs:
if Path(root, name, '.download_complete').exists():
installed_models.update({name.replace('--','/'): True})
return installed_models
def install_controlnet_models(self, model_names: list[str], access_token: str=None):
'''Download list of controlnet models; provide either repo_id or short name listed in INITIAL_MODELS.yaml'''
short_names = OmegaConf.load(Dataset_path).get('controlnet') or {}
dest_dir = self.globals.controlnet_path
dest_dir.mkdir(parents=True,exist_ok=True)
# The model file may be fp32 or fp16, and may be either a
# .bin file or a .safetensors. We try each until we get one,
# preferring 'fp16' if using half precision, and preferring
# safetensors over over bin.
precisions = ['.fp16',''] if self.precision=='float16' else ['']
formats = ['.safetensors','.bin']
possible_filenames = list()
for p in precisions:
for f in formats:
possible_filenames.append(Path(f'diffusion_pytorch_model{p}{f}'))
for directory_name in model_names:
repo_id = short_names.get(directory_name) or directory_name
safe_name = directory_name.replace('/','--')
self.logger.info(f'Downloading ControlNet model {directory_name} ({repo_id})')
hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = 'config.json',
access_token = access_token
)
path = None
for filename in possible_filenames:
suffix = filename.suffix
dest_filename = Path(f'diffusion_pytorch_model{suffix}')
self.logger.info(f'Checking availability of {directory_name}/{filename}...')
path = hf_download_with_resume(
repo_id = repo_id,
model_dir = dest_dir / safe_name,
model_name = str(filename),
access_token = access_token,
model_dest = Path(dest_dir, safe_name, dest_filename),
)
if path:
(path.parent / '.download_complete').touch()
break
def delete_controlnet_models(self, model_names: List[str]):
'''Remove the list of controlnet models'''
for name in model_names:
safe_name = name.replace('/','--')
directory = self.globals.controlnet_path / safe_name
if directory.exists():
self.logger.info(f'Purging controlnet model {name}')
shutil.rmtree(str(directory))

View File

@ -5,9 +5,6 @@ import safetensors.torch
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)? # TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "SD-1" StableDiffusion1_5 = "SD-1"
StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization StableDiffusion2Base = "SD-2-base" # 512 pixels; this will have epsilon parameterization
@ -18,7 +15,6 @@ class ModelType(str, Enum):
Pipeline = "pipeline" Pipeline = "pipeline"
Classifier = "classifier" Classifier = "classifier"
Vae = "vae" Vae = "vae"
Lora = "lora" Lora = "lora"
ControlNet = "controlnet" ControlNet = "controlnet"
TextualInversion = "embedding" TextualInversion = "embedding"
@ -420,8 +416,6 @@ class ClassifierModel(ModelBase):
model_path = Path(model_path) model_path = Path(model_path)
return model_path return model_path
class VaeModel(ModelBase): class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
@ -548,14 +542,6 @@ class TextualInversionModel(ModelBase):
model_path = Path(model_path) model_path = Path(model_path)
return model_path return model_path
def calc_model_size_by_fs( def calc_model_size_by_fs(
model_path: str, model_path: str,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,