From b537e927896560c43caa2de48dd9fd0cf5041545 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 12 Oct 2022 03:03:29 -0400 Subject: [PATCH] move tokenizer into cpu cache as well --- ldm/generate.py | 2 -- ldm/invoke/model_cache.py | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 9c015d1250..3856bfbe2c 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -678,8 +678,6 @@ class Generate: self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' ) - # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here - self.model.cond_stage_model.device = self.device self._set_sampler() for m in self.model.modules(): diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index f7c4bbe7eb..96997d83d0 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -114,7 +114,11 @@ class ModelCache(object): ''' models = self.list_models() for name in models: - print(f'{name:20s} {models[name]["status"]:>10s} {models[name]["description"]}') + line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["description"]}' + if models[name]['status'] == 'active': + print(f'\033[1m{line}\033[0m') + else: + print(line) def _check_memory(self): avail_memory = psutil.virtual_memory()[1] @@ -164,6 +168,8 @@ class ModelCache(object): print('>> Using more accurate float32 precision') model.to(self.device) + # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here + model.cond_stage_model.device = self.device model.eval() # usage statistics @@ -190,6 +196,7 @@ class ModelCache(object): def _model_to_cpu(self,model): if self._has_cuda(): + model.cond_stage_model.device = 'cpu' model.first_stage_model.to('cpu') model.cond_stage_model.to('cpu') model.model.to('cpu') @@ -200,6 +207,7 @@ class ModelCache(object): model.to(self.device) model.first_stage_model.to(self.device) model.cond_stage_model.to(self.device) + model.cond_stage_model.device = self.device return model def _pop_oldest_model(self):