mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager defaults to consistent values of device and precision
This commit is contained in:
@ -34,8 +34,7 @@ from picklescan.scanner import scan_file_path
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util import CPU_DEVICE, ask_user, download_with_resume
|
||||
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
@ -51,23 +50,28 @@ VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
}
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf,
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
config: OmegaConf|Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
min_avail_mem argument specifies how much unused system
|
||||
(CPU) memory to preserve. The cache of models in RAM will
|
||||
grow until this value is approached. Default is 2G.
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
if not isinstance(config, DictConfig):
|
||||
config = OmegaConf.load(config)
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
@ -557,7 +561,7 @@ class ModelManager(object):
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
model_description or f"Imported diffusers model {model_name}"
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
|
Reference in New Issue
Block a user