Merge branch 'main' into bugfix/dreambooth_ema

This commit is contained in:
Lincoln Stein
2023-03-23 23:24:15 -04:00
committed by GitHub
72 changed files with 1060 additions and 790 deletions

View File

@ -490,7 +490,7 @@ class Args(object):
"-z",
type=int,
default=6,
choices=range(0, 9),
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
)
@ -943,7 +943,6 @@ class Args(object):
"--png_compression",
"-z",
type=int,
default=6,
choices=range(0, 10),
dest="png_compression",
help="level of PNG compression, from 0 (none) to 9 (maximum). [6]",

View File

@ -58,7 +58,7 @@ class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
@ -116,7 +116,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
outputs = txt2img.generate(prompt='banana sushi', iterations=None)
for o in outputs:
print(o.image, o.seed)
'''
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
@ -167,7 +168,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
@ -177,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
@ -267,12 +268,12 @@ class Embiggen(Txt2Img):
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
@ -347,7 +348,6 @@ class Generator:
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
attention_maps_callback=attention_maps_callback,
seed=seed,
**kwargs,
)
results = []
@ -375,7 +375,8 @@ class Generator:
print("** An error occurred while getting initial noise **")
print(traceback.format_exc())
image = make_image(x_T)
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
@ -497,7 +498,8 @@ class Generator:
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
return matched_result
def sample_to_lowres_estimated_image(self, samples):
@staticmethod
def sample_to_lowres_estimated_image(samples):
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7

View File

@ -37,7 +37,6 @@ class Img2Img(Generator):
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
attention_maps_callback=None,
seed=None,
**kwargs,
):
"""
@ -64,7 +63,7 @@ class Img2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
@ -77,7 +76,7 @@ class Img2Img(Generator):
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed
seed=seed,
)
if (
pipeline_output.attention_map_saver is not None
@ -88,9 +87,7 @@ class Img2Img(Generator):
return make_image
def get_noise_like(self, like: torch.Tensor, seed: Optional[int]):
if seed is not None:
set_seed(seed)
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == "mps":
x = torch.randn_like(like, device="cpu").to(device)

View File

@ -159,6 +159,7 @@ class Inpaint(Img2Img):
seam_size: int,
seam_blur: int,
prompt,
seed,
sampler,
steps,
cfg_scale,
@ -192,7 +193,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise)
result = make_image(seam_noise, seed)
return result
@ -223,7 +224,6 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
seed=None,
**kwargs,
):
"""
@ -311,7 +311,7 @@ class Inpaint(Img2Img):
uc, c, cfg_scale
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, seed: int):
pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image,
mask=1 - mask, # expects white means "paint here."
@ -320,7 +320,7 @@ class Inpaint(Img2Img):
conditioning_data=conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed
seed=seed,
)
if (
@ -343,6 +343,7 @@ class Inpaint(Img2Img):
seam_size,
seam_blur,
prompt,
seed,
sampler,
seam_steps,
cfg_scale,

View File

@ -61,7 +61,7 @@ class Txt2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T) -> PIL.Image.Image:
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
noise=x_T,

View File

@ -64,7 +64,7 @@ class Txt2Img2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T):
def make_image(x_T: torch.Tensor, _: int):
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
latents=torch.zeros_like(x_T),
num_inference_steps=steps,

View File

@ -1085,9 +1085,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
dlogging.set_verbosity_error()
checkpoint = (
load_file(checkpoint_path)
if Path(checkpoint_path).suffix == ".safetensors"
else torch.load(checkpoint_path)
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub")
pipeline_class = (

View File

@ -97,7 +97,7 @@ class ModelManager(object):
If on disk, will load from there.
"""
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
@ -362,6 +362,7 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
@ -436,7 +437,6 @@ class ModelManager(object):
height = width
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@ -732,9 +732,9 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = (
safetensors.torch.load_file(model_path)
if model_path.suffix == ".safetensors"
else torch.load(model_path)
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
# additional probing needed if no config file provided

View File

@ -6,7 +6,6 @@ The interface is through the Concepts() object.
"""
import os
import re
import traceback
from typing import Callable
from urllib import error as ul_error
from urllib import request
@ -15,7 +14,6 @@ from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
"""
if not concept_name in self.list_concepts():
print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
except ul_error.HTTPError as e:
if e.code == 404:
print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
return False
except ul_error.URLError as e:
print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
import einops
import PIL.Image
from accelerate.utils import set_seed
import psutil
import torch
import torchvision.transforms as T
@ -694,7 +695,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
noise = noise_func(initial_latents, seed)
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(
initial_latents,
@ -796,7 +799,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
init_image_latents = self.non_noised_latents_from_image(
init_image, device=device, dtype=latents_dtype
)
noise = noise_func(init_image_latents, seed)
if seed is not None:
set_seed(seed)
noise = noise_func(init_image_latents)
if mask.dim() == 3:
mask = mask.unsqueeze(0)

View File

@ -3,6 +3,9 @@ import math
import multiprocessing as mp
import os
import re
import io
import base64
from collections import abc
from inspect import isfunction
from pathlib import Path
@ -364,3 +367,16 @@ def url_attachment_name(url: str) -> dict:
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
"""
Converts an image into a base64 image dataURL.
"""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower())
image_base64 = f"data:{mime_type};base64," + base64.b64encode(
buffered.getvalue()
).decode("UTF-8")
return image_base64