Merge pull request #110 from sajattack/half-precision-embeddings

Support full-precision embeddings in half precision inference mode
This commit is contained in:
Lincoln Stein 2022-08-28 15:36:26 -04:00 committed by GitHub
commit a7ac93a899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -215,11 +215,13 @@ class EmbeddingManager(nn.Module):
ckpt_path,
)
def load(self, ckpt_path):
def load(self, ckpt_path, full=True):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt['string_to_token']
self.string_to_param_dict = ckpt['string_to_param']
self.string_to_token_dict = ckpt["string_to_token"]
self.string_to_param_dict = ckpt["string_to_param"]
if not full:
for key, value in self.string_to_param_dict.items():
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
def get_embedding_norms_squared(self):
all_params = torch.cat(

View File

@ -498,7 +498,7 @@ class T2I:
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path)
model.embedding_manager.load(self.embedding_path, self.full_precision)
self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device