mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(txt2img): allow from_file to work with len(lines) < batch_size (#349)
This commit is contained in:
parent
720e5cd651
commit
eef788981c
@ -232,7 +232,12 @@ def main():
|
|||||||
print(f"reading prompts from {opt.from_file}")
|
print(f"reading prompts from {opt.from_file}")
|
||||||
with open(opt.from_file, "r") as f:
|
with open(opt.from_file, "r") as f:
|
||||||
data = f.read().splitlines()
|
data = f.read().splitlines()
|
||||||
data = list(chunk(data, batch_size))
|
if (len(data) >= batch_size):
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
else:
|
||||||
|
while (len(data) < batch_size):
|
||||||
|
data.append(data[-1])
|
||||||
|
data = [data]
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
sample_path = os.path.join(outpath, "samples")
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
@ -264,7 +269,7 @@ def main():
|
|||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
|
||||||
if not opt.klms:
|
if not opt.klms:
|
||||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
@ -284,7 +289,7 @@ def main():
|
|||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
|
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
|
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user