implement vae passthru

This commit is contained in:
Lincoln Stein 2023-06-23 13:56:30 -04:00
parent afd19ab61a
commit 3043af4620

View File

@ -15,7 +15,9 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from omegaconf import OmegaConf from omegaconf import OmegaConf, DictConfig
from typing import Union
from diffusers import StableDiffusionPipeline, AutoencoderKL from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import ( from transformers import (
@ -104,6 +106,9 @@ class MigrateTo3(object):
''' '''
copy a single file with logging copy a single file with logging
''' '''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}') logger.info(f'Copying {str(src)} to {str(dest)}')
try: try:
shutil.copy(src, dest) shutil.copy(src, dest)
@ -115,6 +120,10 @@ class MigrateTo3(object):
''' '''
Recursively copy a directory with logging Recursively copy a directory with logging
''' '''
if dest.exists():
logger.info(f'Skipping existing {str(dest)}')
return
logger.info(f'Copying {str(src)} to {str(dest)}') logger.info(f'Copying {str(src)} to {str(dest)}')
try: try:
shutil.copytree(src, dest) shutil.copytree(src, dest)
@ -127,7 +136,6 @@ class MigrateTo3(object):
that looks like a model, and copy the model into the that looks like a model, and copy the model into the
appropriate location within the destination models directory. appropriate location within the destination models directory.
''' '''
dest_dir = self.dest_models
for root, dirs, files in os.walk(src_dir): for root, dirs, files in os.walk(src_dir):
for f in files: for f in files:
# hack - don't copy raw learned_embeds.bin, let them # hack - don't copy raw learned_embeds.bin, let them
@ -139,7 +147,7 @@ class MigrateTo3(object):
info = ModelProbe().heuristic_probe(model) info = ModelProbe().heuristic_probe(model)
if not info: if not info:
continue continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f) dest = self._model_probe_to_path(info) / f
self.copy_file(model, dest) self.copy_file(model, dest)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
@ -151,14 +159,13 @@ class MigrateTo3(object):
info = ModelProbe().heuristic_probe(model) info = ModelProbe().heuristic_probe(model)
if not info: if not info:
continue continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name) dest = self._model_probe_to_path(info) / model.name
self.copy_dir(model, dest) self.copy_dir(model, dest)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
# TO DO: Rewrite this to support alternate locations for esrgan and gfpgan in init file
def migrate_support_models(self): def migrate_support_models(self):
''' '''
Copy the clipseg, upscaler, and restoration models to their new Copy the clipseg, upscaler, and restoration models to their new
@ -203,39 +210,53 @@ class MigrateTo3(object):
logger.info('Migrating core tokenizers and text encoders') logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory / 'core' / 'convert' target_dir = dest_directory / 'core' / 'convert'
# bert self._migrate_pretrained(BertTokenizerFast,
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs) repo_id='bert-base-uncased',
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True) dest = target_dir / 'bert-base-uncased',
**kwargs)
# sd-1 # sd-1
repo_id = 'openai/clip-vit-large-patch14' repo_id = 'openai/clip-vit-large-patch14'
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs) self._migrate_pretrained(CLIPTokenizer,
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True) repo_id= repo_id,
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs) **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True) self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
**kwargs)
# sd-2 # sd-2
repo_id = "stabilityai/stable-diffusion-2" repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs) self._migrate_pretrained(CLIPTokenizer,
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True) repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs) **{'subfolder':'tokenizer',**kwargs}
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True) )
self._migrate_pretrained(CLIPTextModel,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
**{'subfolder':'text_encoder',**kwargs}
)
# VAE # VAE
logger.info('Migrating stable diffusion VAE') logger.info('Migrating stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs) self._migrate_pretrained(AutoencoderKL,
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True) repo_id = 'stabilityai/sd-vae-ft-mse',
dest = target_dir / 'sd-vae-ft-mse',
**kwargs)
# safety checking # safety checking
logger.info('Migrating safety checker') logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker" repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs) self._migrate_pretrained(AutoFeatureExtractor,
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs) **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) self._migrate_pretrained(StableDiffusionSafetyChecker,
repo_id = repo_id,
dest = target_dir / 'stable-diffusion-safety-checker',
**kwargs)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as e: except Exception as e:
@ -263,7 +284,71 @@ class MigrateTo3(object):
self.dest_yaml.write(yaml.dump(stanza)) self.dest_yaml.write(yaml.dump(stanza))
self.dest_yaml.flush() self.dest_yaml.flush()
def migrate_repo_id(self, repo_id: str, model_name :str=None): def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
return Path(self.dest_models, info.base_type.value, info.model_type.value)
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model = model_class.from_pretrained(repo_id, **kwargs)
self._save_pretrained(model, dest)
def _save_pretrained(self, model, dest: Path):
if dest.exists():
logger.info(f'Skipping existing {dest}')
return
model_name = dest.name
download_path = dest.with_name(f'{model_name}.downloading')
model.save_pretrained(download_path, safe_serialization=True)
download_path.replace(dest)
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
info = ModelProbe().heuristic_probe(vae)
_, model_name = repo_id.split('/')
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
vae.save_pretrained(dest, safe_serialization=True)
return dest
def _vae_path(self, vae: Union[str,dict])->Path:
'''
Convert 2.3 VAE stanza to a straight path.
'''
vae_path = None
# First get a path
if isinstance(vae,str):
vae_path = vae
elif isinstance(vae,DictConfig):
if p := vae.get('path'):
vae_path = p
elif repo_id := vae.get('repo_id'):
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
vae_path = 'models/core/convert/se-vae-ft-mse'
else:
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
assert vae_path is not None, "Couldn't find VAE for this model"
# if the VAE is in the old models directory, then we must move it into the new
# one. VAEs outside of this directory can stay where they are.
vae_path = Path(vae_path)
if vae_path.is_relative_to(self.src_paths.models):
info = ModelProbe().heuristic_probe(vae_path)
dest = self._model_probe_to_path(info) / vae_path.name
if not dest.exists():
self.copy_dir(vae_path,dest)
vae_path = dest
if vae_path.is_relative_to(self.dest_models):
rel_path = vae_path.relative_to(self.dest_models)
return Path('models',rel_path)
else:
return vae_path
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
''' '''
Migrate a locally-cached diffusers pipeline identified with a repo_id Migrate a locally-cached diffusers pipeline identified with a repo_id
''' '''
@ -295,10 +380,11 @@ class MigrateTo3(object):
if not info: if not info:
return return
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}') dest = self._model_probe_to_path(info) / repo_name
pipeline.save_pretrained(dest, safe_serialization=True) self._save_pretrained(pipeline, dest)
rel_path = Path('models',dest.relative_to(dest_dir)) rel_path = Path('models',dest.relative_to(dest_dir))
self.write_yaml(model_name, path=rel_path, info=info) self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
def migrate_path(self, location: Path, model_name: str=None, **extra_config): def migrate_path(self, location: Path, model_name: str=None, **extra_config):
''' '''
@ -332,16 +418,29 @@ class MigrateTo3(object):
for model_name, stanza in conf.items(): for model_name, stanza in conf.items():
try: try:
passthru_args = {}
if vae := stanza.get('vae'):
try:
passthru_args['vae'] = str(self._vae_path(vae))
except Exception as e:
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
logger.warning(str(e))
if config := stanza.get('config'):
passthru_args['config'] = config
if repo_id := stanza.get('repo_id'): if repo_id := stanza.get('repo_id'):
logger.info(f'Migrating diffusers model {model_name}') logger.info(f'Migrating diffusers model {model_name}')
self.migrate_repo_id(repo_id, model_name) self.migrate_repo_id(repo_id, model_name, **passthru_args)
elif location := stanza.get('weights'): elif location := stanza.get('weights'):
logger.info(f'Migrating checkpoint model {model_name}') logger.info(f'Migrating checkpoint model {model_name}')
self.migrate_path(Path(location), model_name, config=stanza.get('config')) self.migrate_path(Path(location), model_name, **passthru_args)
elif location := stanza.get('path'): elif location := stanza.get('path'):
logger.info(f'Migrating diffusers model {model_name}') logger.info(f'Migrating diffusers model {model_name}')
self.migrate_path(Path(location), model_name, config=stanza.get('config')) self.migrate_path(Path(location), model_name, **passthru_args)
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
@ -424,6 +523,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
) )
migrator.migrate() migrator.migrate()
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
(dest_directory / 'models').replace(dest_directory / 'models.orig') (dest_directory / 'models').replace(dest_directory / 'models.orig')
dest_models.replace(dest_directory / 'models') dest_models.replace(dest_directory / 'models')
@ -456,6 +556,7 @@ script, which will perform a full upgrade in place."""
required=True, required=True,
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")' help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
) )
# TO DO: Implement full directory scanning
# parser.add_argument('--all-models', # parser.add_argument('--all-models',
# action="store_true", # action="store_true",
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml', # help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
@ -476,3 +577,5 @@ script, which will perform a full upgrade in place."""
if __name__ == '__main__': if __name__ == '__main__':
main() main()