mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(model_cache): factor out load_ckpt
This commit is contained in:
parent
dcfdb83513
commit
9b274bd57c
@ -1,5 +1,5 @@
|
||||
'''
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
@ -51,7 +51,7 @@ class ModelCache(object):
|
||||
identifier.
|
||||
'''
|
||||
return model_name in self.config
|
||||
|
||||
|
||||
def get_model(self, model_name:str):
|
||||
'''
|
||||
Given a model named identified in models.yaml, return
|
||||
@ -66,7 +66,7 @@ class ModelCache(object):
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
print(f'>> Retrieving model {model_name} from system RAM cache')
|
||||
@ -92,7 +92,7 @@ class ModelCache(object):
|
||||
print(f'** restoring {self.current_model}')
|
||||
self.get_model(self.current_model)
|
||||
return
|
||||
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
@ -102,7 +102,7 @@ class ModelCache(object):
|
||||
'hash': hash
|
||||
}
|
||||
|
||||
def default_model(self) -> str:
|
||||
def default_model(self) -> str | None:
|
||||
'''
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
@ -191,25 +191,13 @@ class ModelCache(object):
|
||||
omega[model_name] = config
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
print(f'"{model_name}" is not a known model name. Please check your models.yaml file')
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae')
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
@ -219,12 +207,39 @@ class ModelCache(object):
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
omega_config = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
model_format = mconfig.get('format', 'ckpt')
|
||||
if model_format == 'ckpt':
|
||||
weights = mconfig.weights
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
model, width, height, model_hash = self._load_ckpt_model(mconfig)
|
||||
elif model_format == 'diffusers':
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get('vae', None)
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights, 'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights,weight_bytes)
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
|
||||
del weight_bytes
|
||||
sd = sd['state_dict']
|
||||
@ -252,28 +267,19 @@ class ModelCache(object):
|
||||
model.to(self.device)
|
||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||
model.cond_stage_model.device = self.device
|
||||
|
||||
|
||||
model.eval()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
'>> Max VRAM used to load the model:',
|
||||
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
'\n>> Current VRAM usage:'
|
||||
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str) -> None:
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
raise NotImplementedError() # return pipeline, width, height, model_hash
|
||||
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
@ -288,7 +294,7 @@ class ModelCache(object):
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
# scan model
|
||||
print(f'>> Scanning Model: {model_name}')
|
||||
@ -318,7 +324,7 @@ class ModelCache(object):
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
@ -353,12 +359,12 @@ class ModelCache(object):
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
||||
|
||||
def _model_to_cpu(self,model):
|
||||
if self.device != 'cpu':
|
||||
model.cond_stage_model.device = 'cpu'
|
||||
model.first_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.cond_stage_model.to('cpu')
|
||||
model.model.to('cpu')
|
||||
return model.to('cpu')
|
||||
else:
|
||||
@ -388,7 +394,7 @@ class ModelCache(object):
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(model_name)
|
||||
self.stack.append(model_name)
|
||||
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == 'cuda'
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user