Merge branch 'main' into feat/node-cli-autocompleter

This commit is contained in:
Lincoln Stein 2023-03-30 07:51:51 -04:00 committed by GitHub
commit afb66a7884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 70 deletions

View File

@ -1264,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir, cache_dir=cache_dir,
) )
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae.to(precision),
text_encoder=text_model, text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,

View File

@ -18,7 +18,7 @@ import warnings
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union from typing import Any, Optional, Union, Callable
import safetensors import safetensors
import safetensors.torch import safetensors.torch
@ -630,14 +630,13 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
path_url_or_repo: str, path_url_or_repo: str,
convert: bool = True,
model_name: str = None, model_name: str = None,
description: str = None, description: str = None,
model_config_file: Path = None, model_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str: ) -> str:
""" """Accept a string which could be:
Accept a string which could be:
- a HF diffusers repo_id - a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file - a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file - a local path pointing to a legacy .ckpt or .safetensors file
@ -651,16 +650,20 @@ class ModelManager(object):
The model_name and/or description can be provided. If not, they will The model_name and/or description can be provided. If not, they will
be generated automatically. be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory. will only remain in memory.
The (potentially derived) name of the model is returned on success, or None The routine will do its best to figure out the config file
on failure. When multiple models are added from a directory, only the last needed to convert legacy checkpoint file, but if it can't it
imported one is returned. will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
""" """
model_path: Path = None model_path: Path = None
thing = path_url_or_repo # to save typing thing = path_url_or_repo # to save typing
@ -707,7 +710,7 @@ class ModelManager(object):
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
): ):
if model_name := self.heuristic_import( if model_name := self.heuristic_import(
str(m), convert, commit_to_conf=commit_to_conf str(m), commit_to_conf=commit_to_conf
): ):
print(f" >> {model_name} successfully imported") print(f" >> {model_name} successfully imported")
return model_name return model_name
@ -735,7 +738,7 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
checkpoint = None checkpoint = None
if model_path.suffix.endswith((".ckpt",".pt")): if model_path.suffix in [".ckpt",".pt"]:
self.scan_model(model_path,model_path) self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path) checkpoint = torch.load(model_path)
else: else:
@ -743,6 +746,12 @@ class ModelManager(object):
# additional probing needed if no config file provided # additional probing needed if no config file provided
if model_config_file is None: if model_config_file is None:
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
print(f" | Using config file {model_config_file.name}")
else:
model_type = self.probe_model_type(checkpoint) model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1: if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected") print(" | SD-v1 model detected")
@ -756,20 +765,18 @@ class ModelManager(object):
) )
elif model_type == SDLegacyType.V2_v: elif model_type == SDLegacyType.V2_v:
print( print(
" | SD-v2-v model detected; model will be converted to diffusers format" " | SD-v2-v model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2_e: elif model_type == SDLegacyType.V2_e:
print( print(
" | SD-v2-e model detected; model will be converted to diffusers format" " | SD-v2-e model detected"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml" Globals.root, "configs/stable-diffusion/v2-inference.yaml"
) )
convert = True
elif model_type == SDLegacyType.V2: elif model_type == SDLegacyType.V2:
print( print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@ -781,13 +788,29 @@ class ModelManager(object):
) )
return return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
print(f" | Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path( diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
) )
model_name = self.convert_and_import( model_name = self.convert_and_import(
model_path, model_path,
diffusers_path=diffuser_path, diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), vae=vae,
vae_path=str(vae_path),
model_name=model_name, model_name=model_name,
model_description=description, model_description=description,
original_config_file=model_config_file, original_config_file=model_config_file,
@ -829,8 +852,8 @@ class ModelManager(object):
return return
model_name = model_name or diffusers_path.name model_name = model_name or diffusers_path.name
model_description = model_description or f"Optimized version of {model_name}" model_description = model_description or f"Converted version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)") print(f" | Converting {model_name} to diffusers (30-60s)")
try: try:
# By passing the specified VAE to the conversion function, the autoencoder # By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file # will be built into the model rather than tacked on afterward via the config file
@ -848,7 +871,7 @@ class ModelManager(object):
scan_needed=scan_needed, scan_needed=scan_needed,
) )
print( print(
f" | Success. Optimized model is now located at {str(diffusers_path)}" f" | Success. Converted model is now located at {str(diffusers_path)}"
) )
print(f" | Writing new config file entry for {model_name}") print(f" | Writing new config file entry for {model_name}")
new_config = dict( new_config = dict(

View File

@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False): def import_model(model_path: str, gen, opt, completer):
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
) )
if not imported_name: if not imported_name:
@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path, model_path,
model_name=model_name, model_name=model_name,
description=model_desc, description=model_desc,
convert=convert,
model_config_file=config_file, model_config_file=config_file,
) )
if not imported_name: if not imported_name:
@ -757,7 +755,6 @@ def _get_model_name_and_desc(
) )
return model_name, model_description return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer): def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager manager = gen.model_manager
@ -788,7 +785,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
) )
else: else:
try: try:
import_model(model_name_or_path, gen, opt, completer, convert=True) import_model(model_name_or_path, gen, opt, completer)
except KeyboardInterrupt: except KeyboardInterrupt:
return return

View File

@ -1,9 +1,24 @@
import i18n from 'i18next'; import i18n from 'i18next';
import LanguageDetector from 'i18next-browser-languagedetector'; import LanguageDetector from 'i18next-browser-languagedetector';
import Backend from 'i18next-http-backend'; import Backend from 'i18next-http-backend';
import { initReactI18next } from 'react-i18next'; import { initReactI18next } from 'react-i18next';
i18n
import translationEN from '../dist/locales/en.json';
if (import.meta.env.MODE === 'package') {
i18n.use(initReactI18next).init({
lng: 'en',
resources: {
en: { translation: translationEN },
},
debug: false,
interpolation: {
escapeValue: false,
},
returnNull: false,
});
} else {
i18n
.use(Backend) .use(Backend)
.use(LanguageDetector) .use(LanguageDetector)
.use(initReactI18next) .use(initReactI18next)
@ -18,5 +33,6 @@ i18n
}, },
returnNull: false, returnNull: false,
}); });
}
export default i18n; export default i18n;