mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-02-08 21:09:41 -06:00
fix(ogtags): respect target host/SNI/insecure flags in OG passthrough (#1283)
This commit is contained in:
@@ -439,26 +439,29 @@ func main() {
|
||||
}
|
||||
|
||||
s, err := libanubis.New(libanubis.Options{
|
||||
BasePrefix: *basePrefix,
|
||||
StripBasePrefix: *stripBasePrefix,
|
||||
Next: rp,
|
||||
Policy: policy,
|
||||
ServeRobotsTXT: *robotsTxt,
|
||||
ED25519PrivateKey: ed25519Priv,
|
||||
HS512Secret: []byte(*hs512Secret),
|
||||
CookieDomain: *cookieDomain,
|
||||
CookieDynamicDomain: *cookieDynamicDomain,
|
||||
CookieExpiration: *cookieExpiration,
|
||||
CookiePartitioned: *cookiePartitioned,
|
||||
RedirectDomains: redirectDomainsList,
|
||||
Target: *target,
|
||||
WebmasterEmail: *webmasterEmail,
|
||||
OpenGraph: policy.OpenGraph,
|
||||
CookieSecure: *cookieSecure,
|
||||
CookieSameSite: parseSameSite(*cookieSameSite),
|
||||
PublicUrl: *publicUrl,
|
||||
JWTRestrictionHeader: *jwtRestrictionHeader,
|
||||
DifficultyInJWT: *difficultyInJWT,
|
||||
BasePrefix: *basePrefix,
|
||||
StripBasePrefix: *stripBasePrefix,
|
||||
Next: rp,
|
||||
Policy: policy,
|
||||
TargetHost: *targetHost,
|
||||
TargetSNI: *targetSNI,
|
||||
TargetInsecureSkipVerify: *targetInsecureSkipVerify,
|
||||
ServeRobotsTXT: *robotsTxt,
|
||||
ED25519PrivateKey: ed25519Priv,
|
||||
HS512Secret: []byte(*hs512Secret),
|
||||
CookieDomain: *cookieDomain,
|
||||
CookieDynamicDomain: *cookieDynamicDomain,
|
||||
CookieExpiration: *cookieExpiration,
|
||||
CookiePartitioned: *cookiePartitioned,
|
||||
RedirectDomains: redirectDomainsList,
|
||||
Target: *target,
|
||||
WebmasterEmail: *webmasterEmail,
|
||||
OpenGraph: policy.OpenGraph,
|
||||
CookieSecure: *cookieSecure,
|
||||
CookieSameSite: parseSameSite(*cookieSameSite),
|
||||
PublicUrl: *publicUrl,
|
||||
JWTRestrictionHeader: *jwtRestrictionHeader,
|
||||
DifficultyInJWT: *difficultyInJWT,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("can't construct libanubis.Server: %v", err)
|
||||
|
||||
@@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Allow Renovate as an OCI registry client.
|
||||
- Properly handle 4in6 addresses so that IP matching works with those addresses.
|
||||
- Add support to simple Valkey/Redis cluster mode
|
||||
- Open Graph passthrough now reuses the configured target Host/SNI/TLS settings, so metadata fetches succeed when the upstream certificate differs from the public domain. ([1283](https://github.com/TecharoHQ/anubis/pull/1283))
|
||||
- Stabilize the CVE-2025-24369 regression test by always submitting an invalid proof instead of relying on random POW failures.
|
||||
|
||||
## v1.23.1: Lyse Hext - Echo 1
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestCacheReturnsDefault(t *testing.T) {
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
Override: want,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
u, err := url.Parse("https://anubis.techaro.lol")
|
||||
if err != nil {
|
||||
@@ -52,7 +52,7 @@ func TestCheckCache(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
// Set up test data
|
||||
urlStr := "http://example.com/page"
|
||||
@@ -115,7 +115,7 @@ func TestGetOGTags(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
// Parse the test server URL
|
||||
parsedURL, err := url.Parse(ts.URL)
|
||||
@@ -271,7 +271,7 @@ func TestGetOGTagsWithHostConsideration(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: tc.ogCacheConsiderHost,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
for i, req := range tc.requests {
|
||||
ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host)
|
||||
|
||||
@@ -27,16 +27,29 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri
|
||||
}
|
||||
|
||||
// Set the Host header to the original host
|
||||
if originalHost != "" {
|
||||
req.Host = originalHost
|
||||
var hostForRequest string
|
||||
switch {
|
||||
case c.targetHost != "":
|
||||
hostForRequest = c.targetHost
|
||||
case originalHost != "":
|
||||
hostForRequest = originalHost
|
||||
}
|
||||
if hostForRequest != "" {
|
||||
req.Host = hostForRequest
|
||||
}
|
||||
|
||||
// Add proxy headers
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
|
||||
|
||||
serverName := hostForRequest
|
||||
if serverName == "" {
|
||||
serverName = req.URL.Hostname()
|
||||
}
|
||||
client := c.clientForSNI(serverName)
|
||||
|
||||
// Send the request
|
||||
resp, err := c.client.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
|
||||
@@ -87,7 +87,7 @@ func TestFetchHTMLDocument(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything")
|
||||
|
||||
if tt.expectError {
|
||||
@@ -118,7 +118,7 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything")
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
// Create URL for test
|
||||
testURL, _ := url.Parse(ts.URL)
|
||||
|
||||
@@ -31,7 +31,7 @@ func BenchmarkGetTarget(b *testing.B) {
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(tt.name, func(b *testing.B) {
|
||||
cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()))
|
||||
cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
|
||||
urls := make([]*url.URL, len(tt.paths))
|
||||
for i, path := range tt.paths {
|
||||
u, _ := url.Parse(path)
|
||||
@@ -67,7 +67,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
|
||||
</head><body><div><p>Content</p></div></body></html>`,
|
||||
}
|
||||
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context()))
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
|
||||
docs := make([]*html.Node, len(htmlSamples))
|
||||
|
||||
for i, sample := range htmlSamples {
|
||||
@@ -85,7 +85,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
|
||||
|
||||
// Memory usage test
|
||||
func TestMemoryUsage(t *testing.T) {
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context()))
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
// Force GC and wait for it to complete
|
||||
runtime.GC()
|
||||
|
||||
@@ -2,11 +2,13 @@ package ogtags
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||
@@ -22,21 +24,34 @@ const (
|
||||
)
|
||||
|
||||
type OGTagCache struct {
|
||||
ogOverride map[string]string
|
||||
targetURL *url.URL
|
||||
client *http.Client
|
||||
ogOverride map[string]string
|
||||
transport *http.Transport
|
||||
cache store.JSON[map[string]string]
|
||||
|
||||
// Pre-built strings for optimization
|
||||
unixPrefix string // "http://unix"
|
||||
approvedTags []string
|
||||
targetSNI string
|
||||
targetHost string
|
||||
approvedPrefixes []string
|
||||
approvedTags []string
|
||||
ogTimeToLive time.Duration
|
||||
ogCacheConsiderHost bool
|
||||
ogPassthrough bool
|
||||
ogCacheConsiderHost bool
|
||||
targetSNIAuto bool
|
||||
insecureSkipVerify bool
|
||||
sniClients map[string]*http.Client
|
||||
transportMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface) *OGTagCache {
|
||||
type TargetOptions struct {
|
||||
Host string
|
||||
SNI string
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface, targetOpts TargetOptions) *OGTagCache {
|
||||
// Predefined approved tags and prefixes
|
||||
defaultApprovedTags := []string{"description", "keywords", "author"}
|
||||
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
|
||||
@@ -62,20 +77,37 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
}
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
// Configure custom transport for Unix sockets
|
||||
if parsedTargetURL.Scheme == "unix" {
|
||||
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
|
||||
client.Transport = &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
},
|
||||
transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
}
|
||||
}
|
||||
|
||||
targetSNIAuto := targetOpts.SNI == "auto"
|
||||
|
||||
if targetOpts.SNI != "" && !targetSNIAuto {
|
||||
if transport.TLSClientConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.TLSClientConfig.ServerName = targetOpts.SNI
|
||||
}
|
||||
|
||||
if targetOpts.InsecureSkipVerify {
|
||||
if transport.TLSClientConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return &OGTagCache{
|
||||
cache: store.JSON[map[string]string]{
|
||||
Underlying: backend,
|
||||
@@ -89,7 +121,13 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
|
||||
approvedTags: defaultApprovedTags,
|
||||
approvedPrefixes: defaultApprovedPrefixes,
|
||||
client: client,
|
||||
transport: transport,
|
||||
unixPrefix: "http://unix",
|
||||
targetHost: targetOpts.Host,
|
||||
targetSNI: targetOpts.SNI,
|
||||
targetSNIAuto: targetSNIAuto,
|
||||
insecureSkipVerify: targetOpts.InsecureSkipVerify,
|
||||
sniClients: make(map[string]*http.Client),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ func FuzzGetTarget(f *testing.F) {
|
||||
}
|
||||
|
||||
// Create cache - should not panic
|
||||
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()))
|
||||
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
|
||||
|
||||
// Create URL
|
||||
u := &url.URL{
|
||||
@@ -132,7 +132,7 @@ func FuzzExtractOGTags(f *testing.F) {
|
||||
return
|
||||
}
|
||||
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()))
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
|
||||
|
||||
// Should not panic
|
||||
tags := cache.extractOGTags(doc)
|
||||
@@ -188,7 +188,7 @@ func FuzzGetTargetRoundTrip(f *testing.F) {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()))
|
||||
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
|
||||
u := &url.URL{Path: path, RawQuery: query}
|
||||
|
||||
result := cache.getTarget(u)
|
||||
@@ -245,7 +245,7 @@ func FuzzExtractMetaTagInfo(f *testing.F) {
|
||||
},
|
||||
}
|
||||
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()))
|
||||
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
|
||||
|
||||
// Should not panic
|
||||
property, content := cache.extractMetaTagInfo(node)
|
||||
@@ -298,7 +298,7 @@ func BenchmarkFuzzedGetTarget(b *testing.B) {
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background()))
|
||||
cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
|
||||
u := &url.URL{Path: input.path, RawQuery: input.query}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
@@ -2,15 +2,23 @@ package ogtags
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -45,7 +53,7 @@ func TestNewOGTagCache(t *testing.T) {
|
||||
Enabled: tt.ogPassthrough,
|
||||
TimeToLive: tt.ogTimeToLive,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("expected non-nil cache, got nil")
|
||||
@@ -85,7 +93,7 @@ func TestNewOGTagCache_UnixSocket(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: 5 * time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("expected non-nil cache, got nil")
|
||||
@@ -170,7 +178,7 @@ func TestGetTarget(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
u := &url.URL{
|
||||
Path: tt.path,
|
||||
@@ -243,7 +251,7 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
|
||||
// Create a dummy URL for the request (path and query matter)
|
||||
testReqURL, _ := url.Parse("/some/page?query=1")
|
||||
@@ -274,3 +282,244 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
|
||||
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOGTagsWithTargetHostOverride(t *testing.T) {
|
||||
originalHost := "example.test"
|
||||
overrideHost := "backend.internal"
|
||||
seenHosts := make(chan string, 10)
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seenHosts <- r.Host
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="HostOverride" /></head><body>ok</body></html>`)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
conf := config.OpenGraph{
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}
|
||||
|
||||
t.Run("default host uses original", func(t *testing.T) {
|
||||
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
|
||||
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
|
||||
t.Fatalf("GetOGTags failed: %v", err)
|
||||
}
|
||||
select {
|
||||
case host := <-seenHosts:
|
||||
if host != originalHost {
|
||||
t.Fatalf("expected host %q, got %q", originalHost, host)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("server did not receive request")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("override host respected", func(t *testing.T) {
|
||||
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
|
||||
Host: overrideHost,
|
||||
})
|
||||
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
|
||||
t.Fatalf("GetOGTags failed: %v", err)
|
||||
}
|
||||
select {
|
||||
case host := <-seenHosts:
|
||||
if host != overrideHost {
|
||||
t.Fatalf("expected host %q, got %q", overrideHost, host)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("server did not receive request")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetOGTagsWithInsecureSkipVerify(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Self-Signed" /></head><body>hello</body></html>`)
|
||||
})
|
||||
ts := httptest.NewTLSServer(handler)
|
||||
defer ts.Close()
|
||||
|
||||
parsedURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
conf := config.OpenGraph{
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}
|
||||
|
||||
// Without skip verify we should get a TLS error
|
||||
cacheStrict := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
|
||||
if _, err := cacheStrict.GetOGTags(t.Context(), parsedURL, parsedURL.Host); err == nil {
|
||||
t.Fatal("expected TLS verification error without InsecureSkipVerify")
|
||||
}
|
||||
|
||||
cacheSkip := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
|
||||
tags, err := cacheSkip.GetOGTags(t.Context(), parsedURL, parsedURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("expected successful fetch with InsecureSkipVerify, got: %v", err)
|
||||
}
|
||||
if tags["og:title"] != "Self-Signed" {
|
||||
t.Fatalf("expected og:title to be %q, got %q", "Self-Signed", tags["og:title"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOGTagsWithTargetSNI(t *testing.T) {
|
||||
originalHost := "hecate.test"
|
||||
conf := config.OpenGraph{
|
||||
Enabled: true,
|
||||
TimeToLive: time.Minute,
|
||||
ConsiderHost: false,
|
||||
}
|
||||
|
||||
t.Run("explicit SNI override", func(t *testing.T) {
|
||||
expectedSNI := "backend.internal"
|
||||
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Works" /></head><body>ok</body></html>`)
|
||||
defer ts.Close()
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
cacheExplicit := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
|
||||
SNI: expectedSNI,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if _, err := cacheExplicit.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
|
||||
t.Fatalf("expected successful fetch with explicit SNI, got: %v", err)
|
||||
}
|
||||
if got := recorder.last(); got != expectedSNI {
|
||||
t.Fatalf("expected server to see SNI %q, got %q", expectedSNI, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auto SNI uses original host", func(t *testing.T) {
|
||||
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Auto" /></head><body>ok</body></html>`)
|
||||
defer ts.Close()
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
cacheAuto := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
|
||||
SNI: "auto",
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if _, err := cacheAuto.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
|
||||
t.Fatalf("expected successful fetch with auto SNI, got: %v", err)
|
||||
}
|
||||
if got := recorder.last(); got != originalHost {
|
||||
t.Fatalf("expected server to see SNI %q with auto, got %q", originalHost, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default SNI uses backend host", func(t *testing.T) {
|
||||
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Default" /></head><body>ok</body></html>`)
|
||||
defer ts.Close()
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
cacheDefault := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if _, err := cacheDefault.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
|
||||
t.Fatalf("expected successful fetch without explicit SNI, got: %v", err)
|
||||
}
|
||||
wantSNI := ""
|
||||
if net.ParseIP(targetURL.Hostname()) == nil {
|
||||
wantSNI = targetURL.Hostname()
|
||||
}
|
||||
if got := recorder.last(); got != wantSNI {
|
||||
t.Fatalf("expected default SNI %q, got %q", wantSNI, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newSNIServer(t *testing.T, body string) (*httptest.Server, *sniRecorder) {
|
||||
t.Helper()
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, body)
|
||||
})
|
||||
|
||||
recorder := &sniRecorder{}
|
||||
ts := httptest.NewUnstartedServer(handler)
|
||||
cert := mustCertificateForHost(t, "sni.test")
|
||||
ts.TLS = &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
recorder.record(hello.ServerName)
|
||||
return &cert, nil
|
||||
},
|
||||
}
|
||||
ts.StartTLS()
|
||||
return ts, recorder
|
||||
}
|
||||
|
||||
func mustCertificateForHost(t *testing.T, host string) tls.Certificate {
|
||||
t.Helper()
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{host},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
return tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
}
|
||||
|
||||
type sniRecorder struct {
|
||||
mu sync.Mutex
|
||||
names []string
|
||||
}
|
||||
|
||||
func (r *sniRecorder) record(name string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.names = append(r.names, name)
|
||||
}
|
||||
|
||||
func (r *sniRecorder) last() string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if len(r.names) == 0 {
|
||||
return ""
|
||||
}
|
||||
return r.names[len(r.names)-1]
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestExtractOGTags(t *testing.T) {
|
||||
Enabled: false,
|
||||
ConsiderHost: false,
|
||||
TimeToLive: time.Minute,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
// Manually set approved tags/prefixes based on the user request for clarity
|
||||
testCache.approvedTags = []string{"description"}
|
||||
testCache.approvedPrefixes = []string{"og:"}
|
||||
@@ -199,7 +199,7 @@ func TestExtractMetaTagInfo(t *testing.T) {
|
||||
Enabled: false,
|
||||
ConsiderHost: false,
|
||||
TimeToLive: time.Minute,
|
||||
}, memory.New(t.Context()))
|
||||
}, memory.New(t.Context()), TargetOptions{})
|
||||
testCache.approvedTags = []string{"description"}
|
||||
testCache.approvedPrefixes = []string{"og:"}
|
||||
|
||||
|
||||
42
internal/ogtags/sni.go
Normal file
42
internal/ogtags/sni.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package ogtags
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// clientForSNI returns a cached client for the given server name, creating one if needed.
|
||||
func (c *OGTagCache) clientForSNI(serverName string) *http.Client {
|
||||
if !c.targetSNIAuto || serverName == "" {
|
||||
return c.client
|
||||
}
|
||||
|
||||
c.transportMu.RLock()
|
||||
cli, ok := c.sniClients[serverName]
|
||||
c.transportMu.RUnlock()
|
||||
if ok {
|
||||
return cli
|
||||
}
|
||||
|
||||
c.transportMu.Lock()
|
||||
defer c.transportMu.Unlock()
|
||||
if cli, ok := c.sniClients[serverName]; ok {
|
||||
return cli
|
||||
}
|
||||
|
||||
tr := c.transport.Clone()
|
||||
if tr.TLSClientConfig == nil {
|
||||
tr.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
tr.TLSClientConfig.ServerName = serverName
|
||||
if c.insecureSkipVerify {
|
||||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
cli = &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
Transport: tr,
|
||||
}
|
||||
c.sniClients[serverName] = cli
|
||||
return cli
|
||||
}
|
||||
@@ -27,27 +27,30 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Next http.Handler
|
||||
Policy *policy.ParsedConfig
|
||||
Logger *slog.Logger
|
||||
OpenGraph config.OpenGraph
|
||||
PublicUrl string
|
||||
CookieDomain string
|
||||
JWTRestrictionHeader string
|
||||
BasePrefix string
|
||||
WebmasterEmail string
|
||||
Target string
|
||||
RedirectDomains []string
|
||||
ED25519PrivateKey ed25519.PrivateKey
|
||||
HS512Secret []byte
|
||||
CookieExpiration time.Duration
|
||||
CookieSameSite http.SameSite
|
||||
ServeRobotsTXT bool
|
||||
CookieSecure bool
|
||||
StripBasePrefix bool
|
||||
CookiePartitioned bool
|
||||
CookieDynamicDomain bool
|
||||
DifficultyInJWT bool
|
||||
Next http.Handler
|
||||
Policy *policy.ParsedConfig
|
||||
Target string
|
||||
TargetHost string
|
||||
TargetSNI string
|
||||
TargetInsecureSkipVerify bool
|
||||
CookieDynamicDomain bool
|
||||
CookieDomain string
|
||||
CookieExpiration time.Duration
|
||||
CookiePartitioned bool
|
||||
BasePrefix string
|
||||
WebmasterEmail string
|
||||
RedirectDomains []string
|
||||
ED25519PrivateKey ed25519.PrivateKey
|
||||
HS512Secret []byte
|
||||
StripBasePrefix bool
|
||||
OpenGraph config.OpenGraph
|
||||
ServeRobotsTXT bool
|
||||
CookieSecure bool
|
||||
CookieSameSite http.SameSite
|
||||
Logger *slog.Logger
|
||||
PublicUrl string
|
||||
JWTRestrictionHeader string
|
||||
DifficultyInJWT bool
|
||||
}
|
||||
|
||||
func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) {
|
||||
@@ -116,9 +119,13 @@ func New(opts Options) (*Server, error) {
|
||||
hs512Secret: opts.HS512Secret,
|
||||
policy: opts.Policy,
|
||||
opts: opts,
|
||||
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store),
|
||||
store: opts.Policy.Store,
|
||||
logger: opts.Logger,
|
||||
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{
|
||||
Host: opts.TargetHost,
|
||||
SNI: opts.TargetSNI,
|
||||
InsecureSkipVerify: opts.TargetInsecureSkipVerify,
|
||||
}),
|
||||
store: opts.Policy.Store,
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
@@ -62,11 +62,14 @@ type BotConfig struct {
|
||||
Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"`
|
||||
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
|
||||
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
|
||||
GeoIP *GeoIP `json:"geoip,omitempty"`
|
||||
ASNs *ASNs `json:"asns,omitempty"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Action Rule `json:"action" yaml:"action"`
|
||||
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||
|
||||
// Thoth features
|
||||
GeoIP *GeoIP `json:"geoip,omitempty"`
|
||||
ASNs *ASNs `json:"asns,omitempty"`
|
||||
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Action Rule `json:"action" yaml:"action"`
|
||||
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
|
||||
}
|
||||
|
||||
func (b BotConfig) Zero() bool {
|
||||
|
||||
Reference in New Issue
Block a user