mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): move HF token helper to route
This commit is contained in:
parent
dea9142cb8
commit
9a5575b46b
@ -1,13 +1,16 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
# Copyright (c) 2023 Lincoln D. Stein
|
||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
from fastapi import Body, Path, Query, Response, UploadFile
|
from fastapi import Body, Path, Query, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -22,6 +25,7 @@ from invokeai.app.services.model_records import (
|
|||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||||
|
from invokeai.app.util.suppress_output import SuppressOutput
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -794,3 +798,51 @@ async def get_starter_models() -> list[StarterModel]:
|
|||||||
model.is_installed = True
|
model.is_installed = True
|
||||||
|
|
||||||
return starter_models
|
return starter_models
|
||||||
|
|
||||||
|
|
||||||
|
class HFTokenStatus(str, Enum):
|
||||||
|
VALID = "valid"
|
||||||
|
INVALID = "invalid"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class HFTokenHelper:
|
||||||
|
@classmethod
|
||||||
|
def get_status(cls) -> HFTokenStatus:
|
||||||
|
try:
|
||||||
|
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||||
|
# Valid token!
|
||||||
|
return HFTokenStatus.VALID
|
||||||
|
# No token set
|
||||||
|
return HFTokenStatus.INVALID
|
||||||
|
except Exception:
|
||||||
|
return HFTokenStatus.UNKNOWN
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_token(cls, token: str) -> HFTokenStatus:
|
||||||
|
with SuppressOutput(), contextlib.suppress(Exception):
|
||||||
|
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||||
|
return cls.get_status()
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||||
|
async def get_hf_login_status() -> HFTokenStatus:
|
||||||
|
token_status = HFTokenHelper.get_status()
|
||||||
|
|
||||||
|
if token_status is HFTokenStatus.UNKNOWN:
|
||||||
|
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||||
|
|
||||||
|
return token_status
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||||
|
async def do_hf_login(
|
||||||
|
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||||
|
) -> HFTokenStatus:
|
||||||
|
HFTokenHelper.set_token(token)
|
||||||
|
token_status = HFTokenHelper.get_status()
|
||||||
|
|
||||||
|
if token_status is HFTokenStatus.UNKNOWN:
|
||||||
|
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||||
|
|
||||||
|
return token_status
|
||||||
|
@ -16,7 +16,6 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
|||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
import invokeai.configs as model_configs
|
import invokeai.configs as model_configs
|
||||||
from invokeai.app.util.hf_login import hf_login
|
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
@ -449,9 +448,6 @@ def get_config() -> InvokeAIAppConfig:
|
|||||||
]
|
]
|
||||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||||
|
|
||||||
# Log in to HF
|
|
||||||
hf_login()
|
|
||||||
|
|
||||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
import huggingface_hub
|
|
||||||
from pwinput import pwinput
|
|
||||||
|
|
||||||
from invokeai.app.util.suppress_output import SuppressOutput
|
|
||||||
|
|
||||||
|
|
||||||
def hf_login() -> None:
|
|
||||||
"""Prompts the user for their HuggingFace token. If a valid token is already saved, this function will do nothing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the login was successful, False if the user canceled.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the user cancels the login prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
current_token = huggingface_hub.get_token()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if huggingface_hub.get_token_permission(current_token):
|
|
||||||
# We have a valid token already
|
|
||||||
return
|
|
||||||
except ConnectionError:
|
|
||||||
print("Unable to reach HF to verify token. Skipping...")
|
|
||||||
# No internet connection, so we can't check the token
|
|
||||||
pass
|
|
||||||
|
|
||||||
# InvokeAILogger depends on the config, and this class is used within the config, so we can't use the app logger here
|
|
||||||
print("Enter your HuggingFace token. This is required to convert checkpoint/safetensors models to diffusers.")
|
|
||||||
print("For more information, see https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens")
|
|
||||||
print("Press Ctrl+C to skip.")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
access_token = pwinput(prompt="HF token: ")
|
|
||||||
# The login function prints to stdout
|
|
||||||
with SuppressOutput():
|
|
||||||
huggingface_hub.login(token=access_token, add_to_git_credential=False)
|
|
||||||
print("Token verified.")
|
|
||||||
break
|
|
||||||
except ValueError:
|
|
||||||
print("Invalid token!")
|
|
||||||
continue
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nToken verification canceled.")
|
|
||||||
break
|
|
Loading…
x
Reference in New Issue
Block a user