new TUI is fully functional; needs some polishing

This commit is contained in:
Lincoln Stein
2023-06-02 17:20:50 -04:00
parent 41f7758977
commit 1390b65a9c
5 changed files with 186 additions and 71 deletions

View File

@ -95,9 +95,12 @@ def install_requested_models(
model_manager.install_controlnet_models(controlnet.install_models, access_token=access_token)
model_manager.delete_controlnet_models(controlnet.remove_models)
model_manager.install_lora_models(lora.install_models)
model_manager.install_lora_models(lora.install_models, access_token=access_token)
model_manager.delete_lora_models(lora.remove_models)
model_manager.install_ti_models(ti.install_models, access_token=access_token)
model_manager.delete_ti_models(ti.remove_models)
# TODO: Replace next three paragraphs with calls into new model manager
if diffusers.remove_models and len(diffusers.remove_models) > 0:
logger.info("DELETING UNCHECKED STARTER MODELS")
@ -109,7 +112,7 @@ def install_requested_models(
if diffusers.install_models and len(diffusers.install_models) > 0:
logger.info("INSTALLING SELECTED STARTER MODELS")
successfully_downloaded = download_weight_datasets(
models=diffusers.install_initial_models,
models=diffusers.install_models,
access_token=None,
precision=precision,
) # FIX: for historical reasons, we don't use model manager here

View File

@ -1334,24 +1334,54 @@ class ModelManager(object):
installed_models = dict()
for root, dirs, files in os.walk(dir):
for name in files:
if Path(name).suffix in ['.safetensors','.ckpt','.pt']:
installed_models.update({name: True})
if Path(name).suffix not in ['.safetensors','.ckpt','.pt','.bin']:
continue
if name == 'pytorch_lora_weights.bin':
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_lora_models(self, model_names: list[str]):
def install_lora_models(self, model_names: list[str], access_token:str=None):
'''Download list of LoRA/LyCORIS models'''
short_names = OmegaConf.load(Dataset_path).get('lora') or {}
for name in model_names:
url = short_names.get(name) or name
download_with_resume(url, self.globals.lora_path)
print(name)
name = short_names.get(name) or name
# HuggingFace style LoRA
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading LoRA/LyCORIS model {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.lora_path / dest_dir,
model_name = 'pytorch_lora_weights.bin',
access_token = access_token,
)
elif name.startswith(("http:", "https:", "ftp:")):
download_with_resume(name, self.globals.lora_path)
else:
self.logger.error(f"Unknown repo_id or URL: {name}")
def delete_lora_models(self, model_names: List[str]):
'''Remove the list of lora models'''
for name in model_names:
path = self.globals.lora_path / name
if path.exists():
self.logger.info(f'Purging lora model {name}')
path.unlink()
file_or_directory = self.globals.lora_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging LoRA/LyCORIS {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.lora_path.glob(f'{name}.*'):
self.logger.info(f'Purging LoRA/LyCORIS {name}')
path.unlink()
def list_ti_models(self)->Dict[str,bool]:
'''Return a dict of installed textual models; key is either the shortname
@ -1362,21 +1392,50 @@ class ModelManager(object):
installed_models = {x: False for x in models.keys()}
dir = self.globals.embedding_path
installed_models = dict()
for root, dirs, files in os.walk(dir):
for name in files:
if not Path(name).suffix in ['.bin','.pt','.ckpt','.safetensors']:
continue
if name == 'learned_embeds.bin':
name = str(Path(root,name).parent)
name = Path(root,name).parent.stem #Path(root,name).stem
else:
name = Path(name).stem
installed_models.update({name: True})
return installed_models
def install_ti_models(self, model_names: list[str]):
def install_ti_models(self, model_names: list[str], access_token: str=None):
'''Download list of textual inversion embeddings'''
short_names = OmegaConf.load(Dataset_path).get('textual_inversion') or {}
for name in model_names:
url = short_names.get(name) or name
download_with_resume(url, self.globals.embedding_path)
name = short_names.get(name) or name
if re.match(r"^[\w.+-]+/([\w.+-]+)$", name):
self.logger.info(f'Downloading Textual Inversion embedding {name}')
_,dest_dir = name.split("/")
hf_download_with_resume(
repo_id = name,
model_dir = self.globals.embedding_path / dest_dir,
model_name = 'learned_embeds.bin',
access_token = access_token
)
elif name.startswith(('http:','https:','ftp:')):
download_with_resume(name, self.globals.embedding_path)
else:
self.logger.error(f'{name} does not look like either a HuggingFace repo_id or a downloadable URL')
def delete_ti_models(self, model_names: list[str]):
'''Remove TI embeddings from disk'''
for name in model_names:
file_or_directory = self.globals.embedding_path / name
if file_or_directory.is_dir():
self.logger.info(f'Purging textual inversion embedding {name}')
shutil.rmtree(str(file_or_directory))
else:
for path in self.globals.embedding_path.glob(f'{name}.*'):
self.logger.info(f'Purging textual inversion embedding {name}')
path.unlink()
def list_controlnet_models(self)->Dict[str,bool]:
'''Return a dict of installed controlnet models; key is repo_id or short name
of model (defined in INITIAL_MODELS), and value is True if installed'''

View File

@ -322,8 +322,8 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or exist_size == content_length:
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0: