Merge branch 'development' into inpaint-model

This commit is contained in:
Lincoln Stein
2022-10-25 11:50:08 -04:00
committed by GitHub
24 changed files with 902 additions and 796 deletions

View File

@ -133,20 +133,21 @@ class Generate:
def __init__(
self,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
mconfig = OmegaConf.load(conf)
self.height = None
@ -177,6 +178,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -204,6 +206,19 @@ class Generate:
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
@ -277,6 +292,7 @@ class Generate:
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
use_mps_noise = False,
**args,
): # eat up additional cruft
"""
@ -421,6 +437,12 @@ class Generate:
generator.set_variation(
self.seed, variation_amount, with_variations
)
generator.use_mps_noise = use_mps_noise
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate(
prompt,
@ -432,10 +454,10 @@ class Generate:
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
@ -444,7 +466,8 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius
mask_blur_radius=mask_blur_radius,
safety_checker=checker
)
if init_color: