diff --git a/ldm/dream/args.py b/ldm/dream/args.py index db6d963645..ada8975e96 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -174,31 +174,37 @@ class Args(object): switches.append(f'-W {a["width"]}') switches.append(f'-H {a["height"]}') switches.append(f'-C {a["cfg_scale"]}') - switches.append(f'-A {a["sampler_name"]}') if a['grid']: switches.append('--grid') if a['seamless']: switches.append('--seamless') + + # img2img generations have parameters relevant only to them and have special handling if a['init_img'] and len(a['init_img'])>0: switches.append(f'-I {a["init_img"]}') - if a['init_mask'] and len(a['init_mask'])>0: - switches.append(f'-M {a["init_mask"]}') - if a['init_color'] and len(a['init_color'])>0: - switches.append(f'--init_color {a["init_color"]}') - if a['fit']: - switches.append(f'--fit') - if a['init_img'] and a['strength'] and a['strength']>0: - switches.append(f'-f {a["strength"]}') + switches.append(f'-A ddim') # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS + if a['fit']: + switches.append(f'--fit') + if a['init_mask'] and len(a['init_mask'])>0: + switches.append(f'-M {a["init_mask"]}') + if a['init_color'] and len(a['init_color'])>0: + switches.append(f'--init_color {a["init_color"]}') + if a['strength'] and a['strength']>0: + switches.append(f'-f {a["strength"]}') + else: + switches.append(f'-A {a["sampler_name"]}') + + # gfpgan-specific parameters if a['gfpgan_strength']: switches.append(f'-G {a["gfpgan_strength"]}') + + # esrgan-specific parameters if a['upscale']: switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}') if a['embiggen']: switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}') if a['embiggen_tiles']: switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}') - if a['variation_amount'] > 0: - switches.append(f'-v {a["variation_amount"]}') if a['with_variations']: formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"])) switches.append(f'-V {formatted_variations}') @@ -618,18 +624,24 @@ def metadata_dumps(opt, postprocessing=postprocessing ) - # TODO: This is just a hack until postprocessing pipeline work completed - image_dict['postprocessing'] = [] - if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0: - image_dict['postprocessing'].append('GFPGAN (not RFC compliant)') - if image_dict['upscale'] and image_dict['upscale'][0] > 0: - image_dict['postprocessing'].append('ESRGAN (not RFC compliant)') + # 'postprocessing' is either null or an array of postprocessing metadatal + if postprocessing: + # TODO: This is just a hack until postprocessing pipeline work completed + image_dict['postprocessing'] = [] + + if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0: + image_dict['postprocessing'].append('GFPGAN (not RFC compliant)') + if image_dict['upscale'] and image_dict['upscale'][0] > 0: + image_dict['postprocessing'].append('ESRGAN (not RFC compliant)') + else: + image_dict['postprocessing'] = None # remove any image keys not mentioned in RFC #266 rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', 'cfg_scale','step_number','width','height','extra','strength'] rfc_dict ={} + for item in image_dict.items(): key,value = item if key in rfc266_img_fields: @@ -637,25 +649,24 @@ def metadata_dumps(opt, # semantic drift rfc_dict['sampler'] = image_dict.get('sampler_name',None) - + # display weighted subprompts (liable to change) if opt.prompt: subprompts = split_weighted_subprompts(opt.prompt) subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts] rfc_dict['prompt'] = subprompts - # variations - if opt.with_variations: - variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] - rfc_dict['variations'] = variations + # 'variations' should always exist and be an array, empty or consisting of {'seed': seed, 'weight': weight} pairs + rfc_dict['variations'] = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations] if opt.with_variations else [] if opt.init_img: rfc_dict['type'] = 'img2img' rfc_dict['strength_steps'] = rfc_dict.pop('strength') rfc_dict['orig_hash'] = calculate_init_img_hash(opt.init_img) - rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS + rfc_dict['sampler'] = 'ddim' # TODO: FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS else: rfc_dict['type'] = 'txt2img' + rfc_dict.pop('strength') if len(seeds)==0 and opt.seed: seeds=[seed]