feat(mm): move HF token helper to route

This commit is contained in:
psychedelicious 2024-03-20 09:21:48 +11:00
parent dea9142cb8
commit 9a5575b46b
3 changed files with 52 additions and 50 deletions

View File

@ -1,13 +1,16 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import contextlib
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional
import huggingface_hub
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
@ -22,6 +25,7 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -794,3 +798,51 @@ async def get_starter_models() -> list[StarterModel]:
model.is_installed = True
return starter_models
class HFTokenStatus(str, Enum):
VALID = "valid"
INVALID = "invalid"
UNKNOWN = "unknown"
class HFTokenHelper:
@classmethod
def get_status(cls) -> HFTokenStatus:
try:
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
# Valid token!
return HFTokenStatus.VALID
# No token set
return HFTokenStatus.INVALID
except Exception:
return HFTokenStatus.UNKNOWN
@classmethod
def set_token(cls, token: str) -> HFTokenStatus:
with SuppressOutput(), contextlib.suppress(Exception):
huggingface_hub.login(token=token, add_to_git_credential=False)
return cls.get_status()
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
async def get_hf_login_status() -> HFTokenStatus:
token_status = HFTokenHelper.get_status()
if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
async def do_hf_login(
token: str = Body(description="Hugging Face token to use for login", embed=True),
) -> HFTokenStatus:
HFTokenHelper.set_token(token)
token_status = HFTokenHelper.get_status()
if token_status is HFTokenStatus.UNKNOWN:
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
return token_status

View File

@ -16,7 +16,6 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import invokeai.configs as model_configs
from invokeai.app.util.hf_login import hf_login
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -449,9 +448,6 @@ def get_config() -> InvokeAIAppConfig:
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Log in to HF
hf_login()
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)

View File

@ -1,46 +0,0 @@
import huggingface_hub
from pwinput import pwinput
from invokeai.app.util.suppress_output import SuppressOutput
def hf_login() -> None:
"""Prompts the user for their HuggingFace token. If a valid token is already saved, this function will do nothing.
Returns:
bool: True if the login was successful, False if the user canceled.
Raises:
RuntimeError: If the user cancels the login prompt.
"""
current_token = huggingface_hub.get_token()
try:
if huggingface_hub.get_token_permission(current_token):
# We have a valid token already
return
except ConnectionError:
print("Unable to reach HF to verify token. Skipping...")
# No internet connection, so we can't check the token
pass
# InvokeAILogger depends on the config, and this class is used within the config, so we can't use the app logger here
print("Enter your HuggingFace token. This is required to convert checkpoint/safetensors models to diffusers.")
print("For more information, see https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens")
print("Press Ctrl+C to skip.")
while True:
try:
access_token = pwinput(prompt="HF token: ")
# The login function prints to stdout
with SuppressOutput():
huggingface_hub.login(token=access_token, add_to_git_credential=False)
print("Token verified.")
break
except ValueError:
print("Invalid token!")
continue
except KeyboardInterrupt:
print("\nToken verification canceled.")
break