From 9b274bd57c5e4711729532a8d9ac9aac691fc4fd Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 9 Nov 2022 15:25:56 -0800 Subject: [PATCH] refactor(model_cache): factor out load_ckpt --- ldm/invoke/model_cache.py | 92 +++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 7d1654718a..c730c7b775 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -1,5 +1,5 @@ ''' -Manage a cache of Stable Diffusion model files for fast switching. +Manage a cache of Stable Diffusion model files for fast switching. They are moved between GPU and CPU as necessary. If CPU memory falls below a preset minimum, the least recently used model will be cleared and loaded from disk when next needed. @@ -51,7 +51,7 @@ class ModelCache(object): identifier. ''' return model_name in self.config - + def get_model(self, model_name:str): ''' Given a model named identified in models.yaml, return @@ -66,7 +66,7 @@ class ModelCache(object): if model_name not in self.models: # make room for a new one self._make_cache_room() self.offload_model(self.current_model) - + if model_name in self.models: requested_model = self.models[model_name]['model'] print(f'>> Retrieving model {model_name} from system RAM cache') @@ -92,7 +92,7 @@ class ModelCache(object): print(f'** restoring {self.current_model}') self.get_model(self.current_model) return - + self.current_model = model_name self._push_newest_model(model_name) return { @@ -102,7 +102,7 @@ class ModelCache(object): 'hash': hash } - def default_model(self) -> str: + def default_model(self) -> str | None: ''' Returns the name of the default model, or None if none is defined. @@ -191,25 +191,13 @@ class ModelCache(object): omega[model_name] = config if clobber: self._invalidate_cached_model(model_name) - + def _load_model(self, model_name:str): """Load and initialize the model from configuration variables passed at object creation time""" if model_name not in self.config: print(f'"{model_name}" is not a known model name. Please check your models.yaml file') mconfig = self.config[model_name] - config = mconfig.config - weights = mconfig.weights - vae = mconfig.get('vae') - width = mconfig.width - height = mconfig.height - - if not os.path.isabs(weights): - weights = os.path.normpath(os.path.join(Globals.root,weights)) - # scan model - self.scan_model(model_name, weights) - - print(f'>> Loading {model_name} from {weights}') # for usage statistics if self._has_cuda(): @@ -219,12 +207,39 @@ class ModelCache(object): tic = time.time() # this does the work - if not os.path.isabs(config): - config = os.path.join(Globals.root,config) - omega_config = OmegaConf.load(config) - with open(weights,'rb') as f: + model_format = mconfig.get('format', 'ckpt') + if model_format == 'ckpt': + weights = mconfig.weights + print(f'>> Loading {model_name} from {weights}') + model, width, height, model_hash = self._load_ckpt_model(mconfig) + elif model_format == 'diffusers': + model, width, height, model_hash = self._load_diffusers_model(mconfig) + else: + raise NotImplementedError(f"Unknown model format {model_name}: {model_format}") + + # usage statistics + toc = time.time() + print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) + if self._has_cuda(): + print( + '>> Max VRAM used to load the model:', + '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), + '\n>> Current VRAM usage:' + '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), + ) + return model, width, height, model_hash + + def _load_ckpt_model(self, mconfig): + config = mconfig.config + weights = mconfig.weights + vae = mconfig.get('vae', None) + width = mconfig.width + height = mconfig.height + + c = OmegaConf.load(config) + with open(weights, 'rb') as f: weight_bytes = f.read() - model_hash = self._cached_sha256(weights,weight_bytes) + model_hash = self._cached_sha256(weights, weight_bytes) sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu') del weight_bytes sd = sd['state_dict'] @@ -252,28 +267,19 @@ class ModelCache(object): model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here model.cond_stage_model.device = self.device - + model.eval() for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): module._orig_padding_mode = module.padding_mode - # usage statistics - toc = time.time() - print(f'>> Model loaded in', '%4.2fs' % (toc - tic)) - - if self._has_cuda(): - print( - '>> Max VRAM used to load the model:', - '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), - '\n>> Current VRAM usage:' - '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), - ) - return model, width, height, model_hash - - def offload_model(self, model_name:str) -> None: + + def _load_diffusers_model(self, mconfig): + raise NotImplementedError() # return pipeline, width, height, model_hash + + def offload_model(self, model_name:str): ''' Offload the indicated model to CPU. Will call _make_cache_room() to free space if needed. @@ -288,7 +294,7 @@ class ModelCache(object): gc.collect() if self._has_cuda(): torch.cuda.empty_cache() - + def scan_model(self, model_name, checkpoint): # scan model print(f'>> Scanning Model: {model_name}') @@ -318,7 +324,7 @@ class ModelCache(object): if least_recent_model is not None: del self.models[least_recent_model] gc.collect() - + def print_vram_usage(self) -> None: if self._has_cuda: print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) @@ -353,12 +359,12 @@ class ModelCache(object): if model_name in self.stack: self.stack.remove(model_name) self.models.pop(model_name,None) - + def _model_to_cpu(self,model): if self.device != 'cpu': model.cond_stage_model.device = 'cpu' model.first_stage_model.to('cpu') - model.cond_stage_model.to('cpu') + model.cond_stage_model.to('cpu') model.model.to('cpu') return model.to('cpu') else: @@ -388,7 +394,7 @@ class ModelCache(object): with contextlib.suppress(ValueError): self.stack.remove(model_name) self.stack.append(model_name) - + def _has_cuda(self) -> bool: return self.device.type == 'cuda'