diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 1716469b72..069fc96aea 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1288,7 +1288,9 @@ This descends from `ModelMetadataBase` and adds the following fields: | `id` | int | Civitai model id | | `version_name` | str | Name of this version of the model (distinct from model name) | | `version_id` | int | Civitai model version id (distinct from model id) | -| `created` | datetime | Date the model was uploaded to Civitai; no modification date provided | +| `created` | datetime | Date this version of the model was created | +| `updated` | datetime | Date this version of the model was last updated | +| `published` | datetime | Date this version of the model was published to Civitai | | `description` | str | Model description. Quite verbose and contains HTML tags | | `version_description` | str | Model version description, usually describes changes to the model | | `nsfw` | bool | Whether the model tends to generate NSFW content | diff --git a/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py b/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py index 491094f55f..5031b0ea3c 100644 --- a/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py +++ b/invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py @@ -98,49 +98,56 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase): May raise an `UnknownModelException`. """ model_url = CIVITAI_MODEL_ENDPOINT + str(model_id) - model = self._requests.get(model_url).json() - default_version = model["modelVersions"][0]["id"] - return self.from_civitai_versionid(default_version, model) + model_json = self._requests.get(model_url).json() + return self._from_model_json(model_json) - def from_civitai_versionid( - self, version_id: int, model_metadata: Optional[Dict[str, Any]] = None - ) -> CivitaiMetadata: + def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata: + version_id = version_id or model_json["modelVersions"][0]["id"] + + # loop till we find the section containing the version requested + version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id] + if not version_sections: + raise UnknownModelException(f"Version {version_id} not found in model metadata") + + version_json = version_sections[0] + safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"] + return CivitaiMetadata( + id=model_json["id"], + name=model_json["name"], + version_id=version_json["id"], + version_name=version_json["name"], + created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["createdAt"])), + updated=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["updatedAt"])), + published=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["publishedAt"])), + base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType + download_url=version_json["downloadUrl"], + thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, + author=model_json["creator"]["username"], + description=model_json["description"], + version_description=version_json["description"] or "", + tags=model_json["tags"], + trained_words=version_json["trainedWords"], + nsfw=model_json["nsfw"], + restrictions=LicenseRestrictions( + AllowNoCredit=model_json["allowNoCredit"], + AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]), + AllowDerivatives=model_json["allowDerivatives"], + AllowDifferentLicense=model_json["allowDifferentLicense"], + ), + ) + + def from_civitai_versionid(self, version_id: int) -> CivitaiMetadata: + """ + Return a CivitaiMetadata object given a model version id. + + May raise an `UnknownModelException`. + """ version_url = CIVITAI_VERSION_ENDPOINT + str(version_id) version = self._requests.get(version_url).json() model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"]) - model = model_metadata or self._requests.get(model_url).json() - safe_thumbnails = [x["url"] for x in version["images"] if x["nsfw"] == "None"] - - # It would be more elegant to define a Pydantic BaseModel that matches the Civitai metadata JSON. - # However the contents of the JSON does not exactly match the documentation at - # https://github.com/civitai/civitai/wiki/REST-API-Reference, and it feels safer to cherry pick - # a subset of the fields. - # - # In addition, there are some fields that I want to pick up from the model JSON, such as `tags`, - # that are not present in the version JSON. - return CivitaiMetadata( - id=version["modelId"], - name=version["model"]["name"], - version_id=version["id"], - version_name=version["name"], - created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version["createdAt"])), - base_model_trained_on=version["baseModel"], # note - need a dictionary to turn into a BaseModelType - download_url=version["downloadUrl"], - thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None, - author=model["creator"]["username"], - description=model["description"], - version_description=version["description"] or "", - tags=model["tags"], - trained_words=version["trainedWords"], - nsfw=version["model"]["nsfw"], - restrictions=LicenseRestrictions( - AllowNoCredit=model["allowNoCredit"], - AllowCommercialUse=CommercialUsage(model["allowCommercialUse"]), - AllowDerivatives=model["allowDerivatives"], - AllowDifferentLicense=model["allowDifferentLicense"], - ), - ) + model_json = self._requests.get(model_url).json() + return self._from_model_json(model_json, version_id) @classmethod def from_json(cls, json: str) -> CivitaiMetadata: diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index a19b28ac92..4366ef3d48 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -74,7 +74,9 @@ class CivitaiMetadata(ModelMetadataBase): id: int = Field(description="Civitai model identifier") version_name: str = Field(description="Version identifier, such as 'V2-alpha'") version_id: int = Field(description="Civitai model version identifier") - created: datetime = Field(description="date the model was posted to CivitAI") + created: datetime = Field(description="date the model was created") + updated: datetime = Field(description="date the model was last modified") + published: datetime = Field(description="date the model was published to Civitai") description: str = Field(description="text description of model; may contain HTML") version_description: str = Field( description="text description of the model's reversion; usually change history; may contain HTML" diff --git a/tests/app/services/model_metadata/metadata_examples.py b/tests/backend/model_manager_2/model_metadata/metadata_examples.py similarity index 100% rename from tests/app/services/model_metadata/metadata_examples.py rename to tests/backend/model_manager_2/model_metadata/metadata_examples.py diff --git a/tests/app/services/model_metadata/test_model_metadata.py b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py similarity index 99% rename from tests/app/services/model_metadata/test_model_metadata.py rename to tests/backend/model_manager_2/model_metadata/test_model_metadata.py index a34466c94c..7d4fbb08ca 100644 --- a/tests/app/services/model_metadata/test_model_metadata.py +++ b/tests/backend/model_manager_2/model_metadata/test_model_metadata.py @@ -26,7 +26,7 @@ from invokeai.backend.model_manager.metadata import ( ModelMetadataStore, ) from invokeai.backend.util.logging import InvokeAILogger -from tests.app.services.model_metadata.metadata_examples import ( +from tests.backend.model_manager_2.model_metadata.metadata_examples import ( RepoCivitaiModelMetadata1, RepoCivitaiVersionMetadata1, RepoHFMetadata1,