mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement vae passthru
This commit is contained in:
parent
afd19ab61a
commit
3043af4620
@ -15,7 +15,9 @@ import warnings
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf, DictConfig
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -104,6 +106,9 @@ class MigrateTo3(object):
|
|||||||
'''
|
'''
|
||||||
copy a single file with logging
|
copy a single file with logging
|
||||||
'''
|
'''
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
|
return
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
try:
|
try:
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
@ -115,6 +120,10 @@ class MigrateTo3(object):
|
|||||||
'''
|
'''
|
||||||
Recursively copy a directory with logging
|
Recursively copy a directory with logging
|
||||||
'''
|
'''
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f'Copying {str(src)} to {str(dest)}')
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
try:
|
try:
|
||||||
shutil.copytree(src, dest)
|
shutil.copytree(src, dest)
|
||||||
@ -127,7 +136,6 @@ class MigrateTo3(object):
|
|||||||
that looks like a model, and copy the model into the
|
that looks like a model, and copy the model into the
|
||||||
appropriate location within the destination models directory.
|
appropriate location within the destination models directory.
|
||||||
'''
|
'''
|
||||||
dest_dir = self.dest_models
|
|
||||||
for root, dirs, files in os.walk(src_dir):
|
for root, dirs, files in os.walk(src_dir):
|
||||||
for f in files:
|
for f in files:
|
||||||
# hack - don't copy raw learned_embeds.bin, let them
|
# hack - don't copy raw learned_embeds.bin, let them
|
||||||
@ -139,7 +147,7 @@ class MigrateTo3(object):
|
|||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
if not info:
|
if not info:
|
||||||
continue
|
continue
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f)
|
dest = self._model_probe_to_path(info) / f
|
||||||
self.copy_file(model, dest)
|
self.copy_file(model, dest)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
@ -151,14 +159,13 @@ class MigrateTo3(object):
|
|||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
if not info:
|
if not info:
|
||||||
continue
|
continue
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
|
dest = self._model_probe_to_path(info) / model.name
|
||||||
self.copy_dir(model, dest)
|
self.copy_dir(model, dest)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
# TO DO: Rewrite this to support alternate locations for esrgan and gfpgan in init file
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
'''
|
'''
|
||||||
Copy the clipseg, upscaler, and restoration models to their new
|
Copy the clipseg, upscaler, and restoration models to their new
|
||||||
@ -203,39 +210,53 @@ class MigrateTo3(object):
|
|||||||
logger.info('Migrating core tokenizers and text encoders')
|
logger.info('Migrating core tokenizers and text encoders')
|
||||||
target_dir = dest_directory / 'core' / 'convert'
|
target_dir = dest_directory / 'core' / 'convert'
|
||||||
|
|
||||||
# bert
|
self._migrate_pretrained(BertTokenizerFast,
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
repo_id='bert-base-uncased',
|
||||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
dest = target_dir / 'bert-base-uncased',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = 'openai/clip-vit-large-patch14'
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True)
|
repo_id= repo_id,
|
||||||
|
dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer',
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
**kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True)
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
**{'subfolder':'tokenizer',**kwargs}
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
)
|
||||||
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
||||||
|
**{'subfolder':'text_encoder',**kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info('Migrating stable diffusion VAE')
|
logger.info('Migrating stable diffusion VAE')
|
||||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
self._migrate_pretrained(AutoencoderKL,
|
||||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
repo_id = 'stabilityai/sd-vae-ft-mse',
|
||||||
|
dest = target_dir / 'sd-vae-ft-mse',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info('Migrating safety checker')
|
logger.info('Migrating safety checker')
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
self._migrate_pretrained(AutoFeatureExtractor,
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
**kwargs)
|
||||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
||||||
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
|
**kwargs)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -262,8 +283,72 @@ class MigrateTo3(object):
|
|||||||
}
|
}
|
||||||
self.dest_yaml.write(yaml.dump(stanza))
|
self.dest_yaml.write(yaml.dump(stanza))
|
||||||
self.dest_yaml.flush()
|
self.dest_yaml.flush()
|
||||||
|
|
||||||
|
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||||
|
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||||
|
|
||||||
def migrate_repo_id(self, repo_id: str, model_name :str=None):
|
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs):
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {dest}')
|
||||||
|
return
|
||||||
|
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||||
|
self._save_pretrained(model, dest)
|
||||||
|
|
||||||
|
def _save_pretrained(self, model, dest: Path):
|
||||||
|
if dest.exists():
|
||||||
|
logger.info(f'Skipping existing {dest}')
|
||||||
|
return
|
||||||
|
model_name = dest.name
|
||||||
|
download_path = dest.with_name(f'{model_name}.downloading')
|
||||||
|
model.save_pretrained(download_path, safe_serialization=True)
|
||||||
|
download_path.replace(dest)
|
||||||
|
|
||||||
|
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||||
|
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||||
|
info = ModelProbe().heuristic_probe(vae)
|
||||||
|
_, model_name = repo_id.split('/')
|
||||||
|
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||||
|
vae.save_pretrained(dest, safe_serialization=True)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
def _vae_path(self, vae: Union[str,dict])->Path:
|
||||||
|
'''
|
||||||
|
Convert 2.3 VAE stanza to a straight path.
|
||||||
|
'''
|
||||||
|
vae_path = None
|
||||||
|
|
||||||
|
# First get a path
|
||||||
|
if isinstance(vae,str):
|
||||||
|
vae_path = vae
|
||||||
|
|
||||||
|
elif isinstance(vae,DictConfig):
|
||||||
|
if p := vae.get('path'):
|
||||||
|
vae_path = p
|
||||||
|
elif repo_id := vae.get('repo_id'):
|
||||||
|
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||||
|
vae_path = 'models/core/convert/se-vae-ft-mse'
|
||||||
|
else:
|
||||||
|
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||||
|
|
||||||
|
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||||
|
|
||||||
|
# if the VAE is in the old models directory, then we must move it into the new
|
||||||
|
# one. VAEs outside of this directory can stay where they are.
|
||||||
|
vae_path = Path(vae_path)
|
||||||
|
if vae_path.is_relative_to(self.src_paths.models):
|
||||||
|
info = ModelProbe().heuristic_probe(vae_path)
|
||||||
|
dest = self._model_probe_to_path(info) / vae_path.name
|
||||||
|
if not dest.exists():
|
||||||
|
self.copy_dir(vae_path,dest)
|
||||||
|
vae_path = dest
|
||||||
|
|
||||||
|
if vae_path.is_relative_to(self.dest_models):
|
||||||
|
rel_path = vae_path.relative_to(self.dest_models)
|
||||||
|
return Path('models',rel_path)
|
||||||
|
else:
|
||||||
|
return vae_path
|
||||||
|
|
||||||
|
def migrate_repo_id(self, repo_id: str, model_name :str=None, **extra_config):
|
||||||
'''
|
'''
|
||||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||||
'''
|
'''
|
||||||
@ -295,10 +380,11 @@ class MigrateTo3(object):
|
|||||||
if not info:
|
if not info:
|
||||||
return
|
return
|
||||||
|
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}')
|
dest = self._model_probe_to_path(info) / repo_name
|
||||||
pipeline.save_pretrained(dest, safe_serialization=True)
|
self._save_pretrained(pipeline, dest)
|
||||||
|
|
||||||
rel_path = Path('models',dest.relative_to(dest_dir))
|
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||||
self.write_yaml(model_name, path=rel_path, info=info)
|
self.write_yaml(model_name, path=rel_path, info=info, **extra_config)
|
||||||
|
|
||||||
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||||
'''
|
'''
|
||||||
@ -332,16 +418,29 @@ class MigrateTo3(object):
|
|||||||
for model_name, stanza in conf.items():
|
for model_name, stanza in conf.items():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
passthru_args = {}
|
||||||
|
|
||||||
|
if vae := stanza.get('vae'):
|
||||||
|
try:
|
||||||
|
passthru_args['vae'] = str(self._vae_path(vae))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||||
|
logger.warning(str(e))
|
||||||
|
|
||||||
|
if config := stanza.get('config'):
|
||||||
|
passthru_args['config'] = config
|
||||||
|
|
||||||
if repo_id := stanza.get('repo_id'):
|
if repo_id := stanza.get('repo_id'):
|
||||||
logger.info(f'Migrating diffusers model {model_name}')
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
self.migrate_repo_id(repo_id, model_name)
|
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get('weights'):
|
elif location := stanza.get('weights'):
|
||||||
logger.info(f'Migrating checkpoint model {model_name}')
|
logger.info(f'Migrating checkpoint model {model_name}')
|
||||||
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get('path'):
|
elif location := stanza.get('path'):
|
||||||
logger.info(f'Migrating diffusers model {model_name}')
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
self.migrate_path(Path(location), model_name, config=stanza.get('config'))
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
@ -424,6 +523,7 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
|||||||
)
|
)
|
||||||
migrator.migrate()
|
migrator.migrate()
|
||||||
|
|
||||||
|
shutil.rmtree(dest_directory / 'models.orig', ignore_errors=True)
|
||||||
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
(dest_directory / 'models').replace(dest_directory / 'models.orig')
|
||||||
dest_models.replace(dest_directory / 'models')
|
dest_models.replace(dest_directory / 'models')
|
||||||
|
|
||||||
@ -456,6 +556,7 @@ script, which will perform a full upgrade in place."""
|
|||||||
required=True,
|
required=True,
|
||||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||||
)
|
)
|
||||||
|
# TO DO: Implement full directory scanning
|
||||||
# parser.add_argument('--all-models',
|
# parser.add_argument('--all-models',
|
||||||
# action="store_true",
|
# action="store_true",
|
||||||
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
# help='Migrate all models found in `models` directory, not just those mentioned in models.yaml',
|
||||||
@ -476,3 +577,5 @@ script, which will perform a full upgrade in place."""
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user