rewrite of widget display - marshalling needs rewrite

This commit is contained in:
Lincoln Stein
2023-06-15 23:32:33 -04:00
parent 5c740452f6
commit ada7399753
7 changed files with 473 additions and 464 deletions

View File

@ -24,14 +24,32 @@ from transformers import (
)
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management import ModelManager
from invokeai.backend.model_management.model_probe import (
ModelProbe, ModelType, BaseModelType
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType, ModelVariantInfo
)
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
diffusers.logging.set_verbosity_error()
model_names = set()
def unique_name(name,info)->str:
done = False
key = ModelManager.create_key(name,info.base_type,info.model_type)
unique_name = key
counter = 1
while not done:
if unique_name in model_names:
unique_name = f'{key}-{counter:0>2d}'
counter += 1
else:
done = True
model_names.add(unique_name)
name,_,_ = ModelManager.parse_key(unique_name)
return name
def create_directory_structure(dest: Path):
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
@ -113,10 +131,10 @@ def migrate_conversion_models(dest_directory: Path):
# sd-1
repo_id = 'openai/clip-vit-large-patch14'
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14' / 'text_encoder', safe_serialization=True)
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
@ -153,12 +171,48 @@ def migrate_tuning_models(dest: Path):
logger.info(f'Scanning {subdir}')
migrate_models(src, dest)
def write_yaml(model_name: str, path:Path, info:ModelVariantInfo, dest_yaml: io.TextIOBase):
name = unique_name(model_name, info)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{name}': {
'name': model_name,
'path': str(path),
'description': f'diffusers model {model_name}',
'format': 'diffusers',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
'upcast_attention': info.prediction_type == SchedulerPredictionType.VPrediction
}
}
dest_yaml.write(yaml.dump(stanza))
dest_yaml.flush()
def migrate_converted(dest_dir: Path, dest_yaml: io.TextIOBase):
for sub_dir in [Path('./models/converted_ckpts'),Path('./models/optimize-ckpts')]:
for model in sub_dir.glob('*'):
if not model.is_dir():
continue
info = ModelProbe().heuristic_probe(model)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, model.name)
try:
copy_dir(model,dest)
rel_path = Path('models',dest.relative_to(dest_dir))
write_yaml(model.name,path=rel_path,info=info, dest_yaml=dest_yaml)
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Could not migrate the converted diffusers {model.name}: {str(e)}. Skipping.')
def migrate_pipelines(dest_dir: Path, dest_yaml: io.TextIOBase):
cache = Path('./models/hub')
kwargs = dict(
cache_dir = cache,
local_files_only = True,
safety_checker = None,
# local_files_only = True,
)
for model in cache.glob('models--*'):
if len(list(model.glob('snapshots/**/model_index.json')))==0:
@ -166,38 +220,26 @@ def migrate_pipelines(dest_dir: Path, dest_yaml: io.TextIOBase):
_,owner,repo_name=model.name.split('--')
repo_id = f'{owner}/{repo_name}'
revisions = [x.name for x in model.glob('refs/*')]
for revision in revisions:
logger.info(f'Migrating {repo_id}, revision {revision}')
try:
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}-{revision}')
pipeline.save_pretrained(dest, safe_serialization=True)
rel_path = Path('models',dest.relative_to(dest_dir))
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{repo_name}-{revision}':
{
'name': repo_name,
'path': str(rel_path),
'description': f'diffusers model {repo_id}',
'format': 'diffusers',
'image_size': info.image_size,
'base': info.base_type.value,
'variant': info.variant_type.value,
'prediction_type': info.prediction_type.value,
}
}
print(yaml.dump(stanza),file=dest_yaml,end="")
dest_yaml.flush()
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Could not load the "{revision}" version of {repo_id}. Skipping.')
# if an fp16 is available we use that
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
logger.info(f'Migrating {repo_id}, revision {revision}')
try:
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
revision=revision,
**kwargs)
info = ModelProbe().heuristic_probe(pipeline)
if not info:
continue
dest = Path(dest_dir, info.base_type.value, info.model_type.value, f'{repo_name}')
pipeline.save_pretrained(dest, safe_serialization=True)
rel_path = Path('models',dest.relative_to(dest_dir))
write_yaml(repo_name, path=rel_path, info=info, dest_yaml=dest_yaml)
except KeyboardInterrupt:
raise
except Exception as e:
logger.warning(f'Could not load the "{revision}" version of {repo_id}. Skipping.')
def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
# find any checkpoints referred to in old models.yaml
@ -218,6 +260,7 @@ def migrate_checkpoints(dest_dir: Path, dest_yaml: io.TextIOBase):
dest = Path(dest_dir, info.base_type.value, info.model_type.value,weights.name)
copy_file(weights,dest)
weights = Path('models', info.base_type.value, info.model_type.value,weights.name)
model_name = unique_name(model_name, info)
stanza = {
f'{info.base_type.value}/{info.model_type.value}/{model_name}':
{
@ -261,15 +304,16 @@ def main():
os.chdir(root_directory)
with open(dest_yaml,'w') as yaml_file:
print(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
),file=yaml_file,end=""
)
yaml_file.write(yaml.dump({'__metadata__':
{'version':'3.0.0'}
}
)
)
create_directory_structure(dest_directory)
migrate_support_models(dest_directory)
migrate_conversion_models(dest_directory)
migrate_tuning_models(dest_directory)
migrate_converted(dest_directory,yaml_file)
migrate_pipelines(dest_directory,yaml_file)
migrate_checkpoints(dest_directory,yaml_file)