diff --git a/pkg/downloader/huggingface.go b/pkg/downloader/huggingface.go index 34ba9bd9a..9d7f1657f 100644 --- a/pkg/downloader/huggingface.go +++ b/pkg/downloader/huggingface.go @@ -23,10 +23,10 @@ var ErrUnsafeFilesFound = errors.New("unsafe files found") func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) { cleanParts := strings.Split(uri.ResolveURL(), "/") - if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" { + if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" && cleanParts[2] != HF_ENDPOINT { return nil, ErrNonHuggingFaceFile } - results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4])) + results, err := http.Get(fmt.Sprintf("%s/api/models/%s/%s/scan", HF_ENDPOINT, cleanParts[3], cleanParts[4])) if err != nil { return nil, err } diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 8d9b1d936..b6739498d 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -37,6 +37,17 @@ const ( type URI string +// HF_ENDPOINT is the HuggingFace endpoint, can be overridden by setting the HF_ENDPOINT environment variable. +var HF_ENDPOINT string = loadConfig() + +func loadConfig() string { + HF_ENDPOINT := os.Getenv("HF_ENDPOINT") + if HF_ENDPOINT == "" { + HF_ENDPOINT = "https://huggingface.co" + } + return HF_ENDPOINT +} + func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error { return uri.DownloadWithAuthorizationAndCallback(basePath, "", f) } @@ -213,7 +224,7 @@ func (s URI) ResolveURL() string { filepath = strings.Split(filepath, "@")[0] } - return fmt.Sprintf("https://huggingface.co/%s/%s/resolve/%s/%s", owner, repo, branch, filepath) + return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath) } return string(s)