diff --git a/scripts/dream.py b/scripts/dream.py index 85b2ed5211..fec0475724 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -13,11 +13,13 @@ import ldm.dream.readline from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid +from omegaconf import OmegaConf def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser() opt = arg_parser.parse_args() + """ if opt.laion400m: # defaults suitable to the older latent diffusion weights width = 256 @@ -33,6 +35,15 @@ def main(): weights = opt.weights else: weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt' + """ + try: + models = OmegaConf.load('configs/models.yaml') + width = models[opt.model].width + height = models[opt.model].height + config = models[opt.model].config + weights = models[opt.model].weights + except (FileNotFoundError, IOError, KeyError) as e: + print(f'{e}. Aborting.') print('* Initializing, be patient...\n') sys.path.append('.') @@ -426,6 +437,11 @@ def create_argv_parser(): default='model', help='Indicates the Stable Diffusion model to use.', ) + parser.add_argument( + '--model', + default='stable-diffusion-1.4', + help='Indicates which Diffusion model to load.', + ) return parser @@ -540,4 +556,4 @@ def create_cmd_parser(): if __name__ == '__main__': - main() + main() \ No newline at end of file