model manager defaults to consistent values of device and precision

This commit is contained in:
Lincoln Stein
2023-03-09 01:09:54 -05:00
parent 5d37fa6e36
commit b679a6ba37
3 changed files with 47 additions and 17 deletions

View File

@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -51,23 +50,28 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
}
class ModelManager(object):
'''
Model manager handles loading, caching, importing, deleting, converting, and editing models.
'''
def __init__(
self,
config: OmegaConf,
device_type: torch.device = CPU_DEVICE,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
):
"""
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
min_avail_mem argument specifies how much unused system
(CPU) memory to preserve. The cache of models in RAM will
grow until this value is approached. Default is 2G.
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
transformers.logging.set_verbosity_error()
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.precision = precision
self.device = torch.device(device_type)
@ -557,7 +561,7 @@ class ModelManager(object):
"""
model_name = model_name or Path(repo_or_path).stem
model_description = (
model_description or f"Imported diffusers model {model_name}"
description or f"Imported diffusers model {model_name}"
)
new_config = dict(
description=model_description,