mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bug and warning message fixes
- txt2img2img back to using DDIM as img2img sampler; results produced by some k* samplers are just not reliable enough for good user experience - img2img progress message clarifies why img2img steps taken != steps requested - warn of potential problems when user tries to run img2img on a small init image
This commit is contained in:
@ -27,7 +27,7 @@ class Inpaint(Img2Img):
|
|||||||
# klms samplers not supported yet, so ignore previous sampler
|
# klms samplers not supported yet, so ignore previous sampler
|
||||||
if isinstance(sampler,KSampler):
|
if isinstance(sampler,KSampler):
|
||||||
print(
|
print(
|
||||||
f">> sampler '{sampler.__class__.__name__}' is not yet supported for inpainting, using DDIMSampler instead."
|
f">> Using recommended DDIM sampler for inpainting."
|
||||||
)
|
)
|
||||||
sampler = DDIMSampler(self.model, device=self.model.device)
|
sampler = DDIMSampler(self.model, device=self.model.device)
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height}"
|
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
)
|
)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
@ -75,17 +75,19 @@ class Txt2Img2Img(Generator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
|
ddim_sampler = DDIMSampler(self.model, device=self.model.device)
|
||||||
|
ddim_sampler.make_schedule(
|
||||||
|
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
x = self.get_noise(width,height,False)
|
z_enc = ddim_sampler.stochastic_encode(
|
||||||
|
|
||||||
z_enc = sampler.stochastic_encode(
|
|
||||||
samples,
|
samples,
|
||||||
torch.tensor([t_enc]).to(self.model.device),
|
torch.tensor([t_enc]).to(self.model.device),
|
||||||
noise=x
|
noise=self.get_noise(width,height,False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# decode it
|
# decode it
|
||||||
samples = sampler.decode(
|
samples = ddim_sampler.decode(
|
||||||
z_enc,
|
z_enc,
|
||||||
c,
|
c,
|
||||||
t_enc,
|
t_enc,
|
||||||
|
@ -417,7 +417,8 @@ class Generate:
|
|||||||
generator = self._make_txt2img()
|
generator = self._make_txt2img()
|
||||||
|
|
||||||
generator.set_variation(
|
generator.set_variation(
|
||||||
self.seed, variation_amount, with_variations)
|
self.seed, variation_amount, with_variations
|
||||||
|
)
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
@ -626,18 +627,14 @@ class Generate:
|
|||||||
height,
|
height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if image.width < self.width and image.height < self.height:
|
||||||
|
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
||||||
|
|
||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
if self._has_transparency(image) and not mask:
|
if self._has_transparency(image):
|
||||||
print(
|
self._transparency_check_and_warning(image, mask)
|
||||||
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
|
||||||
if self._check_for_erasure(image):
|
|
||||||
print(
|
|
||||||
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
|
||||||
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
|
||||||
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
|
||||||
)
|
|
||||||
# this returns a torch tensor
|
# this returns a torch tensor
|
||||||
init_mask = self._create_init_mask(image,width,height,fit=fit)
|
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||||
|
|
||||||
if (image.width * image.height) > (self.width * self.height):
|
if (image.width * image.height) > (self.width * self.height):
|
||||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||||
@ -953,6 +950,17 @@ class Generate:
|
|||||||
colored += 1
|
colored += 1
|
||||||
return colored == 0
|
return colored == 0
|
||||||
|
|
||||||
|
def _transparency_check_and_warning(image, mask):
|
||||||
|
if not mask:
|
||||||
|
print(
|
||||||
|
'>> Initial image has transparent areas. Will inpaint in these regions.')
|
||||||
|
if self._check_for_erasure(image):
|
||||||
|
print(
|
||||||
|
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
||||||
|
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
||||||
|
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
||||||
|
)
|
||||||
|
|
||||||
def _squeeze_image(self, image):
|
def _squeeze_image(self, image):
|
||||||
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
x, y, resize_needed = self._resolution_check(image.width, image.height)
|
||||||
if resize_needed:
|
if resize_needed:
|
||||||
|
@ -51,6 +51,7 @@ class KSampler(Sampler):
|
|||||||
schedule,
|
schedule,
|
||||||
steps=model.num_timesteps,
|
steps=model.num_timesteps,
|
||||||
)
|
)
|
||||||
|
self.sigmas = None
|
||||||
self.ds = None
|
self.ds = None
|
||||||
self.s_in = None
|
self.s_in = None
|
||||||
|
|
||||||
@ -140,7 +141,7 @@ class KSampler(Sampler):
|
|||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': unconditional_guidance_scale,
|
'cond_scale': unconditional_guidance_scale,
|
||||||
}
|
}
|
||||||
print(f'>> Sampling with k_{self.schedule}')
|
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||||
return (
|
return (
|
||||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||||
@ -149,6 +150,8 @@ class KSampler(Sampler):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this code will support inpainting if and when ksampler API modified or
|
||||||
|
# a workaround is found.
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(
|
def p_sample(
|
||||||
self,
|
self,
|
||||||
|
@ -39,6 +39,7 @@ class Sampler(object):
|
|||||||
ddim_eta=0.0,
|
ddim_eta=0.0,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
|
self.total_steps = ddim_num_steps
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
@ -211,6 +212,7 @@ class Sampler(object):
|
|||||||
if ddim_use_original_steps
|
if ddim_use_original_steps
|
||||||
else np.flip(timesteps)
|
else np.flip(timesteps)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps=steps
|
total_steps=steps
|
||||||
|
|
||||||
iterator = tqdm(
|
iterator = tqdm(
|
||||||
@ -305,7 +307,7 @@ class Sampler(object):
|
|||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
print(f'>> Running {self.__class__.__name__} Sampling with {total_steps} timesteps')
|
print(f'>> Running {self.__class__.__name__} sampling starting at step {self.total_steps - t_start} of {self.total_steps} ({total_steps} new sampling steps)')
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = x_latent
|
||||||
|
Reference in New Issue
Block a user