Merge pull request #4 from xraxra/halfPrecision

use Half precision for reduced memory usage & faster speed
This commit is contained in:
Lincoln Stein 2022-08-20 09:42:17 -04:00 committed by GitHub
commit 09afcc321c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -256,6 +256,8 @@ class T2I:
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
precision_scope = autocast if self.precision=="autocast" else nullcontext
# grid and individual are mutually exclusive, with individual taking priority. # grid and individual are mutually exclusive, with individual taking priority.
# not necessary, but needed for compatability with dream bot # not necessary, but needed for compatability with dream bot
if (grid is None): if (grid is None):
@ -279,7 +281,8 @@ class T2I:
assert os.path.isfile(init_img) assert os.path.isfile(init_img)
init_image = self._load_img(init_img).to(self.device) init_image = self._load_img(init_img).to(self.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space with precision_scope("cuda"):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False)
@ -292,7 +295,6 @@ class T2I:
t_enc = int(strength * steps) t_enc = int(strength * steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if self.precision=="autocast" else nullcontext
images = list() images = list()
seeds = list() seeds = list()
@ -401,6 +403,7 @@ class T2I:
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
model.cuda() model.cuda()
model.eval() model.eval()
model.half()
return model return model
def _load_img(self,path): def _load_img(self,path):