mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): move HF token helper to route
This commit is contained in:
parent
dea9142cb8
commit
9a5575b46b
@ -1,13 +1,16 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import huggingface_hub
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
@ -22,6 +25,7 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -794,3 +798,51 @@ async def get_starter_models() -> list[StarterModel]:
|
||||
model.is_installed = True
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def set_token(cls, token: str) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
@ -16,7 +16,6 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.app.util.hf_login import hf_login
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
@ -449,9 +448,6 @@ def get_config() -> InvokeAIAppConfig:
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Log in to HF
|
||||
hf_login()
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
@ -1,46 +0,0 @@
|
||||
import huggingface_hub
|
||||
from pwinput import pwinput
|
||||
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
|
||||
|
||||
def hf_login() -> None:
|
||||
"""Prompts the user for their HuggingFace token. If a valid token is already saved, this function will do nothing.
|
||||
|
||||
Returns:
|
||||
bool: True if the login was successful, False if the user canceled.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the user cancels the login prompt.
|
||||
"""
|
||||
|
||||
current_token = huggingface_hub.get_token()
|
||||
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(current_token):
|
||||
# We have a valid token already
|
||||
return
|
||||
except ConnectionError:
|
||||
print("Unable to reach HF to verify token. Skipping...")
|
||||
# No internet connection, so we can't check the token
|
||||
pass
|
||||
|
||||
# InvokeAILogger depends on the config, and this class is used within the config, so we can't use the app logger here
|
||||
print("Enter your HuggingFace token. This is required to convert checkpoint/safetensors models to diffusers.")
|
||||
print("For more information, see https://huggingface.co/docs/hub/security-tokens#how-to-manage-user-access-tokens")
|
||||
print("Press Ctrl+C to skip.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
access_token = pwinput(prompt="HF token: ")
|
||||
# The login function prints to stdout
|
||||
with SuppressOutput():
|
||||
huggingface_hub.login(token=access_token, add_to_git_credential=False)
|
||||
print("Token verified.")
|
||||
break
|
||||
except ValueError:
|
||||
print("Invalid token!")
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
print("\nToken verification canceled.")
|
||||
break
|
Loading…
Reference in New Issue
Block a user