install of repo_ids records author, tags and license

This commit is contained in:
Lincoln Stein
2023-09-09 14:02:05 -04:00
parent 598fe8101e
commit 64424c6db0
4 changed files with 17 additions and 5 deletions

View File

@ -106,8 +106,8 @@ class ModelConfigBase(BaseModel):
id: Optional[str] = Field(None) # this may get added by the store
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
license: Optional[str] = Field(description="License string")
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
license_url: Optional[str] = Field(description="URL of license")
source_url: Optional[str] = Field(description="Model download source")
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import List, Optional, Callable
from typing import List, Optional, Callable, Dict, Any
from pydantic import BaseModel, Field
@ -50,6 +50,7 @@ class DownloadJobBase(BaseModel):
job_sequence: Optional[int] = Field(
description="Counter that records order in which this job was dequeued (for debugging)"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Model metadata (source-specific)")
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
class Config:

View File

@ -395,8 +395,9 @@ class DownloadQueue(DownloadQueueBase):
try:
repo_id = job.source
variant = job.variant
urls_to_download = self._get_repo_urls(repo_id, variant)
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
job.destination = job.destination / Path(repo_id).name
job.metadata = metadata
bytes_downloaded = dict()
for url, subdir, file, size in urls_to_download:
@ -418,7 +419,10 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.COMPLETED)
subqueue.release() # get rid of the subqueue
def _get_repo_urls(self, repo_id: str, variant: Optional[str] = None) -> List[Tuple[AnyHttpUrl, Path, Path]]:
def _get_repo_info(self,
repo_id: str,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
@ -431,10 +435,11 @@ class DownloadQueue(DownloadQueueBase):
submodels = resp.json()
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, "model_index.json")
return [
urls = [
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
return (urls, {'cardData': model_info.cardData, 'tags': model_info.tags, 'author': model_info.author})
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""

View File

@ -320,6 +320,12 @@ class ModelInstall(ModelInstallBase):
info = self._store.get_model(id)
info.description = f"Downloaded model {info.name}"
info.source_url = str(job.source)
if card_data := job.metadata.get('cardData'):
info.license = card_data.get('license')
if author := job.metadata.get('author'):
info.author = author
if tags := job.metadata.get('tags'):
info.tags = tags
self._store.update_model(id, info)
self._async_installs[job.source] = id
jobs = queue.list_jobs()