mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into fix-prompts
This commit is contained in:
121
ldm/generate.py
121
ldm/generate.py
@ -56,23 +56,8 @@ torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -126,12 +111,13 @@ still work.
|
||||
The full list of arguments to Generate() are:
|
||||
gr = Generate(
|
||||
# these values are set once and shouldn't be changed
|
||||
conf = path to configuration file ('configs/models.yaml')
|
||||
model = symbolic name of the model in the configuration file
|
||||
precision = float precision to be used
|
||||
conf:str = path to configuration file ('configs/models.yaml')
|
||||
model:str = symbolic name of the model in the configuration file
|
||||
precision:float = float precision to be used
|
||||
safety_checker:bool = activate safety checker [False]
|
||||
|
||||
# this value is sticky and maintained between generation calls
|
||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
|
||||
# these are deprecated - use conf and model instead
|
||||
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
||||
@ -148,7 +134,7 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = 'stable-diffusion-1.4',
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
@ -164,7 +150,6 @@ class Generate:
|
||||
free_gpu_mem=False,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.model_name = model
|
||||
self.height = None
|
||||
self.width = None
|
||||
self.model_cache = None
|
||||
@ -211,6 +196,7 @@ class Generate:
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
|
||||
@ -287,6 +273,8 @@ class Generate:
|
||||
upscale = None,
|
||||
# this is specific to inpainting and causes more extreme inpainting
|
||||
inpaint_replace = 0.0,
|
||||
# This will help match inpainted areas to the original image more smoothly
|
||||
mask_blur_radius: int = 8,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
@ -407,7 +395,7 @@ class Generate:
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
init_image,mask_image = self._make_images(
|
||||
init_image, mask_image = self._make_images(
|
||||
init_img,
|
||||
init_mask,
|
||||
width,
|
||||
@ -454,6 +442,7 @@ class Generate:
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -572,16 +561,19 @@ class Generate:
|
||||
from ldm.invoke.restoration.outcrop import Outcrop
|
||||
extend_instructions = {}
|
||||
for direction,pixels in _pairwise(opt.outcrop):
|
||||
extend_instructions[direction]=int(pixels)
|
||||
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
try:
|
||||
extend_instructions[direction]=int(pixels)
|
||||
except ValueError:
|
||||
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
|
||||
if len(extend_instructions)>0:
|
||||
restorer = Outcrop(image,self,)
|
||||
return restorer.process (
|
||||
extend_instructions,
|
||||
opt = opt,
|
||||
orig_opt = args,
|
||||
image_callback = callback,
|
||||
prefix = prefix,
|
||||
)
|
||||
|
||||
elif tool == 'embiggen':
|
||||
# fetch the metadata from the image
|
||||
@ -645,23 +637,22 @@ class Generate:
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask)
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
init_image = self._create_init_image(image,width,height,fit=fit)
|
||||
|
||||
if mask:
|
||||
mask_image = self._load_img(mask) # this returns an Image
|
||||
mask_image = self._load_img(mask)
|
||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
return init_image, init_mask
|
||||
return init_image,init_mask
|
||||
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
@ -717,8 +708,7 @@ class Generate:
|
||||
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
print(f'** Model switch failed **')
|
||||
return self.model
|
||||
return None
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
@ -879,46 +869,31 @@ class Generate:
|
||||
|
||||
def _create_init_image(self, image, width, height, fit=True):
|
||||
image = image.convert('RGB')
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.device)
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
def _create_init_mask(self, image, width, height, fit=True):
|
||||
# convert into a black/white mask
|
||||
image = self._image_to_mask(image)
|
||||
image = image.convert('RGB')
|
||||
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = image.resize((image.width//downsampling, image.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image.to(self.device)
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
# The mask is expected to have the region to be inpainted
|
||||
# with alpha transparency. It converts it into a black/white
|
||||
# image with the transparent part black.
|
||||
def _image_to_mask(self, mask_image, invert=False) -> Image:
|
||||
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if mask_image.mode == 'L':
|
||||
mask = mask_image
|
||||
else:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if invert:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||
prompt = text_mask[0]
|
||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||
@ -928,18 +903,8 @@ class Generate:
|
||||
segmented = self.txt2mask.segment(image, prompt)
|
||||
mask = segmented.to_mask(float(confidence_level))
|
||||
mask = mask.convert('RGB')
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
mask = self._fit_image(mask, (width, height))
|
||||
else:
|
||||
mask = self._squeeze_image(mask)
|
||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
mask = np.array(mask)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None].transpose(0, 3, 1, 2)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask.to(self.device)
|
||||
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
||||
return mask
|
||||
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
|
Reference in New Issue
Block a user