InvokeAI/scripts/orig_scripts/inpaint.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

93 lines
3.3 KiB
Python
Raw Normal View History

2021-12-21 02:23:41 +00:00
import argparse, os, sys, glob
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.devices import choose_torch_device
2021-12-21 02:23:41 +00:00
2023-07-28 13:46:44 +00:00
2021-12-21 02:23:41 +00:00
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
2023-07-28 13:46:44 +00:00
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
2021-12-21 02:23:41 +00:00
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
2023-07-28 13:46:44 +00:00
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
2021-12-21 02:23:41 +00:00
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
2023-07-28 13:46:44 +00:00
masked_image = (1 - mask) * image
2021-12-21 02:23:41 +00:00
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
2023-07-28 13:46:44 +00:00
batch[k] = batch[k] * 2.0 - 1.0
2021-12-21 02:23:41 +00:00
return batch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--indir",
type=str,
nargs="?",
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
)
parser.add_argument(
"--steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
opt = parser.parse_args()
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
images = [x.replace("_mask.png", ".png") for x in masks]
print(f"Found {len(masks)} inputs.")
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
model = instantiate_from_config(config.model)
2023-07-28 13:46:44 +00:00
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
2021-12-21 02:23:41 +00:00
2023-07-28 13:46:44 +00:00
device = choose_torch_device()
model = model.to(device)
2021-12-21 02:23:41 +00:00
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
batch = make_batch(image, mask, device=device)
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
2023-07-28 13:46:44 +00:00
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
2021-12-21 02:23:41 +00:00
c = torch.cat((c, cc), dim=1)
2023-07-28 13:46:44 +00:00
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim, _ = sampler.sample(
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
)
2021-12-21 02:23:41 +00:00
x_samples_ddim = model.decode_first_stage(samples_ddim)
2023-07-28 13:46:44 +00:00
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
2021-12-21 02:23:41 +00:00
2023-07-28 13:46:44 +00:00
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
2021-12-21 02:23:41 +00:00
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)