From 9b5101cd8d2437dbbf122f56e2f57e5371be7a55 Mon Sep 17 00:00:00 2001 From: Paul Sajna Date: Fri, 26 Aug 2022 08:30:58 -0700 Subject: [PATCH] support full-precision embeddings in half precision mode --- ldm/modules/embedding_manager.py | 8 +++++--- ldm/simplet2i.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py index 7020a27b9a..6b6df1dafe 100644 --- a/ldm/modules/embedding_manager.py +++ b/ldm/modules/embedding_manager.py @@ -135,11 +135,13 @@ class EmbeddingManager(nn.Module): torch.save({"string_to_token": self.string_to_token_dict, "string_to_param": self.string_to_param_dict}, ckpt_path) - def load(self, ckpt_path): + def load(self, ckpt_path, full=True): ckpt = torch.load(ckpt_path, map_location='cpu') - self.string_to_token_dict = ckpt["string_to_token"] self.string_to_param_dict = ckpt["string_to_param"] + if not full: + for key, value in self.string_to_param_dict.items(): + self.string_to_param_dict[key] = torch.nn.Parameter(value.half()) def get_embedding_norms_squared(self): all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim @@ -161,4 +163,4 @@ class EmbeddingManager(nn.Module): loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings - return loss \ No newline at end of file + return loss diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 4737d90ba7..d0a8798d3b 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -454,7 +454,7 @@ The vast majority of these arguments default to reasonable values. self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu") model = self._load_model_from_config(config,self.weights) if self.embedding_path is not None: - model.embedding_manager.load(self.embedding_path) + model.embedding_manager.load(self.embedding_path, self.full_precision) self.model = model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here self.model.cond_stage_model.device = self.device