From 3043af46204907dd05f0057f423c6f19afffdbc7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 23 Jun 2023 13:56:30 -0400 Subject: [PATCH] implement vae passthru --- invokeai/backend/install/migrate_to_3.py | 169 ++++++++++++++++++----- 1 file changed, 136 insertions(+), 33 deletions(-) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index 171c86f7e0..5e9a194125 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -15,7 +15,9 @@ import warnings from dataclasses import dataclass from pathlib import Path -from omegaconf import OmegaConf +from omegaconf import OmegaConf, DictConfig +from typing import Union + from diffusers import StableDiffusionPipeline, AutoencoderKL from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import ( @@ -104,6 +106,9 @@ class MigrateTo3(object): ''' copy a single file with logging ''' + if dest.exists(): + logger.info(f'Skipping existing {str(dest)}') + return logger.info(f'Copying {str(src)} to {str(dest)}') try: shutil.copy(src, dest) @@ -115,6 +120,10 @@ class MigrateTo3(object): ''' Recursively copy a directory with logging ''' + if dest.exists(): + logger.info(f'Skipping existing {str(dest)}') + return + logger.info(f'Copying {str(src)} to {str(dest)}') try: shutil.copytree(src, dest) @@ -127,7 +136,6 @@ class MigrateTo3(object): that looks like a model, and copy the model into the appropriate location within the destination models directory. ''' - dest_dir = self.dest_models for root, dirs, files in os.walk(src_dir): for f in files: # hack - don't copy raw learned_embeds.bin, let them @@ -139,7 +147,7 @@ class MigrateTo3(object): info = ModelProbe().heuristic_probe(model) if not info: continue - dest = Path(dest_dir, info.base_type.value, info.model_type.value, f) + dest = self._model_probe_to_path(info) / f self.copy_file(model, dest) except KeyboardInterrupt: raise @@ -151,14 +159,13 @@ class MigrateTo3(object): info = ModelProbe().heuristic_probe(model) if not info: continue - dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name) + dest = self._model_probe_to_path(info) / model.name self.copy_dir(model, dest) except KeyboardInterrupt: raise except Exception as e: logger.error(str(e)) - # TO DO: Rewrite this to support alternate locations for esrgan and gfpgan in init file def migrate_support_models(self): ''' Copy the clipseg, upscaler, and restoration models to their new @@ -203,39 +210,53 @@ class MigrateTo3(object): logger.info('Migrating core tokenizers and text encoders') target_dir = dest_directory / 'core' / 'convert' - # bert - bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs) - bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True) + self._migrate_pretrained(BertTokenizerFast, + repo_id='bert-base-uncased', + dest = target_dir / 'bert-base-uncased', + **kwargs) # sd-1 repo_id = 'openai/clip-vit-large-patch14' - pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs) - pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True) - - pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs) - pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True) + self._migrate_pretrained(CLIPTokenizer, + repo_id= repo_id, + dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', + **kwargs) + self._migrate_pretrained(CLIPTextModel, + repo_id = repo_id, + dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', + **kwargs) # sd-2 repo_id = "stabilityai/stable-diffusion-2" - pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True) - - pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True) + self._migrate_pretrained(CLIPTokenizer, + repo_id = repo_id, + dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer', + **{'subfolder':'tokenizer',**kwargs} + ) + self._migrate_pretrained(CLIPTextModel, + repo_id = repo_id, + dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder', + **{'subfolder':'text_encoder',**kwargs} + ) # VAE logger.info('Migrating stable diffusion VAE') - vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs) - vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True) - + self._migrate_pretrained(AutoencoderKL, + repo_id = 'stabilityai/sd-vae-ft-mse', + dest = target_dir / 'sd-vae-ft-mse', + **kwargs) + # safety checking logger.info('Migrating safety checker') repo_id = "CompVis/stable-diffusion-safety-checker" - pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) - - pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) + self._migrate_pretrained(AutoFeatureExtractor, + repo_id = repo_id, + dest = target_dir / 'stable-diffusion-safety-checker', + **kwargs) + self._migrate_pretrained(StableDiffusionSafetyChecker, + repo_id = repo_id, + dest = target_dir / 'stable-diffusion-safety-checker', + **kwargs) except KeyboardInterrupt: raise except Exception as e: @@ -262,8 +283,72 @@ class MigrateTo3(object): } self.dest_yaml.write(yaml.dump(stanza)) self.dest_yaml.flush() + + def _model_probe_to_path(self, info: ModelProbeInfo)->Path: + return Path(self.dest_models, info.base_type.value, info.model_type.value) - def migrate_repo_id(self, repo_id: str, model_name :str=None): + def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): + if dest.exists(): + logger.info(f'Skipping existing {dest}') + return + model = model_class.from_pretrained(repo_id, **kwargs) + self._save_pretrained(model, dest) + + def _save_pretrained(self, model, dest: Path): + if dest.exists(): + logger.info(f'Skipping existing {dest}') + return + model_name = dest.name + download_path = dest.with_name(f'{model_name}.downloading') + model.save_pretrained(download_path, safe_serialization=True) + download_path.replace(dest) + + def _download_vae(self, repo_id: str, subfolder:str=None)->Path: + vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) + info = ModelProbe().heuristic_probe(vae) + _, model_name = repo_id.split('/') + dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) + vae.save_pretrained(dest, safe_serialization=True) + return dest + + def _vae_path(self, vae: Union[str,dict])->Path: + ''' + Convert 2.3 VAE stanza to a straight path. + ''' + vae_path = None + + # First get a path + if isinstance(vae,str): + vae_path = vae + + elif isinstance(vae,DictConfig): + if p := vae.get('path'): + vae_path = p + elif repo_id := vae.get('repo_id'): + if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded + vae_path = 'models/core/convert/se-vae-ft-mse' + else: + vae_path = self._download_vae(repo_id, vae.get('subfolder')) + + assert vae_path is not None, "Couldn't find VAE for this model" + + # if the VAE is in the old models directory, then we must move it into the new + # one. VAEs outside of this directory can stay where they are. + vae_path = Path(vae_path) + if vae_path.is_relative_to(self.src_paths.models): + info = ModelProbe().heuristic_probe(vae_path) + dest = self._model_probe_to_path(info) / vae_path.name + if not dest.exists(): + self.copy_dir(vae_path,dest) + vae_path = dest + + if vae_path.is_relative_to(self.dest_models): + rel_path = vae_path.relative_to(self.dest_models) + return Path('models',rel_path) + else: + return vae_path + + def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config): ''' Migrate a locally-cached diffusers pipeline identified with a repo_id ''' @@ -295,10 +380,11 @@ class MigrateTo3(object): if not info: return - dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}') - pipeline.save_pretrained(dest, safe_serialization=True) + dest = self._model_probe_to_path(info) / repo_name + self._save_pretrained(pipeline, dest) + rel_path = Path('models',dest.relative_to(dest_dir)) - self.write_yaml(model_name, path=rel_path, info=info) + self.write_yaml(model_name, path=rel_path, info=info, **extra_config) def migrate_path(self, location: Path, model_name: str=None, **extra_config): ''' @@ -332,16 +418,29 @@ class MigrateTo3(object): for model_name, stanza in conf.items(): try: + passthru_args = {} + + if vae := stanza.get('vae'): + try: + passthru_args['vae'] = str(self._vae_path(vae)) + except Exception as e: + logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') + logger.warning(str(e)) + + if config := stanza.get('config'): + passthru_args['config'] = config + if repo_id := stanza.get('repo_id'): logger.info(f'Migrating diffusers model {model_name}') - self.migrate_repo_id(repo_id, model_name) + self.migrate_repo_id(repo_id, model_name, **passthru_args) elif location := stanza.get('weights'): logger.info(f'Migrating checkpoint model {model_name}') - self.migrate_path(Path(location), model_name, config=stanza.get('config')) + self.migrate_path(Path(location), model_name, **passthru_args) + elif location := stanza.get('path'): logger.info(f'Migrating diffusers model {model_name}') - self.migrate_path(Path(location), model_name, config=stanza.get('config')) + self.migrate_path(Path(location), model_name, **passthru_args) except KeyboardInterrupt: raise @@ -424,6 +523,7 @@ def do_migrate(src_directory: Path, dest_directory: Path): ) migrator.migrate() + shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True) (dest_directory / 'models').replace(dest_directory / 'models.orig') dest_models.replace(dest_directory / 'models') @@ -456,6 +556,7 @@ script, which will perform a full upgrade in place.""" required=True, help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")' ) +# TO DO: Implement full directory scanning # parser.add_argument('--all-models', # action="store_true", # help='Migrate all models found in `models` directory, not just those mentioned in models.yaml', @@ -476,3 +577,5 @@ script, which will perform a full upgrade in place.""" if __name__ == '__main__': main() + +