mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: ruff check - fix flake8-comprensions
This commit is contained in:
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
|
||||
def download_conversion_models():
|
||||
target_dir = config.models_path / "core/convert"
|
||||
kwargs = dict() # for future use
|
||||
kwargs = {} # for future use
|
||||
try:
|
||||
logger.info("Downloading core tokenizers and text encoders")
|
||||
|
||||
@ -252,26 +252,26 @@ def download_conversion_models():
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
description="RealESRGAN_x4plus.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
),
|
||||
dict(
|
||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
description="RealESRGAN_x2plus.pth",
|
||||
),
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
"description": "RealESRGAN_x4plus.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
"description": "RealESRGAN_x2plus.pth",
|
||||
},
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else list(),
|
||||
else [],
|
||||
)
|
||||
|
||||
|
||||
|
@ -182,10 +182,10 @@ class MigrateTo3(object):
|
||||
"""
|
||||
|
||||
dest_directory = self.dest_models
|
||||
kwargs = dict(
|
||||
cache_dir=self.root_directory / "models/hub",
|
||||
kwargs = {
|
||||
"cache_dir": self.root_directory / "models/hub",
|
||||
# local_files_only = True
|
||||
)
|
||||
}
|
||||
try:
|
||||
logger.info("Migrating core tokenizers and text encoders")
|
||||
target_dir = dest_directory / "core" / "convert"
|
||||
@ -316,11 +316,11 @@ class MigrateTo3(object):
|
||||
dest_dir = self.dest_models
|
||||
|
||||
cache = self.root_directory / "models/hub"
|
||||
kwargs = dict(
|
||||
cache_dir=cache,
|
||||
safety_checker=None,
|
||||
kwargs = {
|
||||
"cache_dir": cache,
|
||||
"safety_checker": None,
|
||||
# local_files_only = True,
|
||||
)
|
||||
}
|
||||
|
||||
owner, repo_name = repo_id.split("/")
|
||||
model_name = model_name or repo_name
|
||||
|
@ -120,7 +120,7 @@ class ModelInstall(object):
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
"""
|
||||
model_dict = dict()
|
||||
model_dict = {}
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
@ -134,7 +134,7 @@ class ModelInstall(object):
|
||||
model_dict[key] = model_info
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = [x for x in self.mgr.list_models()]
|
||||
installed_models = list(self.mgr.list_models())
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
@ -184,7 +184,7 @@ class ModelInstall(object):
|
||||
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
@ -234,7 +234,7 @@ class ModelInstall(object):
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = dict()
|
||||
models_installed = {}
|
||||
|
||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||
|
||||
@ -252,8 +252,7 @@ class ModelInstall(object):
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
[
|
||||
(path / x).exists()
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
@ -261,7 +260,6 @@ class ModelInstall(object):
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
]
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
@ -433,17 +431,17 @@ class ModelInstall(object):
|
||||
|
||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||
|
||||
attributes = dict(
|
||||
path=str(rel_path),
|
||||
description=str(description),
|
||||
model_format=info.format,
|
||||
)
|
||||
attributes = {
|
||||
"path": str(rel_path),
|
||||
"description": str(description),
|
||||
"model_format": info.format,
|
||||
}
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
dict(
|
||||
variant=info.variant_type,
|
||||
)
|
||||
{
|
||||
"variant": info.variant_type,
|
||||
}
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
@ -474,7 +472,7 @@ class ModelInstall(object):
|
||||
)
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update(dict(config=str(legacy_conf)))
|
||||
attributes.update({"config": str(legacy_conf)})
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||
@ -519,7 +517,7 @@ class ModelInstall(object):
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = list()
|
||||
paths = []
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
|
Reference in New Issue
Block a user