mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
support full-precision embeddings in half precision mode
This commit is contained in:
parent
8b682ac83b
commit
9b5101cd8d
@ -135,11 +135,13 @@ class EmbeddingManager(nn.Module):
|
|||||||
torch.save({"string_to_token": self.string_to_token_dict,
|
torch.save({"string_to_token": self.string_to_token_dict,
|
||||||
"string_to_param": self.string_to_param_dict}, ckpt_path)
|
"string_to_param": self.string_to_param_dict}, ckpt_path)
|
||||||
|
|
||||||
def load(self, ckpt_path):
|
def load(self, ckpt_path, full=True):
|
||||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
self.string_to_token_dict = ckpt["string_to_token"]
|
self.string_to_token_dict = ckpt["string_to_token"]
|
||||||
self.string_to_param_dict = ckpt["string_to_param"]
|
self.string_to_param_dict = ckpt["string_to_param"]
|
||||||
|
if not full:
|
||||||
|
for key, value in self.string_to_param_dict.items():
|
||||||
|
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
|
||||||
|
|
||||||
def get_embedding_norms_squared(self):
|
def get_embedding_norms_squared(self):
|
||||||
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
|
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
|
||||||
@ -161,4 +163,4 @@ class EmbeddingManager(nn.Module):
|
|||||||
|
|
||||||
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
|
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -454,7 +454,7 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
|
self.device = torch.device(self.device) if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = self._load_model_from_config(config,self.weights)
|
model = self._load_model_from_config(config,self.weights)
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
model.embedding_manager.load(self.embedding_path)
|
model.embedding_manager.load(self.embedding_path, self.full_precision)
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
self.model.cond_stage_model.device = self.device
|
self.model.cond_stage_model.device = self.device
|
||||||
|
Loading…
x
Reference in New Issue
Block a user