mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
install of repo_ids records author, tags and license
This commit is contained in:
@ -106,8 +106,8 @@ class ModelConfigBase(BaseModel):
|
||||
id: Optional[str] = Field(None) # this may get added by the store
|
||||
description: Optional[str] = Field(None)
|
||||
author: Optional[str] = Field(description="Model author")
|
||||
license: Optional[str] = Field(description="License string")
|
||||
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
|
||||
license_url: Optional[str] = Field(description="URL of license")
|
||||
source_url: Optional[str] = Field(description="Model download source")
|
||||
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable
|
||||
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Optional, Callable, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -50,6 +50,7 @@ class DownloadJobBase(BaseModel):
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (for debugging)"
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Model metadata (source-specific)")
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
class Config:
|
||||
|
@ -395,8 +395,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
urls_to_download = self._get_repo_urls(repo_id, variant)
|
||||
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
job.metadata = metadata
|
||||
bytes_downloaded = dict()
|
||||
|
||||
for url, subdir, file, size in urls_to_download:
|
||||
@ -418,7 +419,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
subqueue.release() # get rid of the subqueue
|
||||
|
||||
def _get_repo_urls(self, repo_id: str, variant: Optional[str] = None) -> List[Tuple[AnyHttpUrl, Path, Path]]:
|
||||
def _get_repo_info(self,
|
||||
repo_id: str,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
|
||||
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
sibs = model_info.siblings
|
||||
@ -431,10 +435,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
submodels = resp.json()
|
||||
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, "model_index.json")
|
||||
return [
|
||||
urls = [
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
return (urls, {'cardData': model_info.cardData, 'tags': model_info.tags, 'author': model_info.author})
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
|
@ -320,6 +320,12 @@ class ModelInstall(ModelInstallBase):
|
||||
info = self._store.get_model(id)
|
||||
info.description = f"Downloaded model {info.name}"
|
||||
info.source_url = str(job.source)
|
||||
if card_data := job.metadata.get('cardData'):
|
||||
info.license = card_data.get('license')
|
||||
if author := job.metadata.get('author'):
|
||||
info.author = author
|
||||
if tags := job.metadata.get('tags'):
|
||||
info.tags = tags
|
||||
self._store.update_model(id, info)
|
||||
self._async_installs[job.source] = id
|
||||
jobs = queue.list_jobs()
|
||||
|
Reference in New Issue
Block a user