chore: bump reva

This commit is contained in:
Ralf Haferkamp
2026-05-11 15:47:58 +02:00
committed by Ralf Haferkamp
parent 393926bd73
commit 59bd11d02a
102 changed files with 10788 additions and 4246 deletions
+12 -12
View File
@@ -55,17 +55,17 @@ require (
github.com/mitchellh/mapstructure v1.5.0
github.com/mna/pigeon v1.3.0
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/nats-io/nats-server/v2 v2.12.6
github.com/nats-io/nats-server/v2 v2.14.0
github.com/nats-io/nats.go v1.51.0
github.com/oklog/run v1.2.0
github.com/olekukonko/tablewriter v1.1.4
github.com/onsi/ginkgo v1.16.5
github.com/onsi/ginkgo/v2 v2.28.1
github.com/onsi/gomega v1.39.1
github.com/onsi/gomega v1.40.0
github.com/open-policy-agent/opa v1.15.2
github.com/opencloud-eu/icap-client v0.0.0-20250930132611-28a2afe62d89
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0
github.com/opensearch-project/opensearch-go/v4 v4.6.0
github.com/orcaman/concurrent-map v1.0.0
github.com/pkg/errors v0.9.1
@@ -103,14 +103,14 @@ require (
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0
go.opentelemetry.io/otel/sdk v1.43.0
go.opentelemetry.io/otel/trace v1.43.0
golang.org/x/crypto v0.49.0
golang.org/x/crypto v0.50.0
golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac
golang.org/x/image v0.38.0
golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
golang.org/x/text v0.35.0
golang.org/x/term v0.42.0
golang.org/x/text v0.36.0
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
@@ -122,7 +122,7 @@ require (
require (
contrib.go.opencensus.io/exporter/prometheus v0.4.2 // indirect
filippo.io/edwards25519 v1.1.1 // indirect
filippo.io/edwards25519 v1.2.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Azure/go-ntlmssp v0.1.1 // indirect
github.com/BurntSushi/toml v1.6.0 // indirect
@@ -136,7 +136,7 @@ require (
github.com/ajg/form v1.5.1 // indirect
github.com/alexedwards/argon2id v1.0.0 // indirect
github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 // indirect
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op // indirect
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op // indirect
github.com/armon/go-radix v1.0.0 // indirect
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
@@ -220,7 +220,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-sql-driver/mysql v1.10.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/go-test/deep v1.1.0 // indirect
@@ -286,7 +286,7 @@ require (
github.com/miekg/dns v1.1.68 // indirect
github.com/mileusna/useragent v1.3.5 // indirect
github.com/minio/crc64nvme v1.1.1 // indirect
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 // indirect
github.com/minio/highwayhash v1.0.4 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/minio-go/v7 v7.0.99 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
@@ -388,10 +388,10 @@ require (
go.uber.org/zap v1.27.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/mod v0.33.0 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/time v0.15.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.org/x/tools v0.43.0 // indirect
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
+24 -24
View File
@@ -39,8 +39,8 @@ contrib.go.opencensus.io/exporter/prometheus v0.4.2/go.mod h1:dvEHbiKmgvbr5pjaF9
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/Acconut/go-httptest-recorder v1.0.0 h1:TAv2dfnqp/l+SUvIaMAUK4GeN4+wqb6KZsFFFTGhoJg=
github.com/Acconut/go-httptest-recorder v1.0.0/go.mod h1:CwQyhTH1kq/gLyWiRieo7c0uokpu3PXeyF/nZjUNtmM=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
@@ -117,8 +117,8 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNg
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op h1:kpBdlEPbRvff0mDD1gk7o9BhI16b9p5yYAXRlidpqJE=
github.com/antithesishq/antithesis-sdk-go v0.6.0-default-no-op/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E=
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op h1:Z/MZK75wC/NSrkgqeNIa7jexam9uWzhLmFTSCPI/kn0=
github.com/antithesishq/antithesis-sdk-go v0.7.0-default-no-op/go.mod h1:FQyySiasQQM8735Ddel3MRojmy4dA1IqCeyJ5jmPMbI=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
@@ -456,8 +456,8 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-resty/resty/v2 v2.1.1-0.20191201195748-d7b97669fe48/go.mod h1:dZGr0i9PLlaaTD4H/hoZIDjQ+r6xq8mgbRzHZf7f2J8=
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
@@ -839,8 +839,8 @@ github.com/mileusna/useragent v1.3.5 h1:SJM5NzBmh/hO+4LGeATKpaEX9+b4vcGg2qXGLiNG
github.com/mileusna/useragent v1.3.5/go.mod h1:3d8TOmwL/5I8pJjyVDteHtgDGcefrFUX4ccGOMKNYYc=
github.com/minio/crc64nvme v1.1.1 h1:8dwx/Pz49suywbO+auHCBpCtlW1OfpcLN7wYgVR6wAI=
github.com/minio/crc64nvme v1.1.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk=
github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ=
github.com/minio/highwayhash v1.0.4 h1:asJizugGgchQod2ja9NJlGOWq4s7KsAWr5XUc9Clgl4=
github.com/minio/highwayhash v1.0.4/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.99 h1:2vH/byrwUkIpFQFOilvTfaUpvAX3fEFhEzO+DR3DlCE=
@@ -900,8 +900,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/namedotcom/go v0.0.0-20180403034216-08470befbe04/go.mod h1:5sN+Lt1CaY4wsPvgQH/jsuJi4XO2ssZbdsIizr4CVC8=
github.com/nats-io/jwt/v2 v2.8.1 h1:V0xpGuD/N8Mi+fQNDynXohVvp7ZztevW5io8CUWlPmU=
github.com/nats-io/jwt/v2 v2.8.1/go.mod h1:nWnOEEiVMiKHQpnAy4eXlizVEtSfzacZ1Q43LIRavZg=
github.com/nats-io/nats-server/v2 v2.12.6 h1:Egbx9Vl7Ch8wTtpXPGqbehkZ+IncKqShUxvrt1+Enc8=
github.com/nats-io/nats-server/v2 v2.12.6/go.mod h1:4HPlrvtmSO3yd7KcElDNMx9kv5EBJBnJJzQPptXlheo=
github.com/nats-io/nats-server/v2 v2.14.0 h1:+8q0HrDFotwLLcGH/legOEOnowunhK+aZ4GYBIWpQlM=
github.com/nats-io/nats-server/v2 v2.14.0/go.mod h1:ImVUUDvfClJbb6cuJQRc1VmgDCXKM5ds0OoiG9MVOKo=
github.com/nats-io/nats.go v1.51.0 h1:ByW84XTz6W03GSSsygsZcA+xgKK8vPGaa/FCAAEHnAI=
github.com/nats-io/nats.go v1.51.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
@@ -940,8 +940,8 @@ github.com/onsi/ginkgo/v2 v2.28.1/go.mod h1:CLtbVInNckU3/+gC8LzkGUb9oF+e8W8TdUsx
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.39.1 h1:1IJLAad4zjPn2PsnhH70V4DKRFlrCzGBNrNaru+Vf28=
github.com/onsi/gomega v1.39.1/go.mod h1:hL6yVALoTOxeWudERyfppUcZXjMwIMLnuSfruD2lcfg=
github.com/onsi/gomega v1.40.0 h1:Vtol0e1MghCD2ZVIilPDIg44XSL9l2QAn8ZNaljWcJc=
github.com/onsi/gomega v1.40.0/go.mod h1:M/Uqpu/8qTjtzCLUA2zJHX9Iilrau25x1PdoSRbWh5A=
github.com/open-policy-agent/opa v1.15.2 h1:dS9q+0Yvruq/VNvWJc5qCvCchn715OWc3HLHXn/UCCc=
github.com/open-policy-agent/opa v1.15.2/go.mod h1:c6SN+7jSsUcKJLQc5P4yhwx8YYDRbjpAiGkBOTqxaa4=
github.com/opencloud-eu/go-micro-plugins/v4/store/nats-js-kv v0.0.0-20250512152754-23325793059a h1:Sakl76blJAaM6NxylVkgSzktjo2dS504iDotEFJsh3M=
@@ -952,8 +952,8 @@ github.com/opencloud-eu/inotifywaitgo v0.0.0-20251111171128-a390bae3c5e9 h1:dIft
github.com/opencloud-eu/inotifywaitgo v0.0.0-20251111171128-a390bae3c5e9/go.mod h1:JWyDC6H+5oZRdUJUgKuaye+8Ph5hEs6HVzVoPKzWSGI=
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d h1:JcqGDiyrcaQwVyV861TUyQgO7uEmsjkhfm7aQd84dOw=
github.com/opencloud-eu/libre-graph-api-go v1.0.8-0.20260310090739-853d972b282d/go.mod h1:pzatilMEHZFT3qV7C/X3MqOa3NlRQuYhlRhZTL+hN6Q=
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753 h1:/FpQdybaNb3OAISHmHRrh/4aWYQep3nVSYKzYt2F+jE=
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260424125411-c5db28365753/go.mod h1:msu4TkFw7Jxog0QRbGPxyQOJG9sago5nc+f//y+bbpI=
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0 h1:e4w34sW1gXixTKi9z+odF6IKGyvisvu97xfYEXOvRGE=
github.com/opencloud-eu/reva/v2 v2.43.1-0.20260512061040-cd4be86c66b0/go.mod h1:SoRYtNJ9ha83YdUUep5wYF7F5/OIhgED7ZSgqudhpNo=
github.com/opencloud-eu/secure v0.0.0-20260312082735-b6f5cb2244e4 h1:l2oB/RctH+t8r7QBj5p8thfEHCM/jF35aAY3WQ3hADI=
github.com/opencloud-eu/secure v0.0.0-20260312082735-b6f5cb2244e4/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -1358,8 +1358,8 @@ golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -1397,8 +1397,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -1566,8 +1566,8 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1580,8 +1580,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1642,8 +1642,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210112230658-8b4aab62c064/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/tools/godoc v0.1.0-deprecated h1:o+aZ1BOj6Hsx/GBdJO/s815sqftjSnrZZwyYTHODvtk=
golang.org/x/tools/godoc v0.1.0-deprecated/go.mod h1:qM63CriJ961IHWmnWa9CjZnBndniPt4a3CK0PVB9bIg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+4 -2
View File
@@ -7,8 +7,10 @@ import "filippo.io/edwards25519"
This library implements the edwards25519 elliptic curve, exposing the necessary APIs to build a wide array of higher-level primitives.
Read the docs at [pkg.go.dev/filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519).
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17. It now tracks the upstream codebase and extends it with additional functionality.
The package tracks the upstream standard library package `crypto/internal/fips140/edwards25519` and extends it with additional functionality.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `golang.org/x/crypto/curve25519` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of `crypto/internal/edwards25519`/`crypto/ed25519/internal/edwards25519` or `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
The code is originally derived from Adam Langley's internal implementation in the Go standard library, and includes George Tankersley's [performance improvements](https://golang.org/cl/71950). It was then further developed by Henry de Valence for use in ristretto255, and was finally [merged back into the Go standard library](https://golang.org/cl/276272) as of Go 1.17.
Most users don't need this package, and should instead use `crypto/ed25519` for signatures, `crypto/ecdh` for Diffie-Hellman, or `github.com/gtank/ristretto255` for prime order group logic. However, for anyone currently using a fork of the internal `edwards25519` package or of `github.com/agl/edwards25519`, this package should be a safer, faster, and more powerful alternative.
Since this package is meant to curb proliferation of edwards25519 implementations in the Go ecosystem, it welcomes requests for new APIs or reviewable performance improvements.
+3 -3
View File
@@ -10,11 +10,11 @@
// the curve used by the Ed25519 signature scheme.
//
// Most users don't need this package, and should instead use crypto/ed25519 for
// signatures, golang.org/x/crypto/curve25519 for Diffie-Hellman, or
// github.com/gtank/ristretto255 for prime order group logic.
// signatures, crypto/ecdh for Diffie-Hellman, or github.com/gtank/ristretto255
// for prime order group logic.
//
// However, developers who do need to interact with low-level edwards25519
// operations can use this package, which is an extended version of
// crypto/internal/edwards25519 from the standard library repackaged as
// crypto/internal/fips140/edwards25519 from the standard library repackaged as
// an importable module.
package edwards25519
+59 -8
View File
@@ -9,6 +9,7 @@ package edwards25519
import (
"errors"
"slices"
"filippo.io/edwards25519/field"
)
@@ -100,13 +101,15 @@ func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
//
// u = (1 + y) / (1 - y)
//
// where y = Y / Z.
// where y = Y / Z and therefore
//
// u = (Z + Y) / (Z - Y)
var y, recip, u field.Element
var n, r, u field.Element
y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
n.Add(&v.z, &v.y) // n = Z + Y
r.Invert(r.Subtract(&v.z, &v.y)) // r = 1 / (Z - Y)
u.Multiply(&n, &r) // u = n * r
return copyFieldElement(buf, &u)
}
@@ -124,7 +127,7 @@ func (v *Point) MultByCofactor(p *Point) *Point {
return v.fromP1xP1(&result)
}
// Given k > 0, set s = s**(2*i).
// Given k > 0, set s = s**(2*k).
func (s *Scalar) pow2k(k int) {
for i := 0; i < k; i++ {
s.Multiply(s, s)
@@ -250,12 +253,14 @@ func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
// between each point in the multiscalar equation.
// Build lookup tables for each point
tables := make([]projLookupTable, len(points))
tables := make([]projLookupTable, 0, 2) // avoid allocation for small sizes
tables = slices.Grow(tables, len(points))[:len(points)]
for i := range tables {
tables[i].FromP3(points[i])
}
// Compute signed radix-16 digits for each scalar
digits := make([][64]int8, len(scalars))
digits := make([][64]int8, 0, 2) // avoid allocation for small sizes
digits = slices.Grow(digits, len(scalars))[:len(scalars)]
for i := range digits {
digits[i] = scalars[i].signedRadix16()
}
@@ -348,3 +353,49 @@ func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Poin
v.fromP2(tmp2)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *Point) Select(a, b *Point, cond int) *Point {
checkInitialized(a, b)
v.x.Select(&a.x, &b.x, cond)
v.y.Select(&a.y, &b.y, cond)
v.z.Select(&a.z, &b.z, cond)
v.t.Select(&a.t, &b.t, cond)
return v
}
// Double sets v = p + p, and returns v.
func (v *Point) Double(p *Point) *Point {
checkInitialized(p)
pp := new(projP2).FromP3(p)
p1 := new(projP1xP1).Double(pp)
return v.fromP1xP1(p1)
}
func (v *Point) addCached(p *Point, qCached *projCached) *Point {
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// ScalarMultSlow sets v = x * q, and returns v. It doesn't precompute a large
// table, so it is considerably slower, but requires less memory.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMultSlow(x *Scalar, q *Point) *Point {
checkInitialized(q)
s := x.Bytes()
qCached := new(projCached).FromP3(q)
v.Set(NewIdentityPoint())
t := new(Point)
for i := 255; i >= 0; i-- {
v.Double(v)
t.addCached(v, qCached)
cond := (s[i/8] >> (i % 8)) & 1
v.Select(t, v, int(cond))
}
return v
}
+17 -17
View File
@@ -90,11 +90,7 @@ func (v *Element) Add(a, b *Element) *Element {
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
return v.carryPropagate()
}
// Subtract sets v = a - b, and returns v.
@@ -232,18 +228,22 @@ func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
// Pack five 51-bit limbs into four 64-bit words:
//
// 255 204 153 102 51 0
// ├──l4──┼──l3──┼──l2──┼──l1──┼──l0──┤
// ├───u3───┼───u2───┼───u1───┼───u0───┤
// 256 192 128 64 0
u0 := t.l1<<51 | t.l0
u1 := t.l2<<(102-64) | t.l1>>(64-51)
u2 := t.l3<<(153-128) | t.l2>>(128-102)
u3 := t.l4<<(204-192) | t.l3>>(192-153)
binary.LittleEndian.PutUint64(out[0*8:], u0)
binary.LittleEndian.PutUint64(out[1*8:], u1)
binary.LittleEndian.PutUint64(out[2*8:], u2)
binary.LittleEndian.PutUint64(out[3*8:], u3)
return out[:]
}
+1 -2
View File
@@ -1,7 +1,6 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
//go:build !purego
package field
+111 -92
View File
@@ -1,7 +1,6 @@
// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT.
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
//go:build !purego
#include "textflag.h"
@@ -17,32 +16,36 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
MOVQ DX, SI
// r0 += 19×a1×b4
MOVQ 8(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
MOVQ 8(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 32(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a2×b3
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
MOVQ 16(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 24(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a3×b2
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
MOVQ 24(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 16(BX)
ADDQ AX, DI
ADCQ DX, SI
// r0 += 19×a4×b1
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
MOVQ 32(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 8(BX)
ADDQ AX, DI
ADCQ DX, SI
// r1 = a0×b1
MOVQ (CX), AX
@@ -57,25 +60,28 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
ADCQ DX, R8
// r1 += 19×a2×b4
MOVQ 16(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
MOVQ 16(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 32(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a3×b3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
MOVQ 24(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 24(BX)
ADDQ AX, R9
ADCQ DX, R8
// r1 += 19×a4×b2
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
MOVQ 32(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 16(BX)
ADDQ AX, R9
ADCQ DX, R8
// r2 = a0×b2
MOVQ (CX), AX
@@ -96,18 +102,20 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
ADCQ DX, R10
// r2 += 19×a3×b4
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
MOVQ 24(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 32(BX)
ADDQ AX, R11
ADCQ DX, R10
// r2 += 19×a4×b3
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
MOVQ 32(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 24(BX)
ADDQ AX, R11
ADCQ DX, R10
// r3 = a0×b3
MOVQ (CX), AX
@@ -134,11 +142,12 @@ TEXT ·feMul(SB), NOSPLIT, $0-24
ADCQ DX, R12
// r3 += 19×a4×b4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
MOVQ 32(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 32(BX)
ADDQ AX, R13
ADCQ DX, R12
// r4 = a0×b4
MOVQ (CX), AX
@@ -232,18 +241,22 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ DX, BX
// r0 += 38×l1×l4
MOVQ 8(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
MOVQ 8(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
SHLQ $0x01, AX
MULQ 32(CX)
ADDQ AX, SI
ADCQ DX, BX
// r0 += 38×l2×l3
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
MOVQ 16(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
SHLQ $0x01, AX
MULQ 24(CX)
ADDQ AX, SI
ADCQ DX, BX
// r1 = 2×l0×l1
MOVQ (CX), AX
@@ -253,18 +266,21 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ DX, DI
// r1 += 38×l2×l4
MOVQ 16(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
MOVQ 16(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
SHLQ $0x01, AX
MULQ 32(CX)
ADDQ AX, R8
ADCQ DX, DI
// r1 += 19×l3×l3
MOVQ 24(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
MOVQ 24(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 24(CX)
ADDQ AX, R8
ADCQ DX, DI
// r2 = 2×l0×l2
MOVQ (CX), AX
@@ -280,11 +296,13 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
ADCQ DX, R9
// r2 += 38×l3×l4
MOVQ 24(CX), AX
IMUL3Q $0x26, AX, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
MOVQ 24(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
SHLQ $0x01, AX
MULQ 32(CX)
ADDQ AX, R10
ADCQ DX, R9
// r3 = 2×l0×l3
MOVQ (CX), AX
@@ -294,18 +312,19 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ DX, R11
// r3 += 2×l1×l2
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
MOVQ 8(CX), AX
SHLQ $0x01, AX
MULQ 16(CX)
ADDQ AX, R12
ADCQ DX, R11
// r3 += 19×l4×l4
MOVQ 32(CX), AX
IMUL3Q $0x13, AX, AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
MOVQ 32(CX), DX
LEAQ (DX)(DX*8), AX
LEAQ (DX)(AX*2), AX
MULQ 32(CX)
ADDQ AX, R12
ADCQ DX, R11
// r4 = 2×l0×l4
MOVQ (CX), AX
@@ -315,11 +334,11 @@ TEXT ·feSquare(SB), NOSPLIT, $0-16
MOVQ DX, R13
// r4 += 2×l1×l3
MOVQ 8(CX), AX
IMUL3Q $0x02, AX, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
MOVQ 8(CX), AX
SHLQ $0x01, AX
MULQ 24(CX)
ADDQ AX, R14
ADCQ DX, R13
// r4 += l2×l2
MOVQ 16(CX), AX
+1 -2
View File
@@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
//go:build !amd64 || purego
package field
-16
View File
@@ -1,16 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
// +build arm64,gc,!purego
package field
//go:noescape
func carryPropagate(v *Element)
func (v *Element) carryPropagate() *Element {
carryPropagate(v)
return v
}
-42
View File
@@ -1,42 +0,0 @@
// Copyright (c) 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build arm64 && gc && !purego
#include "textflag.h"
// carryPropagate works exactly like carryPropagateGeneric and uses the
// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but
// avoids loading R0-R4 twice and uses LDP and STP.
//
// See https://golang.org/issues/43145 for the main compiler issue.
//
// func carryPropagate(v *Element)
TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8
MOVD v+0(FP), R20
LDP 0(R20), (R0, R1)
LDP 16(R20), (R2, R3)
MOVD 32(R20), R4
AND $0x7ffffffffffff, R0, R10
AND $0x7ffffffffffff, R1, R11
AND $0x7ffffffffffff, R2, R12
AND $0x7ffffffffffff, R3, R13
AND $0x7ffffffffffff, R4, R14
ADD R0>>51, R11, R11
ADD R1>>51, R12, R12
ADD R2>>51, R13, R13
ADD R3>>51, R14, R14
// R4>>51 * 19 + R10 -> R10
LSR $51, R4, R21
MOVD $19, R22
MADD R22, R10, R21, R10
STP (R10, R11), 0(R20)
STP (R12, R13), 16(R20)
MOVD R14, 32(R20)
RET
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
// +build !arm64 !gc purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}
+88 -82
View File
@@ -12,20 +12,42 @@ type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
// mul returns a * b.
func mul(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
// addMul returns v + a * b.
func addMul(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// mul19 returns v * 19.
func mul19(v uint64) uint64 {
// Using this approach seems to yield better optimizations than *19.
return v + (v+v<<3)<<1
}
// addMul19 returns v + 19 * a * b, where a and b are at most 52 bits.
func addMul19(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(mul19(a), b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// addMul38 returns v + 38 * a * b, where a and b are at most 52 bits.
func addMul38(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(mul19(a), b*2)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
@@ -76,45 +98,40 @@ func feMulGeneric(v, a, b *Element) {
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
r0 := mul(a0, b0)
r0 = addMul19(r0, a1, b4)
r0 = addMul19(r0, a2, b3)
r0 = addMul19(r0, a3, b2)
r0 = addMul19(r0, a4, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
r1 := mul(a0, b1)
r1 = addMul(r1, a1, b0)
r1 = addMul19(r1, a2, b4)
r1 = addMul19(r1, a3, b3)
r1 = addMul19(r1, a4, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
r2 := mul(a0, b2)
r2 = addMul(r2, a1, b1)
r2 = addMul(r2, a2, b0)
r2 = addMul19(r2, a3, b4)
r2 = addMul19(r2, a4, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
r3 := mul(a0, b3)
r3 = addMul(r3, a1, b2)
r3 = addMul(r3, a2, b1)
r3 = addMul(r3, a3, b0)
r3 = addMul19(r3, a4, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
r4 := mul(a0, b4)
r4 = addMul(r4, a1, b3)
r4 = addMul(r4, a2, b2)
r4 = addMul(r4, a3, b1)
r4 = addMul(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
@@ -149,7 +166,7 @@ func feMulGeneric(v, a, b *Element) {
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr0 := r0.lo&maskLow51Bits + mul19(c4)
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
@@ -158,8 +175,12 @@ func feMulGeneric(v, a, b *Element) {
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as an Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
v.l0 = rr0&maskLow51Bits + mul19(rr4>>51)
v.l1 = rr1&maskLow51Bits + rr0>>51
v.l2 = rr2&maskLow51Bits + rr1>>51
v.l3 = rr3&maskLow51Bits + rr2>>51
v.l4 = rr4&maskLow51Bits + rr3>>51
}
func feSquareGeneric(v, a *Element) {
@@ -190,44 +211,31 @@ func feSquareGeneric(v, a *Element) {
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
r0 := mul(l0, l0)
r0 = addMul38(r0, l1, l4)
r0 = addMul38(r0, l2, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
r1 := mul(l0*2, l1)
r1 = addMul38(r1, l2, l4)
r1 = addMul19(r1, l3, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
r2 := mul(l0*2, l2)
r2 = addMul(r2, l1, l1)
r2 = addMul38(r2, l3, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
r3 := mul(l0*2, l3)
r3 = addMul(r3, l1*2, l2)
r3 = addMul19(r3, l4, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
r4 := mul(l0*2, l4)
r4 = addMul(r4, l1*2, l3)
r4 = addMul(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
@@ -235,32 +243,30 @@ func feSquareGeneric(v, a *Element) {
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr0 := r0.lo&maskLow51Bits + mul19(c4)
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
v.l0 = rr0&maskLow51Bits + mul19(rr4>>51)
v.l1 = rr1&maskLow51Bits + rr0>>51
v.l2 = rr2&maskLow51Bits + rr1>>51
v.l3 = rr3&maskLow51Bits + rr2>>51
v.l4 = rr4&maskLow51Bits + rr3>>51
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// carryPropagate brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
func (v *Element) carryPropagate() *Element {
// (l4>>51) is at most 64 - 51 = 13 bits, so (l4>>51)*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
l0 := v.l0
v.l0 = v.l0&maskLow51Bits + mul19(v.l4>>51)
v.l4 = v.l4&maskLow51Bits + v.l3>>51
v.l3 = v.l3&maskLow51Bits + v.l2>>51
v.l2 = v.l2&maskLow51Bits + v.l1>>51
v.l1 = v.l1&maskLow51Bits + l0>>51
return v
}
+53
View File
@@ -0,0 +1,53 @@
#!/usr/bin/env bash
set -euo pipefail
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <tag>"
exit 1
fi
TAG="$1"
TMPDIR="$(mktemp -d)"
cleanup() {
rm -rf "$TMPDIR"
}
trap cleanup EXIT
command -v git >/dev/null
command -v git-filter-repo >/dev/null
if [ -d "$HOME/go/.git" ]; then
REFERENCE=(--reference "$HOME/go" --dissociate)
else
REFERENCE=()
fi
git -c advice.detachedHead=false clone --no-checkout "${REFERENCE[@]}" \
-b "$TAG" https://go.googlesource.com/go.git "$TMPDIR"
# Simplify the history graph by removing the dev.boringcrypto branches, whose
# merges end up empty after grafting anyway. This also fixes a weird quirk
# (maybe a git-filter-repo bug?) where only one file from an old path,
# src/crypto/ed25519/internal/edwards25519/const.go, would still exist in the
# filtered repo.
git -C "$TMPDIR" replace --graft f771edd7f9 99f1bf54eb
git -C "$TMPDIR" replace --graft 109c13b64f c2f96e686f
git -C "$TMPDIR" replace --graft aa4da4f189 912f075047
git -C "$TMPDIR" filter-repo --force \
--paths-from-file /dev/stdin \
--prune-empty always \
--prune-degenerate always \
--tag-callback 'tag.skip()' <<'EOF'
src/crypto/internal/fips140/edwards25519
src/crypto/internal/edwards25519
src/crypto/ed25519/internal/edwards25519
EOF
git fetch "$TMPDIR"
git update-ref "refs/heads/upstream/$TAG" FETCH_HEAD
echo
echo "Fetched upstream history up to $TAG. Merge with:"
echo -e "\tgit merge --no-ff --no-commit --allow-unrelated-histories upstream/$TAG"
+18 -9
View File
@@ -7,6 +7,7 @@ package edwards25519
import (
"encoding/binary"
"errors"
"math/bits"
)
// A Scalar is an integer modulo
@@ -179,15 +180,23 @@ func isReduced(s []byte) bool {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
s0 := binary.LittleEndian.Uint64(s[:8])
s1 := binary.LittleEndian.Uint64(s[8:16])
s2 := binary.LittleEndian.Uint64(s[16:24])
s3 := binary.LittleEndian.Uint64(s[24:])
l0 := binary.LittleEndian.Uint64(scalarMinusOneBytes[:8])
l1 := binary.LittleEndian.Uint64(scalarMinusOneBytes[8:16])
l2 := binary.LittleEndian.Uint64(scalarMinusOneBytes[16:24])
l3 := binary.LittleEndian.Uint64(scalarMinusOneBytes[24:])
// Do a constant time subtraction chain scalarMinusOneBytes - s. If there is
// a borrow at the end, then s > scalarMinusOneBytes.
_, b := bits.Sub64(l0, s0, 0)
_, b = bits.Sub64(l1, s1, b)
_, b = bits.Sub64(l2, s2, b)
_, b = bits.Sub64(l3, s3, b)
return b == 0
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
+1 -3
View File
@@ -4,9 +4,7 @@
package edwards25519
import (
"crypto/subtle"
)
import "crypto/subtle"
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
+2 -2
View File
@@ -14,8 +14,8 @@
//
// [Antithesis Go SDK]: https://antithesis.com/docs/using_antithesis/sdk/go/
// [Antithesis platform]: https://antithesis.com
// [test properties]: https://antithesis.com/docs/using_antithesis/properties/
// [workload]: https://antithesis.com/docs/getting_started/first_test/
// [test properties]: https://antithesis.com/docs/properties_assertions/properties/
// [workload]: https://antithesis.com/docs/test_templates/first_test/
// [antithesis-go-generator]: https://antithesis.com/docs/using_antithesis/sdk/go/instrumentor/
// [triage report]: https://antithesis.com/docs/reports/
// [here]: https://antithesis.com/docs/using_antithesis/sdk/fallback/
+1 -1
View File
@@ -3,7 +3,7 @@ package internal
// --------------------------------------------------------------------------------
// Versions
// --------------------------------------------------------------------------------
const SDK_Version = "0.6.0"
const SDK_Version = "0.7.0"
const Protocol_Version = "1.1.0"
// --------------------------------------------------------------------------------
+7 -3
View File
@@ -18,8 +18,8 @@ Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Arne Hormann <arnehormann at gmail.com>
Artur Melanchyk <artur.melanchyk@gmail.com>
Asta Xie <xiemengjun at gmail.com>
B Lamarche <blam413 at gmail.com>
@@ -38,6 +38,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
Demouth <yuya at demouth.net>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
@@ -66,6 +67,7 @@ Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Jennifer Purevsuren <jennifer at dolthub.com>
Jerome Meyer <jxmeyer at gmail.com>
Jiabin Zhang <jiabin.z at qq.com>
Jiajia Zhong <zhong2plus at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joe Mann <contact at joemann.co.uk>
@@ -85,10 +87,12 @@ Linh Tran Tuan <linhduonggnu at gmail.com>
Lion Yang <lion at aosc.xyz>
Luca Looz <luca.looz92 at gmail.com>
Lucas Liu <extrafliu at gmail.com>
Lunny Xiao <xiaolunwen at gmail.com>
Luke Scott <luke at webconnex.com>
Lunny Xiao <xiaolunwen at gmail.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com>
Minh Quang <minhquang4334 at gmail.com>
Morgan Tocker <tocker at gmail.com>
Nao Yokotsuka <yokotukanao at gmail.com>
Nathanial Murphy <nathanial.murphy at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
@@ -99,7 +103,6 @@ Paul Bonser <misterpib at gmail.com>
Paulius Lozys <pauliuslozys at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com>
Phil Porada <philporada at gmail.com>
Minh Quang <minhquang4334 at gmail.com>
Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com>
@@ -134,6 +137,7 @@ Ziheng Lyu <zihenglv at gmail.com>
# Organizations
Barracuda Networks, Inc.
Block, Inc.
Counting Ltd.
Defined Networking Inc.
DigitalOcean Inc.
+16 -3
View File
@@ -1,13 +1,26 @@
# Changelog
## v1.10.0 (2026-04-28)
* Fix `getSystemVar("max_allowed_packet")` potentially returned wrong value. (#1754)
This affects only when `maxAllowedPacket=0` is set.
* Bump filippo.io/edwards25519 from 1.1.1 to 1.2.0. (#1756)
While older versions have reported CVEs, they do not affect go-mysql.
* Update Go versions to 1.24-1.26. (#1763)
* Enhance interpolateParams to correctly handle placeholders. (#1732)
The question mark (?) within strings and comments will no longer be treated as a placeholder.
## v1.9.3 (2025-06-13)
* `tx.Commit()` and `tx.Rollback()` returned `ErrInvalidConn` always.
Now they return cached real error if present. (#1690)
* Optimize reading small resultsets to fix performance regression
introduced by compression protocol support. (#1707)
* Optimize reading small result sets to fix a performance regression
introduced by compression protocol support. (`#1707`)
* Fix `db.Ping()` on compressed connection. (#1723)
+5 -2
View File
@@ -1,5 +1,8 @@
# Go-MySQL-Driver
[![DeepWiki](https://img.shields.io/badge/DeepWiki-go--sql--driver%2Fmysql-blue.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAACwAAAAyCAYAAAAnWDnqAAAAAXNSR0IArs4c6QAAA05JREFUaEPtmUtyEzEQhtWTQyQLHNak2AB7ZnyXZMEjXMGeK/AIi+QuHrMnbChYY7MIh8g01fJoopFb0uhhEqqcbWTp06/uv1saEDv4O3n3dV60RfP947Mm9/SQc0ICFQgzfc4CYZoTPAswgSJCCUJUnAAoRHOAUOcATwbmVLWdGoH//PB8mnKqScAhsD0kYP3j/Yt5LPQe2KvcXmGvRHcDnpxfL2zOYJ1mFwrryWTz0advv1Ut4CJgf5uhDuDj5eUcAUoahrdY/56ebRWeraTjMt/00Sh3UDtjgHtQNHwcRGOC98BJEAEymycmYcWwOprTgcB6VZ5JK5TAJ+fXGLBm3FDAmn6oPPjR4rKCAoJCal2eAiQp2x0vxTPB3ALO2CRkwmDy5WohzBDwSEFKRwPbknEggCPB/imwrycgxX2NzoMCHhPkDwqYMr9tRcP5qNrMZHkVnOjRMWwLCcr8ohBVb1OMjxLwGCvjTikrsBOiA6fNyCrm8V1rP93iVPpwaE+gO0SsWmPiXB+jikdf6SizrT5qKasx5j8ABbHpFTx+vFXp9EnYQmLx02h1QTTrl6eDqxLnGjporxl3NL3agEvXdT0WmEost648sQOYAeJS9Q7bfUVoMGnjo4AZdUMQku50McDcMWcBPvr0SzbTAFDfvJqwLzgxwATnCgnp4wDl6Aa+Ax283gghmj+vj7feE2KBBRMW3FzOpLOADl0Isb5587h/U4gGvkt5v60Z1VLG8BhYjbzRwyQZemwAd6cCR5/XFWLYZRIMpX39AR0tjaGGiGzLVyhse5C9RKC6ai42ppWPKiBagOvaYk8lO7DajerabOZP46Lby5wKjw1HCRx7p9sVMOWGzb/vA1hwiWc6jm3MvQDTogQkiqIhJV0nBQBTU+3okKCFDy9WwferkHjtxib7t3xIUQtHxnIwtx4mpg26/HfwVNVDb4oI9RHmx5WGelRVlrtiw43zboCLaxv46AZeB3IlTkwouebTr1y2NjSpHz68WNFjHvupy3q8TFn3Hos2IAk4Ju5dCo8B3wP7VPr/FGaKiG+T+v+TQqIrOqMTL1VdWV1DdmcbO8KXBz6esmYWYKPwDL5b5FA1a0hwapHiom0r/cKaoqr+27/XcrS5UwSMbQAAAABJRU5ErkJggg==)](https://deepwiki.com/go-sql-driver/mysql)
A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package
![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin")
@@ -42,8 +45,8 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
## Requirements
* Go 1.21 or higher. We aim to support the 3 latest versions of Go.
* MySQL (5.7+) and MariaDB (10.5+) are supported.
* Go 1.24 or higher. We aim to support the 3 latest versions of Go.
* MySQL (5.7+) and MariaDB (10.5+) are supported by maintainers.
* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP.
* Do not ask questions about TiDB in our issue tracker or forum.
* [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang)
+1 -1
View File
@@ -305,7 +305,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
if !mc.cfg.AllowNativePasswords {
return nil, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// https://dev.mysql.com/doc/dev/mysql-server/8.4.5/page_protocol_connection_phase_authentication_methods_native_password_authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, nil
-1
View File
@@ -7,7 +7,6 @@
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package mysql
-1
View File
@@ -7,7 +7,6 @@
// You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package mysql
+185 -90
View File
@@ -33,7 +33,8 @@ type mysqlConn struct {
connector *connector
maxAllowedPacket int
maxWriteSize int
flags clientFlag
capabilities capabilityFlag
extCapabilities extendedCapabilityFlag
status statusFlag
sequence uint8
compressSequence uint8
@@ -171,7 +172,7 @@ func (mc *mysqlConn) close() {
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// function after successful authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
@@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
if err = mc.skipColumns(stmt.paramCount); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
if mc.extCapabilities&clientCacheMetadata != 0 {
if stmt.columns, err = mc.readColumns(int(columnCount), nil); err != nil {
return nil, err
}
} else {
if err = mc.skipColumns(int(columnCount)); err != nil {
return nil, err
}
}
}
}
@@ -237,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
const (
stateNormal = iota
stateString
stateEscape
stateEOLComment
stateSlashStarComment
stateBacktick
)
const (
QUOTE_BYTE = byte('\'')
DBL_QUOTE_BYTE = byte('"')
BACKSLASH_BYTE = byte('\\')
QUESTION_MARK_BYTE = byte('?')
SLASH_BYTE = byte('/')
STAR_BYTE = byte('*')
HASH_BYTE = byte('#')
MINUS_BYTE = byte('-')
LINE_FEED_BYTE = byte('\n')
BACKTICK_BYTE = byte('`')
)
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.cleanup()
// interpolateParams would be called before sending any query.
// So its safe to retry.
return "", driver.ErrBadConn
}
buf = buf[:0]
state := stateNormal
singleQuotes := false
lastChar := byte(0)
argPos := 0
lenQuery := len(query)
lastIdx := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
for i := range lenQuery {
currentChar := query[i]
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
state = stateString
lastChar = currentChar
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
switch currentChar {
case STAR_BYTE:
if state == stateNormal && lastChar == SLASH_BYTE {
state = stateSlashStarComment
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
case SLASH_BYTE:
if state == stateSlashStarComment && lastChar == STAR_BYTE {
state = stateNormal
// Clear lastChar so the '/' that closed the comment isn't
// reused to start a new comment with a following '*'.
lastChar = 0
continue
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
case HASH_BYTE:
if state == stateNormal {
state = stateEOLComment
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
case MINUS_BYTE:
if state == stateNormal && lastChar == MINUS_BYTE {
// -- only starts a comment if followed by whitespace or control char
if i+1 < lenQuery {
nextChar := query[i+1]
if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
state = stateEOLComment
}
} else {
buf = escapeBytesQuotes(buf, v)
state = stateEOLComment
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
case LINE_FEED_BYTE:
if state == stateEOLComment {
state = stateNormal
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
case DBL_QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = false
} else if state == stateString && !singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case QUOTE_BYTE:
if state == stateNormal {
state = stateString
singleQuotes = true
} else if state == stateString && singleQuotes {
state = stateNormal
} else if state == stateEscape {
state = stateString
}
case BACKSLASH_BYTE:
if state == stateString && !noBackslashEscapes {
state = stateEscape
}
case QUESTION_MARK_BYTE:
if state == stateNormal {
if argPos >= len(args) {
return "", driver.ErrSkip
}
buf = append(buf, query[lastIdx:i]...)
arg := args[argPos]
argPos++
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
if arg == nil {
buf = append(buf, "NULL"...)
lastIdx = i + 1
break
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, false)
} else {
buf = escapeBytesBackslash(buf, v, false)
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
if noBackslashEscapes {
buf = escapeBytesQuotes(buf, v, true)
} else {
buf = escapeBytesBackslash(buf, v, true)
}
}
case string:
if noBackslashEscapes {
buf = escapeStringQuotes(buf, v)
} else {
buf = escapeStringBackslash(buf, v)
}
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
lastIdx = i + 1
}
case BACKTICK_BYTE:
if state == stateBacktick {
state = stateNormal
} else if state == stateNormal {
state = stateBacktick
}
}
lastChar = currentChar
}
buf = append(buf, query[lastIdx:]...)
if argPos != len(args) {
return "", driver.ErrSkip
}
@@ -370,19 +463,19 @@ func (mc *mysqlConn) exec(query string) error {
}
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, _, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.skipColumns(resLen); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
if err := mc.skipRows(); err != nil {
return err
}
}
@@ -419,7 +512,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
resLen, _, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@@ -439,21 +532,20 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
rows.rs.columns, err = mc.readColumns(resLen, nil)
return rows, err
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
func (mc *mysqlConn) getSystemVar(name string) (string, error) {
// Send command
handleOk := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
return "", err
}
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, _, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
@@ -461,17 +553,20 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
if err := mc.skipColumns(resLen); err != nil {
return "", err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
// Convert to string before skipRows, which may
// overwrite the read buffer that dest[0] points into.
val := string(dest[0].([]byte))
return val, mc.skipRows()
}
}
return nil, err
return "", err
}
// cancel is called when the query has canceled.
+7 -5
View File
@@ -42,7 +42,7 @@ func encodeConnectionAttributes(cfg *Config) string {
}
// user-defined connection attributes
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
for connAttr := range strings.SplitSeq(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
@@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.buf = newBuffer()
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
@@ -153,6 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}
}
mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg)
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
@@ -161,13 +162,14 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// (https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html#sect_protocol_connection_phase_fast_path_fails).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
// compression is enabled after auth, not right after sending handshake response.
if mc.capabilities&clientCompress > 0 {
mc.compress = true
mc.compIO = newCompIO(mc)
}
@@ -180,7 +182,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.Close()
return nil, err
}
n, err := strconv.Atoi(string(maxap))
n, err := strconv.Atoi(maxap)
if err != nil {
mc.Close()
return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err)
+17 -4
View File
@@ -32,7 +32,7 @@ const (
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
const (
iOK byte = 0x00
@@ -42,11 +42,12 @@ const (
iERR byte = 0xff
)
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
// https://mariadb.com/kb/en/connection/#capabilities
type capabilityFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientMySQL capabilityFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
@@ -73,6 +74,18 @@ const (
clientDeprecateEOF
)
// https://mariadb.com/kb/en/connection/#capabilities
type extendedCapabilityFlag uint32
const (
progressIndicator extendedCapabilityFlag = 1 << iota
clientComMulti
clientStmtBulkOperations
clientExtendedMetadata
clientCacheMetadata
clientUnitBulkResult
)
const (
comQuit byte = iota + 1
comInitDB
+4 -5
View File
@@ -15,6 +15,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"maps"
"math/big"
"net"
"net/url"
@@ -157,9 +158,7 @@ func (cfg *Config) Clone() *Config {
}
if len(cp.Params) > 0 {
cp.Params = make(map[string]string, len(cfg.Params))
for k, v := range cfg.Params {
cp.Params[k] = v
}
maps.Copy(cp.Params, cfg.Params)
}
if cfg.pubKey != nil {
cp.pubKey = &rsa.PublicKey{
@@ -414,7 +413,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
for k = 0; k < j; k++ { // We cannot use k = range j here, because we use dsn[:k] below
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
@@ -477,7 +476,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
for v := range strings.SplitSeq(params, "&") {
key, value, found := strings.Cut(v, "=")
if !found {
continue
+21 -17
View File
@@ -120,23 +120,24 @@ func (mf *mysqlField) typeDatabaseName() string {
}
var (
scanTypeFloat32 = reflect.TypeOf(float32(0))
scanTypeFloat64 = reflect.TypeOf(float64(0))
scanTypeInt8 = reflect.TypeOf(int8(0))
scanTypeInt16 = reflect.TypeOf(int16(0))
scanTypeInt32 = reflect.TypeOf(int32(0))
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(sql.NullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
scanTypeUint64 = reflect.TypeOf(uint64(0))
scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(any))
scanTypeFloat32 = reflect.TypeFor[float32]()
scanTypeFloat64 = reflect.TypeFor[float64]()
scanTypeInt8 = reflect.TypeFor[int8]()
scanTypeInt16 = reflect.TypeFor[int16]()
scanTypeInt32 = reflect.TypeFor[int32]()
scanTypeInt64 = reflect.TypeFor[int64]()
scanTypeNullFloat = reflect.TypeFor[sql.NullFloat64]()
scanTypeNullInt = reflect.TypeFor[sql.NullInt64]()
scanTypeNullUint = reflect.TypeFor[sql.Null[uint64]]()
scanTypeNullTime = reflect.TypeFor[sql.NullTime]()
scanTypeUint8 = reflect.TypeFor[uint8]()
scanTypeUint16 = reflect.TypeFor[uint16]()
scanTypeUint32 = reflect.TypeFor[uint32]()
scanTypeUint64 = reflect.TypeFor[uint64]()
scanTypeString = reflect.TypeFor[string]()
scanTypeNullString = reflect.TypeFor[sql.NullString]()
scanTypeBytes = reflect.TypeFor[[]byte]()
scanTypeUnknown = reflect.TypeFor[*any]()
)
type mysqlField struct {
@@ -185,6 +186,9 @@ func (mf *mysqlField) scanType() reflect.Type {
}
return scanTypeInt64
}
if mf.flags&flagUnsigned != 0 {
return scanTypeNullUint
}
return scanTypeNullInt
case fieldTypeFloat:
+1 -4
View File
@@ -95,10 +95,7 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a
func (mc *okHandler) handleInFileRequest(name string) (err error) {
var rdr io.Reader
packetSize := defaultPacketSize
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
}
packetSize := min(mc.maxWriteSize, defaultPacketSize)
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
// The server might return an an absolute path. See issue #355.
+205 -138
View File
@@ -179,20 +179,22 @@ func (mc *mysqlConn) writePacket(data []byte) error {
******************************************************************************/
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
return
}
if data[0] == iERR {
return nil, "", mc.handleErrorPacket(data)
err = mc.handleErrorPacket(data)
return
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, "", fmt.Errorf(
return nil, 0, 0, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
@@ -210,15 +212,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
pos += 8 + 1
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, "", ErrOldProtocol
capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if capabilities&clientProtocol41 == 0 {
return nil, capabilities, 0, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
if capabilities&clientSSL == 0 && mc.cfg.TLS != nil {
if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.TLS = nil
} else {
return nil, "", ErrNoTLS
return nil, capabilities, 0, "", ErrNoTLS
}
}
pos += 2
@@ -228,11 +230,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// status flags [2 bytes]
pos += 3
// capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2
// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 11
// reserved (all [00]) [6 bytes]
pos += 7
if capabilities&clientMySQL == 0 {
// MariaDB server extended flag
extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4]))
}
pos += 4
// second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -260,82 +267,72 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], authData)
return b[:], plugin, nil
return b[:], capabilities, extendedCapabilities, plugin, nil
}
// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], authData)
return b[:], plugin, nil
return b[:], capabilities, 0, plugin, nil
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientConnectAttrs |
mc.flags&clientLongFlag
// initCapabilities initializes the capabilities based on server support and configuration
func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) {
clientCapabilities :=
clientMySQL |
clientLongFlag |
clientProtocol41 |
clientSecureConn |
clientTransactions |
clientPluginAuthLenEncClientData |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
clientConnectAttrs |
clientDeprecateEOF
sendConnectAttrs := mc.flags&clientConnectAttrs != 0
if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows
if cfg.ClientFoundRows {
clientCapabilities |= clientFoundRows
}
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
clientFlags |= clientCompress
if cfg.compress {
clientCapabilities |= clientCompress
}
// To enable TLS / SSL
if mc.cfg.TLS != nil {
clientFlags |= clientSSL
clientCapabilities |= clientSSL
}
if mc.cfg.MultiStatements {
clientFlags |= clientMultiStatements
clientCapabilities |= clientMultiStatements
}
if n := len(cfg.DBName); n > 0 {
clientCapabilities |= clientConnectWithDB
}
// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
clientFlags |= clientPluginAuthLenEncClientData
}
// only keep client capabilities that server have
mc.capabilities = clientCapabilities & serverCapabilities
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
// set MariaDB extended clientCacheMetadata capability if server support it
mc.extCapabilities = clientCacheMetadata & serverExtCapabilities
}
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
// encode length of the connection attributes
var connAttrsLEI []byte
if sendConnectAttrs {
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
}
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeBuffer(pktLen + 4)
// Client Authentication Packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
// packet header 4
// capabilities 4
// maxPacketSize 4
// collation id 1
// filler 23
data, err := mc.buf.takeSmallBuffer(4*3 + 24)
if err != nil {
mc.cleanup()
return err
}
_ = data[4*3+23] // boundery check
// ClientFlags [32 bit]
binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
// clientCapabilities [32 bit]
binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities))
// MaxPacketSize [32 bit] (none)
binary.LittleEndian.PutUint32(data[8:], 0)
@@ -353,16 +350,26 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}
// Filler [23 bytes] (all 0x00)
// or filler 19bytes + mariadb extCapabilities
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
if mc.capabilities&clientMySQL == 0 {
for ; pos < 13+19; pos++ {
data[pos] = 0
}
// MariaDB Extended Capabilities
binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities))
} else {
for ; pos < 13+23; pos++ {
data[pos] = 0
}
}
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html
// https://mariadb.com/kb/en/connection/#sslrequest-packet
if mc.cfg.TLS != nil {
// Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
if err := mc.writePacket(data); err != nil {
return err
}
@@ -379,37 +386,35 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// User [null terminated string]
if len(mc.cfg.User) > 0 {
pos += copy(data[pos:], mc.cfg.User)
data = append(data, mc.cfg.User...)
}
data[pos] = 0x00
pos++
data = append(data, 0)
// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], authResp)
data = appendLengthEncodedInteger(data, uint64(len(authResp)))
data = append(data, authResp...)
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
pos += copy(data[pos:], mc.cfg.DBName)
data[pos] = 0x00
pos++
// Database name [null terminated string]
if mc.capabilities&clientConnectWithDB != 0 {
data = append(data, mc.cfg.DBName...)
data = append(data, 0)
}
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
data = append(data, plugin...)
data = append(data, 0)
// Connection Attributes
if sendConnectAttrs {
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
if mc.capabilities&clientConnectAttrs != 0 {
connAttrsLen := len(mc.connector.encodedAttributes)
data = appendLengthEncodedInteger(data, uint64(connAttrsLen))
data = append(data, mc.connector.encodedAttributes...)
}
// Send Auth packet
return mc.writePacket(data[:pos])
return mc.writePacket(data)
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeBuffer(pktLen)
@@ -511,7 +516,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
case iEOF:
if len(data) == 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html
return nil, "mysql_old_password", nil
}
pluginEndIndex := bytes.IndexByte(data, 0x00)
@@ -545,36 +550,41 @@ func (mc *okHandler) readResultOK() error {
// Result Set Header Packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err != nil {
return 0, err
return 0, false, err
}
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
return 0, false, mc.handleOkPacket(data)
case iERR:
return 0, mc.conn().handleErrorPacket(data)
return 0, false, mc.conn().handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
return 0, false, mc.handleInFileRequest(string(data[1:]))
}
// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
num, _, _ := readLengthEncodedInteger(data)
// https://mariadb.com/kb/en/result-set-packets/#column-count-packet
num, _, len := readLengthEncodedInteger(data)
if mc.extCapabilities&clientCacheMetadata != 0 {
return int(num), data[len] == 0x01, nil
}
// ignore remaining data in the packet. see #1478.
return int(num), nil
return int(num), true, nil
}
// Error Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
if data[0] != iERR {
return ErrMalformPkt
@@ -656,7 +666,7 @@ func (mc *mysqlConn) clearResult() *okHandler {
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html
func (mc *okHandler) handleOkPacket(data []byte) error {
var n, m int
var affectedRows, insertId uint64
@@ -690,24 +700,19 @@ func (mc *okHandler) handleOkPacket(data []byte) error {
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html#sect_protocol_com_query_response_text_resultset_column_definition_41
func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) {
columns := make([]mysqlField, count)
if len(old) != count {
old = nil
}
for i := 0; ; i++ {
for i := range count {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
}
// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
@@ -728,7 +733,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
return nil, err
}
pos += n
columns[i].tableName = string(tableName)
if old != nil && old[i].tableName == string(tableName) {
// avoid allocating new string
columns[i].tableName = old[i].tableName
} else {
columns[i].tableName = string(tableName)
}
} else {
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
@@ -749,7 +759,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
if err != nil {
return nil, err
}
columns[i].name = string(name)
if old != nil && old[i].name == string(name) {
// avoid allocating new string
columns[i].name = old[i].name
} else {
columns[i].name = string(name)
}
pos += n
// Original name [len coded string]
@@ -780,17 +795,17 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
// Decimals [uint8]
columns[i].decimals = data[pos]
//pos++
// Default value [len coded binary]
//if pos < len(data) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
// skip EOF packet if client does not support deprecateEOF
if err := mc.skipEof(); err != nil {
return nil, err
}
return columns, nil
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_row.html
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
@@ -804,9 +819,20 @@ func (rows *textRows) readRow(dest []driver.Value) error {
}
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
// text row packets may starts with LengthEncodedString.
// In such case, 0xFE can mean string larger than 0xffffff.
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le
if data[0] == iEOF && len(data) <= 0xffffff {
if mc.capabilities&clientDeprecateEOF == 0 {
// Deprecated EOF packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html
mc.status = readStatus(data[3:])
} else {
// Ok Packet with an 0xFE header
_, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
_, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
mc.status = readStatus(data[1+n+m:])
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
@@ -880,8 +906,34 @@ func (rows *textRows) readRow(dest []driver.Value) error {
return nil
}
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() error {
func (mc *mysqlConn) skipPackets(n int) error {
for range n {
if _, err := mc.readPacket(); err != nil {
return err
}
}
return nil
}
// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set
func (mc *mysqlConn) skipEof() error {
if mc.capabilities&clientDeprecateEOF == 0 {
if _, err := mc.readPacket(); err != nil {
return err
}
}
return nil
}
func (mc *mysqlConn) skipColumns(n int) error {
if err := mc.skipPackets(n); err != nil {
return err
}
return mc.skipEof()
}
// Reads Packets until EOF-Packet or an Error appears.
func (mc *mysqlConn) skipRows() error {
for {
data, err := mc.readPacket()
if err != nil {
@@ -892,10 +944,20 @@ func (mc *mysqlConn) readUntilEOF() error {
case iERR:
return mc.handleErrorPacket(data)
case iEOF:
if len(data) == 5 {
mc.status = readStatus(data[3:])
// text row packets may starts with LengthEncodedString.
// In such case, 0xFE can mean string larger than 0xffffff.
if len(data) <= 0xffffff {
if mc.capabilities&clientDeprecateEOF == 0 {
// EOF packet
mc.status = readStatus(data[3:])
} else {
// OK packet with an 0xFE header
_, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
_, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
mc.status = readStatus(data[1+n+m:])
}
return nil
}
return nil
}
}
}
@@ -905,7 +967,7 @@ func (mc *mysqlConn) readUntilEOF() error {
******************************************************************************/
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
data, err := stmt.mc.readPacket()
if err == nil {
@@ -932,7 +994,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
return 0, err
}
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
maxLen := stmt.mc.maxAllowedPacket - 1
pktLen := maxLen
@@ -979,7 +1041,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
}
// Execute Prepared Statement
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
@@ -993,10 +1055,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc := stmt.mc
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
}
longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64)
// Reset packet-sequence
mc.resetSequence()
@@ -1185,17 +1244,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()
resLen, _, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.conn().readUntilEOF(); err != nil {
if err := mc.conn().skipColumns(resLen); err != nil {
return err
}
// rows
if err := mc.conn().readUntilEOF(); err != nil {
if err := mc.conn().skipRows(); err != nil {
return err
}
}
@@ -1203,7 +1262,7 @@ func (mc *okHandler) discardResults() error {
return nil
}
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
if err != nil {
@@ -1212,9 +1271,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
// EOF/OK Packet
if data[0] == iEOF {
if rows.mc.capabilities&clientDeprecateEOF == 0 {
// EOF packet
rows.mc.status = readStatus(data[3:])
} else {
// OK Packet with an 0xFE header
_, _, n := readLengthEncodedInteger(data[1:])
_, _, m := readLengthEncodedInteger(data[1+n:])
rows.mc.status = readStatus(data[1+n+m:])
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
+4 -2
View File
@@ -8,6 +8,8 @@
package mysql
import "slices"
import "database/sql/driver"
// Result exposes data not available through *connection.Result.
@@ -42,9 +44,9 @@ func (res *mysqlResult) RowsAffected() (int64, error) {
}
func (res *mysqlResult) AllLastInsertIds() []int64 {
return append([]int64{}, res.insertIds...) // defensive copy
return slices.Clone(res.insertIds) // defensive copy
}
func (res *mysqlResult) AllRowsAffected() []int64 {
return append([]int64{}, res.affectedRows...) // defensive copy
return slices.Clone(res.affectedRows) // defensive copy
}
+5 -5
View File
@@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) {
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
err = mc.skipRows()
}
if err == nil {
handleOk := mc.clearResult()
@@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
// Remove unread packets from stream
if !rows.rs.done {
if err := rows.mc.readUntilEOF(); err != nil {
if err := rows.mc.skipRows(); err != nil {
return 0, err
}
rows.rs.done = true
@@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
rows.rs = resultSet{}
// rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
@@ -186,7 +186,7 @@ func (rows *binaryRows) NextResultSet() error {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
rows.rs.columns, err = rows.mc.readColumns(resLen, nil)
return err
}
@@ -208,7 +208,7 @@ func (rows *textRows) NextResultSet() (err error) {
return err
}
rows.rs.columns, err = rows.mc.readColumns(resLen)
rows.rs.columns, err = rows.mc.readColumns(resLen, nil)
return err
}
+26 -8
View File
@@ -20,6 +20,7 @@ type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField
}
func (stmt *mysqlStmt) Close() error {
@@ -64,19 +65,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
handleOk := stmt.mc.clearResult()
// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
if resLen > 0 {
// Columns
if err = mc.readUntilEOF(); err != nil {
return nil, err
if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 {
// we can not skip column metadata because next stmt.Query() may use it.
if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil {
return nil, err
}
} else {
if err = mc.skipColumns(resLen); err != nil {
return nil, err
}
}
// Rows
if err := mc.readUntilEOF(); err != nil {
if err = mc.skipRows(); err != nil {
return nil, err
}
}
@@ -107,7 +115,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
// Read Result
handleOk := stmt.mc.clearResult()
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@@ -116,7 +124,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if resLen > 0 {
rows.mc = mc
rows.rs.columns, err = mc.readColumns(resLen)
if metadataFollows {
if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil {
return nil, err
}
stmt.columns = rows.rs.columns
} else {
if err = mc.skipEof(); err != nil {
return nil, err
}
rows.rs.columns = stmt.columns
}
} else {
rows.rs.done = true
@@ -131,7 +149,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
return rows, err
}
var jsonType = reflect.TypeOf(json.RawMessage{})
var jsonType = reflect.TypeFor[json.RawMessage]()
type converter struct{}
@@ -193,7 +211,7 @@ func (c converter) ConvertValue(v any) (driver.Value, error) {
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
var valuerReflectType = reflect.TypeFor[driver.Valuer]()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
+65 -90
View File
@@ -182,7 +182,7 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
for i := range 4 {
v, err := bToi(b[i])
if err != nil {
return 0, err
@@ -207,7 +207,7 @@ func parseByte2Digits(b1, b2 byte) (int, error) {
func parseByteNanoSec(b []byte) (int, error) {
ns, digit := 0, 100000 // max is 6-digits
for i := 0; i < len(b); i++ {
for i := range b {
v, err := bToi(b[i])
if err != nil {
return 0, err
@@ -625,108 +625,80 @@ func reserveBuffer(buf []byte, appendSize int) []byte {
return buf[:newSize]
}
// escapeBytesBackslash escapes []byte with backslashes (\)
// This escapes the contents of a string (provided as []byte) by adding backslashes before special
// characters, and turning others into specific escape sequences, such as
// turning newlines into \n and null bytes into \0.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
func escapeBytesBackslash(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
// Lookup table for backslash escapes (used for both string and bytes)
var backslashEscapeTable [256]byte
for _, c := range v {
switch c {
case '\x00':
buf[pos+1] = '0'
buf[pos] = '\\'
pos += 2
case '\n':
buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2
case '\r':
buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2
case '\x1a':
buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2
case '\'':
buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2
case '"':
buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2
case '\\':
buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2
default:
buf[pos] = c
pos++
}
}
return buf[:pos]
func init() {
backslashEscapeTable['\x00'] = '0'
backslashEscapeTable['\n'] = 'n'
backslashEscapeTable['\r'] = 'r'
backslashEscapeTable['\x1a'] = 'Z'
backslashEscapeTable['\''] = '\''
backslashEscapeTable['"'] = '"'
backslashEscapeTable['\\'] = '\\'
}
// escapeStringBackslash is similar to escapeBytesBackslash but for string.
func escapeStringBackslash(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
buf = reserveBuffer(buf, len(v)*2+2)
buf[pos] = '\''
pos++
for i := 0; i < len(v); i++ {
c := v[i]
switch c {
case '\x00':
buf[pos+1] = '0'
if esc := backslashEscapeTable[c]; esc != 0 {
buf[pos+1] = esc
buf[pos] = '\\'
pos += 2
case '\n':
buf[pos+1] = 'n'
buf[pos] = '\\'
pos += 2
case '\r':
buf[pos+1] = 'r'
buf[pos] = '\\'
pos += 2
case '\x1a':
buf[pos+1] = 'Z'
buf[pos] = '\\'
pos += 2
case '\'':
buf[pos+1] = '\''
buf[pos] = '\\'
pos += 2
case '"':
buf[pos+1] = '"'
buf[pos] = '\\'
pos += 2
case '\\':
buf[pos+1] = '\\'
buf[pos] = '\\'
pos += 2
default:
} else {
buf[pos] = c
pos++
}
}
buf[pos] = '\''
pos++
return buf[:pos]
}
// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
// This escapes the contents of a string by doubling up any apostrophes that
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
// effect on the server.
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
func escapeBytesQuotes(buf, v []byte) []byte {
// escapeBytesBackslash appends _binary'...' or '...' with backslash escaping for bytes.
func escapeBytesBackslash(buf, v []byte, binary bool) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
if binary {
buf = reserveBuffer(buf, len(v)*2+9)
copy(buf[pos:], []byte("_binary'"))
pos += 8
} else {
buf = reserveBuffer(buf, len(v)*2+2)
buf[pos] = '\''
pos++
}
for _, c := range v {
if esc := backslashEscapeTable[c]; esc != 0 {
buf[pos+1] = esc
buf[pos] = '\\'
pos += 2
} else {
buf[pos] = c
pos++
}
}
buf[pos] = '\''
pos++
return buf[:pos]
}
// escapeBytesQuotes appends _binary'...' or '...' with single-quote escaping for bytes.
func escapeBytesQuotes(buf, v []byte, binary bool) []byte {
pos := len(buf)
if binary {
buf = reserveBuffer(buf, len(v)*2+9)
copy(buf[pos:], []byte("_binary'"))
pos += 8
} else {
buf = reserveBuffer(buf, len(v)*2+2)
buf[pos] = '\''
pos++
}
for _, c := range v {
if c == '\'' {
buf[pos+1] = '\''
@@ -737,16 +709,18 @@ func escapeBytesQuotes(buf, v []byte) []byte {
pos++
}
}
buf[pos] = '\''
pos++
return buf[:pos]
}
// escapeStringQuotes is similar to escapeBytesQuotes but for string.
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
buf = reserveBuffer(buf, len(v)*2+2)
buf[pos] = '\''
pos++
for i := range len(v) {
c := v[i]
if c == '\'' {
buf[pos+1] = '\''
@@ -757,7 +731,8 @@ func escapeStringQuotes(buf []byte, v string) []byte {
pos++
}
}
buf[pos] = '\''
pos++
return buf[:pos]
}
+33 -15
View File
@@ -696,7 +696,7 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
// If we are here we have an auth callout defined and we have failed auth so far
// so we will callout to our auth backend for processing.
if !skip {
authorized, reason = s.processClientOrLeafCallout(c, opts, proxyRequired, trustedProxy)
authorized, reason = s.processClientOrLeafCallout(c, opts, proxyRequired, trustedProxy, ujwt)
}
// Check if we are authorized and in the auth callout account, and if so add in deny publish permissions for the auth subject.
if authorized {
@@ -797,26 +797,42 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
token = opts.Authorization
}
// Check if we have trustedKeys defined in the server. If so we require a user jwt.
if s.trustedKeys != nil {
ujwt = c.opts.JWT
if ujwt == _EMPTY_ && c.isMqtt() {
// For MQTT, we pass the password as the JWT too, but do so here so it's not
// publicly exposed in the client options if it isn't a JWT.
// MQTT can carry JWTs in the password field. Reconstruct it here for auth
// processing and auth callout, but do not populate c.opts.JWT yet or it would
// be exposed through monitoring and advisory paths even when the password is
// not actually a JWT.
if ujwt == _EMPTY_ && c.isMqtt() && c.opts.JWT == _EMPTY_ {
// Don't set juc here, leave that to the next s.trustedKeys != nil block,
// so that we don't try to trust a JWT when we aren't in operator mode. We
// will allow it to be passed through auth callout though.
if _, err := jwt.DecodeUserClaims(c.opts.Password); err == nil {
ujwt = c.opts.Password
}
if ujwt == _EMPTY_ && opts.DefaultSentinel != _EMPTY_ {
c.opts.JWT = opts.DefaultSentinel
ujwt = c.opts.JWT
}
// Check if we have trustedKeys defined in the server. If so we require a user jwt.
if s.trustedKeys != nil {
if ujwt == _EMPTY_ {
// Need to be sure that it's a NATS JWT, otherwise we will not correctly
// attempt the default sentinel below.
if _, err = jwt.DecodeUserClaims(c.opts.JWT); err == nil {
ujwt = c.opts.JWT
}
}
if ujwt == _EMPTY_ {
// Didn't fall through with a valid NATS JWT, so try the default sentinel
// if configured.
if opts.DefaultSentinel != _EMPTY_ {
c.opts.JWT = opts.DefaultSentinel
ujwt = c.opts.JWT
}
}
if ujwt == _EMPTY_ {
s.mu.Unlock()
c.Debugf("Authentication requires a user JWT")
return false
}
// So we have a valid user jwt here.
juc, err = jwt.DecodeUserClaims(ujwt)
if err != nil {
if juc, err = jwt.DecodeUserClaims(ujwt); err != nil {
s.mu.Unlock()
c.Debugf("User JWT not valid: %v", err)
return false
@@ -1015,8 +1031,10 @@ func (s *Server) processClientOrLeafAuthentication(c *client, opts *Options) (au
c.Debugf("Connection type not allowed")
return false
}
// skip validation of nonce when presented with a bearer token
// FIXME: if BearerToken is only for WSS, need check for server with that port enabled
// Skip validation of nonce when presented with a bearer token.
// While support for bearer tokens was added for WebSockets, there is no
// security benefit in restricting their use to that client protocol: the
// client can just go use the other protocol.
if !juc.BearerToken {
// Verify the signature against the nonce.
if c.opts.Sig == _EMPTY_ {
+10 -4
View File
@@ -41,7 +41,7 @@ func titleCase(m string) string {
}
// Process a callout on this client's behalf.
func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequired, trustedProxy bool) (authorized bool, errStr string) {
func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequired, trustedProxy bool, ujwt string) (authorized bool, errStr string) {
isOperatorMode := len(opts.TrustedKeys) > 0
// this is the account the user connected in, or the one running the callout
@@ -374,7 +374,7 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options, proxyRequi
// Grab client info for the request.
c.mu.Lock()
c.fillClientInfo(&claim.ClientInformation)
c.fillConnectOpts(&claim.ConnectOptions)
c.fillConnectOpts(&claim.ConnectOptions, ujwt)
// If we have a sig in the client opts, fill in nonce.
if claim.ConnectOptions.SignedNonce != _EMPTY_ {
claim.ClientInformation.Nonce = string(c.nonce)
@@ -474,16 +474,22 @@ func (c *client) fillClientInfo(ci *jwt.ClientInformation) {
// Fill in client options.
// Lock should be held.
func (c *client) fillConnectOpts(opts *jwt.ConnectOptions) {
func (c *client) fillConnectOpts(opts *jwt.ConnectOptions, ujwt string) {
if c == nil || (c.kind != CLIENT && c.kind != LEAF && c.kind != JETSTREAM && c.kind != ACCOUNT) {
return
}
o := c.opts
if ujwt == _EMPTY_ {
// The caller may supply a reconstructed JWT that should be sent to auth
// callout without storing it in c.opts.JWT. If not, fall back to the client
// option as before.
ujwt = o.JWT
}
// Do it this way to fail to compile if fields are added to jwt.ClientInformation.
*opts = jwt.ConnectOptions{
JWT: o.JWT,
JWT: ujwt,
Nkey: o.Nkey,
SignedNonce: o.Sig,
Token: o.Token,
+2 -2
View File
@@ -239,7 +239,7 @@ func (ss SequenceSet) EncodeLen() int {
return minLen + (ss.Nodes() * ((numBuckets+1)*8 + 2))
}
func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
func (ss SequenceSet) Encode(buf []byte) []byte {
nn, encLen := ss.Nodes(), ss.EncodeLen()
if cap(buf) < encLen {
@@ -268,7 +268,7 @@ func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
le.PutUint16(buf[i:], uint16(n.h))
i += 2
})
return buf[:i], nil
return buf[:i]
}
// ErrBadEncoding is returned when we can not decode properly.
+191 -122
View File
@@ -1061,18 +1061,19 @@ func (c *client) setPermissions(perms *Permissions) {
return
}
c.perms = &permissions{}
slcache := c.srv != nil && !c.srv.getOpts().NoSublistCache
// Loop over publish permissions
if perms.Publish != nil {
if perms.Publish.Allow != nil {
c.perms.pub.allow = NewSublistWithCache()
c.perms.pub.allow = NewSublist(slcache)
}
for _, pubSubject := range perms.Publish.Allow {
sub := &subscription{subject: []byte(pubSubject)}
c.perms.pub.allow.Insert(sub)
}
if len(perms.Publish.Deny) > 0 {
c.perms.pub.deny = NewSublistWithCache()
c.perms.pub.deny = NewSublist(slcache)
}
for _, pubSubject := range perms.Publish.Deny {
sub := &subscription{subject: []byte(pubSubject)}
@@ -1091,7 +1092,7 @@ func (c *client) setPermissions(perms *Permissions) {
if perms.Subscribe != nil {
var err error
if len(perms.Subscribe.Allow) > 0 {
c.perms.sub.allow = NewSublistWithCache()
c.perms.sub.allow = NewSublist(slcache)
}
for _, subSubject := range perms.Subscribe.Allow {
sub := &subscription{}
@@ -1103,7 +1104,7 @@ func (c *client) setPermissions(perms *Permissions) {
c.perms.sub.allow.Insert(sub)
}
if len(perms.Subscribe.Deny) > 0 {
c.perms.sub.deny = NewSublistWithCache()
c.perms.sub.deny = NewSublist(slcache)
// Also hold onto this array for later.
c.darray = perms.Subscribe.Deny
}
@@ -1200,6 +1201,7 @@ func (c *client) mergeDenyPermissions(what denyType, denyPubs []string) {
if c.perms == nil {
c.perms = &permissions{}
}
slcache := c.srv != nil && !c.srv.getOpts().NoSublistCache
var perms []*perm
switch what {
case pub:
@@ -1211,7 +1213,7 @@ func (c *client) mergeDenyPermissions(what denyType, denyPubs []string) {
}
for _, p := range perms {
if p.deny == nil {
p.deny = NewSublistWithCache()
p.deny = NewSublist(slcache)
}
FOR_DENY:
for _, subj := range denyPubs {
@@ -2254,12 +2256,20 @@ func (c *client) processConnect(arg []byte) error {
// least ClientProtoInfo, we need to increment the following counter.
// This is decremented when client is removed from the server's
// clients map.
if kind == CLIENT && proto >= ClientProtoInfo {
if kind == CLIENT && proto >= ClientProtoInfo && firstConnect {
srv.mu.Lock()
srv.cproto++
srv.mu.Unlock()
}
// A second CONNECT may move the client into a different account via
// checkAuthentication. Drop any previously-registered subscriptions
// from the current account first so they don't leak in that account's
// sublist after the client switches.
if !firstConnect {
c.clearAccountSubs(false)
}
// Check for Auth
if ok := srv.checkAuthentication(c); !ok {
// We may fail here because we reached max limits on an account.
@@ -3273,19 +3283,20 @@ func (c *client) canSubscribe(subject string, optQueue ...string) bool {
r := c.perms.sub.deny.Match(subject)
allowed = len(r.psubs) == 0
if queue != _EMPTY_ && len(r.qsubs) > 0 {
if allowed && queue != _EMPTY_ && len(r.qsubs) > 0 {
// If the queue appears in the deny list, then DO NOT allow.
allowed = !queueMatches(queue, r.qsubs)
}
// We use the actual subscription to signal us to spin up the deny mperms
// and cache. We check if the subject is a wildcard that contains any of
// and cache. We check if the subject is a wildcard that intersects any of
// the deny clauses.
// FIXME(dlc) - We could be smarter and track when these go away and remove.
if allowed && c.mperms == nil && subjectHasWildcard(subject) {
// Whip through the deny array and check if this wildcard subject is within scope.
// Whip through the deny array and check if this wildcard subject can
// overlap with any denied deliveries.
for _, sub := range c.darray {
if subjectIsSubsetMatch(sub, subject) {
if SubjectsCollide(sub, subject) {
c.loadMsgDenyFilter()
break
}
@@ -3658,14 +3669,7 @@ func (c *client) deliverMsg(prodIsMQTT bool, sub *subscription, acc *Account, su
// Check if we are a leafnode and have perms to check.
if client.kind == LEAF && client.perms != nil {
var subjectToCheck []byte
if subject[0] == '_' && bytes.HasPrefix(subject, []byte(gwReplyPrefix)) {
subjectToCheck = subject[gwSubjectOffset:]
} else if subject[0] == '$' && bytes.HasPrefix(subject, []byte(oldGWReplyPrefix)) {
subjectToCheck = subject[oldGWReplyStart:]
} else {
subjectToCheck = subject
}
subjectToCheck, _ := getGWRoutedSubjectOrSelf(subject)
if !client.pubAllowedFullCheck(string(subjectToCheck), true, true) {
mt.addEgressEvent(client, sub, errMsgTracePubViolation)
client.mu.Unlock()
@@ -4068,7 +4072,7 @@ func (c *client) allowedMsgTraceDest(hdr []byte, hasLock bool) (string, bool) {
return _EMPTY_, true
}
td := sliceHeader(MsgTraceDest, hdr)
if len(td) == 0 {
if len(td) == 0 || bytes.Equal(td, traceDestDisabledAsBytes) {
return _EMPTY_, true
}
dest := bytesToString(td)
@@ -4131,17 +4135,7 @@ func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bo
if !hasLock {
c.mu.Lock()
}
if resp := c.replies[subject]; resp != nil {
resp.n++
// Check if we have sent too many responses.
if c.perms.resp.MaxMsgs > 0 && resp.n > c.perms.resp.MaxMsgs {
delete(c.replies, subject)
} else if c.perms.resp.Expires > 0 && time.Since(resp.t) > c.perms.resp.Expires {
delete(c.replies, subject)
} else {
allowed = true
}
}
allowed = c.responseAllowed(subject)
if !hasLock {
c.mu.Unlock()
}
@@ -4155,6 +4149,25 @@ func (c *client) pubAllowedFullCheck(subject string, fullCheck, hasLock bool) bo
return allowed
}
// Returns true if this subject matches a tracked dynamic reply permission.
// Lock must be held.
func (c *client) responseAllowed(subject string) bool {
if c.perms == nil || c.perms.resp == nil {
return false
}
if resp := c.replies[subject]; resp != nil {
resp.n++
if c.perms.resp.MaxMsgs > 0 && resp.n > c.perms.resp.MaxMsgs {
delete(c.replies, subject)
} else if c.perms.resp.Expires > 0 && time.Since(resp.t) > c.perms.resp.Expires {
delete(c.replies, subject)
} else {
return true
}
}
return false
}
// Test whether a reply subject is a service import reply.
func isServiceReply(reply []byte) bool {
// This function is inlined and checking this way is actually faster
@@ -4162,16 +4175,51 @@ func isServiceReply(reply []byte) bool {
return len(reply) > 3 && bytesToString(reply[:4]) == replyPrefix
}
// Test whether a subject is a JetStream ACK.
func isJSAckSubject(subject []byte) bool {
return len(subject) > jsAckPreLen && bytesToString(subject[:jsAckPreLen]) == jsAckPre
}
// jsAckDeliverIdx returns the byte offset of the `@` separator in an encoded
// `$JS.ACK....@<deliver>` reply, or -1 if reply is not in that form. Stream,
// consumer, and subject tokens may legally contain `@`, so we accept only the
// first `@` that follows the eight dots of the JS ACK token:
//
// $JS.ACK.<stream>.<consumer>.<delivered>.<sseq>.<cseq>.<tm>.<pending>@<deliver>
func jsAckDeliverIdx(reply []byte) int {
if !isJSAckSubject(reply) {
return -1
}
dots := 0
for i, b := range reply {
switch b {
case '.':
dots++
case '@':
if dots >= 8 {
return i
}
}
}
return -1
}
// replyHasJSAckSuffix reports whether reply is already in `$JS.ACK....@<deliver>`
// form, so callers don't double-append the suffix on a re-entrant pass
// (service-import or chained JS push).
func replyHasJSAckSuffix(reply []byte) bool {
return jsAckDeliverIdx(reply) != -1
}
// Test whether a reply subject is a service import or a gateway routed reply.
func isReservedReply(reply []byte) bool {
if isServiceReply(reply) {
return true
}
rLen := len(reply)
// Faster to check with string([:]) than byte-by-byte
if rLen > jsAckPreLen && bytesToString(reply[:jsAckPreLen]) == jsAckPre {
if isJSAckSubject(reply) {
return true
} else if rLen > gwReplyPrefixLen && bytesToString(reply[:gwReplyPrefixLen]) == gwReplyPrefix {
} else if len(reply) > gwReplyPrefixLen && bytesToString(reply[:gwReplyPrefixLen]) == gwReplyPrefix {
return true
}
return false
@@ -4370,7 +4418,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) {
// Now deal with gateways
if c.srv.gateway.enabled {
reply := c.pa.reply
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(c.pa.reply) > 0 {
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(reply) > 0 && !replyHasJSAckSuffix(reply) {
reply = append(reply, '@')
reply = append(reply, c.pa.deliver...)
}
@@ -4418,7 +4466,7 @@ func (c *client) handleGWReplyMap(msg []byte) bool {
}
if c.srv.gateway.enabled {
reply := c.pa.reply
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(c.pa.reply) > 0 {
if len(c.pa.deliver) > 0 && c.kind == JETSTREAM && len(reply) > 0 && !replyHasJSAckSuffix(reply) {
reply = append(reply, '@')
reply = append(reply, c.pa.deliver...)
}
@@ -4531,7 +4579,8 @@ func (c *client) setHeader(key, value string, msg []byte) []byte {
// Write original header if present.
if c.pa.hdr > LEN_CR_LF {
omi = c.pa.hdr
hdr := removeHeaderIfPresent(msg[:c.pa.hdr-LEN_CR_LF], key)
// Need to copy since we're removing the header in place.
hdr := removeHeaderIfPresent(copyBytes(msg[:c.pa.hdr-LEN_CR_LF]), key)
if len(hdr) == 0 {
bb.WriteString(hdrLine)
} else {
@@ -4825,6 +4874,12 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
// but the local server must replace it with the identity of the
// authenticated leaf connection instead of trusting forwarded values.
ci = c.getClientInfo(share)
if hadPrevSi && cis != nil && cis.Reply != _EMPTY_ {
ci.Reply = cis.Reply
} else if bytes.HasSuffix(c.pa.reply, []byte(FastBatchSuffix)) {
// Fast batch requires knowledge of the original reply subject.
ci.Reply = bytesToString(c.pa.reply)
}
if hadPrevSi {
ci.Service = acc.Name
if !share && (si.share || isSysImport) {
@@ -4843,6 +4898,10 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
}
} else if c.kind != LEAF || c.pa.hdr < 0 || len(sliceHeader(ClientInfoHdr, msg[:c.pa.hdr])) == 0 {
ci = c.getClientInfo(share)
// Fast batch requires knowledge of the original reply subject.
if bytes.HasSuffix(c.pa.reply, []byte(FastBatchSuffix)) {
ci.Reply = bytesToString(c.pa.reply)
}
// If we did not share but the imports destination is the system account add in the server and cluster info.
if !share && isSysImport {
c.addServerAndClusterInfo(ci)
@@ -4902,8 +4961,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
// We also need to disable the message trace headers so that
// if the message is routed, it does not initialize tracing in the
// remote.
positions := disableTraceHeaders(c, msg)
defer enableTraceHeaders(msg, positions)
msg = c.setHeader(MsgTraceDest, MsgTraceDestDisabled, msg)
}
}
}
@@ -5037,21 +5095,9 @@ func (c *client) processMsgResults(acc *Account, r *SublistResult, msg, deliver,
// Check for JetStream encoded reply subjects.
// For now these will only be on $JS.ACK prefixed reply subjects.
var remapped bool
if len(creply) > 0 && c.kind != CLIENT && !isInternalClient(c.kind) && bytes.HasPrefix(creply, []byte(jsAckPre)) {
if len(creply) > 0 && c.kind != CLIENT && !isInternalClient(c.kind) {
// We need to rewrite the subject and the reply.
// But, we must be careful that the stream name, consumer name, and subject can contain '@' characters.
// JS ACK contains at least 8 dots, find the first @ after this prefix.
// - $JS.ACK.<stream>.<consumer>.<delivered>.<sseq>.<cseq>.<tm>.<pending>
counter := 0
li := bytes.IndexFunc(creply, func(rn rune) bool {
if rn == '.' {
counter++
} else if rn == '@' {
return counter >= 8
}
return false
})
if li != -1 && li < len(creply)-1 {
if li := jsAckDeliverIdx(creply); li != -1 && li < len(creply)-1 {
remapped = true
subj, creply = creply[li+1:], creply[:li]
}
@@ -5471,7 +5517,7 @@ sendToRoutesOrLeafs:
// at the end of the reply subject if it exists. But only if this wasn't
// already performed, otherwise we'd end up with a duplicate '@' suffix
// resulting in a protocol error.
if len(deliver) > 0 && len(reply) > 0 && !remapped {
if len(deliver) > 0 && len(reply) > 0 && !remapped && !replyHasJSAckSuffix(reply) {
reply = append(reply, '@')
reply = append(reply, deliver...)
}
@@ -5754,11 +5800,12 @@ func (c *client) clearAuthTimer() bool {
return stopped
}
// We may reuse atmr for expiring user jwts,
// so check connectReceived.
// Track whether the parser should still enforce pre-CONNECT rules.
// This is handshake state, not timer state, since some handshakes
// use a different timer while still expecting CONNECT.
// Lock assume held on entry.
func (c *client) awaitingAuth() bool {
return !c.flags.isSet(connectReceived) && c.atmr != nil
return c.flags.isSet(expectConnect) && !c.flags.isSet(connectReceived)
}
// This will set the atmr for the JWT expiration time.
@@ -5987,37 +6034,12 @@ func (c *client) closeConnection(reason ClosedState) {
srv = c.srv
noReconnect = c.flags.isSet(noReconnect)
acc = c.acc
spoke bool
)
// Snapshot for use if we are a client connection.
// FIXME(dlc) - we can just stub in a new one for client
// and reference existing one.
var subs []*subscription
if kind == CLIENT || kind == LEAF || kind == JETSTREAM {
var _subs [32]*subscription
subs = _subs[:0]
// Do not set c.subs to nil or delete the sub from c.subs here because
// it will be needed in saveClosedClient (which has been started as a
// go routine in markConnAsClosed). Cleanup will be done there.
for _, sub := range c.subs {
// Auto-unsubscribe subscriptions must be unsubscribed forcibly.
sub.max = 0
sub.close()
subs = append(subs, sub)
}
spoke = c.isSpokeLeafNode()
}
c.mu.Unlock()
// Remove client's or leaf node or jetstream subscriptions.
if acc != nil && (kind == CLIENT || kind == LEAF || kind == JETSTREAM) {
acc.sl.RemoveBatch(subs)
} else if kind == ROUTER {
if kind == ROUTER {
c.removeRemoteSubs()
}
if srv != nil {
// Unregister
srv.removeClient(c)
@@ -6025,45 +6047,11 @@ func (c *client) closeConnection(reason ClosedState) {
if acc != nil {
// Update remote subscriptions.
if kind == CLIENT || kind == LEAF || kind == JETSTREAM {
qsubs := map[string]*qsub{}
for _, sub := range subs {
// Call unsubscribe here to cleanup shadow subscriptions and such.
c.unsubscribe(acc, sub, true, false)
// Update route as normal for a normal subscriber.
if sub.queue == nil {
if !spoke {
srv.updateRouteSubscriptionMap(acc, sub, -1)
if srv.gateway.enabled {
srv.gatewayUpdateSubInterest(acc.Name, sub, -1)
}
}
acc.updateLeafNodes(sub, -1)
} else {
// We handle queue subscribers special in case we
// have a bunch we can just send one update to the
// connected routes.
num := int32(1)
if kind == LEAF {
num = sub.qw
}
key := keyFromSub(sub)
if esub, ok := qsubs[key]; ok {
esub.n += num
} else {
qsubs[key] = &qsub{sub, num}
}
}
}
// Process any qsubs here.
for _, esub := range qsubs {
if !spoke {
srv.updateRouteSubscriptionMap(acc, esub.sub, -(esub.n))
if srv.gateway.enabled {
srv.gatewayUpdateSubInterest(acc.Name, esub.sub, -(esub.n))
}
}
acc.updateLeafNodes(esub.sub, -(esub.n))
}
// Remove client's subscriptions from the account and unregister
// client from that account. Keep c.subs populated because
// saveClosedClient (started as a goroutine in markConnAsClosed)
// still needs to read it.
c.clearAccountSubs(true)
}
// Always remove from the account, otherwise we can leak clients.
// Note that SYSTEM and ACCOUNT types from above cleanup their own subs.
@@ -6090,6 +6078,87 @@ func (c *client) closeConnection(reason ClosedState) {
c.reconnect()
}
// clearAccountSubs removes the client's subscriptions from its current account
// and unregisters it from that account. If close is true, c.subs is left
// populated for saveClosedClient; otherwise c.subs is cleared and c.acc
// registered back to the global account.
// Client lock MUST NOT be held on entry.
func (c *client) clearAccountSubs(close bool) {
c.mu.Lock()
kind := c.kind
srv := c.srv
acc := c.acc
if acc == nil || (kind != CLIENT && kind != LEAF && kind != JETSTREAM) {
c.mu.Unlock()
return
}
var _subs [32]*subscription
subs := _subs[:0]
// Do not set c.subs to nil or delete the sub from c.subs here because
// it will be needed in saveClosedClient (which has been started as a
// go routine in markConnAsClosed). Cleanup will be done there.
for _, sub := range c.subs {
// Auto-unsubscribe subscriptions must be unsubscribed forcibly.
sub.max = 0
sub.close()
subs = append(subs, sub)
if !close {
delete(c.subs, string(sub.sid))
}
}
spoke := c.isSpokeLeafNode()
c.mu.Unlock()
acc.sl.RemoveBatch(subs)
if srv != nil {
qsubs := map[string]*qsub{}
for _, sub := range subs {
// Call unsubscribe here to cleanup shadow subscriptions and such.
c.unsubscribe(acc, sub, true, false)
// Update route as normal for a normal subscriber.
if sub.queue == nil {
if !spoke {
srv.updateRouteSubscriptionMap(acc, sub, -1)
if srv.gateway.enabled {
srv.gatewayUpdateSubInterest(acc.Name, sub, -1)
}
}
acc.updateLeafNodes(sub, -1)
} else {
// We handle queue subscribers special in case we
// have a bunch we can just send one update to the
// connected routes.
num := int32(1)
if kind == LEAF {
num = sub.qw
}
key := keyFromSub(sub)
if esub, ok := qsubs[key]; ok {
esub.n += num
} else {
qsubs[key] = &qsub{sub, num}
}
}
}
// Process any qsubs here.
for _, esub := range qsubs {
if !spoke {
srv.updateRouteSubscriptionMap(acc, esub.sub, -(esub.n))
if srv.gateway.enabled {
srv.gatewayUpdateSubInterest(acc.Name, esub.sub, -(esub.n))
}
}
acc.updateLeafNodes(esub.sub, -(esub.n))
}
}
if !close {
// Register back to global account, mimicking the state after client initialization.
c.registerWithAccount(srv.globalAccount())
}
}
// Depending on the kind of connections, this may attempt to recreate a connection.
// The actual reconnect attempt will be started in a go routine.
func (c *client) reconnect() {
@@ -6180,7 +6249,7 @@ func (c *client) reconnect() {
srv.Debugf("Gateway %q not in configuration, not attempting reconnect", gwName)
}
} else if leafCfg != nil {
// Check if this is a solicited leaf node. Start up a reconnect.
// This is a solicited leaf node. Start up a reconnect.
srv.startGoRoutine(func() { srv.reConnectToRemoteLeafNode(leafCfg) })
}
}
+1 -1
View File
@@ -66,7 +66,7 @@ func init() {
const (
// VERSION is the current version for the server.
VERSION = "2.12.6"
VERSION = "2.14.0"
// PROTO is the currently supported protocol.
// 0 was the original
File diff suppressed because it is too large Load Diff
+327
View File
@@ -0,0 +1,327 @@
// Copyright 2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Based on code from https://github.com/robfig/cron
// Copyright (C) 2012 Rob Figueiredo
// All Rights Reserved.
//
// MIT LICENSE
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
package server
import (
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
)
// parseCron parses the given cron pattern and returns the next time it will fire based on the provided ts.
func parseCron(pattern string, loc *time.Location, ts int64) (time.Time, error) {
fields := strings.Fields(pattern)
if len(fields) != 6 {
return time.Time{}, fmt.Errorf("pattern requires 6 fields, got %d", len(fields))
}
// If no time zone is passed, default to UTC.
if loc == nil {
loc = time.UTC
}
// Parse each field.
var err error
var second, minute, hour, dayOfMonth, month, dayOfWeek uint64
if second, err = getField(fields[0], seconds); err != nil {
return time.Time{}, err
}
if minute, err = getField(fields[1], minutes); err != nil {
return time.Time{}, err
}
if hour, err = getField(fields[2], hours); err != nil {
return time.Time{}, err
}
if dayOfMonth, err = getField(fields[3], dom); err != nil {
return time.Time{}, err
}
if month, err = getField(fields[4], months); err != nil {
return time.Time{}, err
}
if dayOfWeek, err = getField(fields[5], dow); err != nil {
return time.Time{}, err
}
// General approach
//
// For Month, Day, Hour, Minute, Second:
// Check if the time value matches. If yes, continue to the next field.
// If the field doesn't match the schedule, then increment the field until it matches.
// While incrementing the field, a wrap-around brings it back to the beginning
// of the field list (since it is necessary to re-verify previous field values)
next := time.Unix(0, ts).In(loc)
// Start at the earliest possible time (the upcoming second).
next = next.Truncate(time.Second).Add(time.Second)
// This flag indicates whether a field has been truncated at one point.
truncated := false
// If no time is found within five years, return error.
yearLimit := next.Year() + 5
WRAP:
if next.Year() > yearLimit {
return time.Time{}, errors.New("pattern exceeds maximum range")
}
for 1<<uint(next.Month())&month == 0 {
if !truncated {
truncated = true
next = time.Date(next.Year(), next.Month(), 1, 0, 0, 0, 0, loc)
}
if next = next.AddDate(0, 1, 0); next.Month() == time.January {
goto WRAP
}
}
for !dayMatches(dayOfMonth, dayOfWeek, next) {
if !truncated {
truncated = true
next = time.Date(next.Year(), next.Month(), next.Day(), 0, 0, 0, 0, loc)
}
if next = next.AddDate(0, 0, 1); next.Day() == 1 {
goto WRAP
}
}
for 1<<uint(next.Hour())&hour == 0 {
if !truncated {
truncated = true
next = time.Date(next.Year(), next.Month(), next.Day(), next.Hour(), 0, 0, 0, loc)
}
if next = next.Add(time.Hour); next.Hour() == 0 {
goto WRAP
}
}
for 1<<uint(next.Minute())&minute == 0 {
if !truncated {
truncated = true
next = next.Truncate(time.Minute)
}
if next = next.Add(time.Minute); next.Minute() == 0 {
goto WRAP
}
}
for 1<<uint(next.Second())&second == 0 {
if !truncated {
truncated = true
next = next.Truncate(time.Second)
}
if next = next.Add(time.Second); next.Second() == 0 {
goto WRAP
}
}
return next, nil
}
// getField returns an Int with the bits set representing all of the times that
// the field represents or error parsing field value. A "field" is a comma-separated
// list of "ranges".
func getField(field string, r bounds) (uint64, error) {
var bits uint64
ranges := strings.FieldsFuncSeq(field, func(r rune) bool { return r == ',' })
for expr := range ranges {
bit, err := getRange(expr, r)
if err != nil {
return bits, err
}
bits |= bit
}
return bits, nil
}
// getRange returns the bits indicated by the given expression: number | number [ "-" number ] [ "/" number ]
// or error parsing range.
func getRange(expr string, r bounds) (uint64, error) {
var (
start, end, step uint
rangeAndStep = strings.Split(expr, "/")
lowAndHigh = strings.Split(rangeAndStep[0], "-")
singleDigit = len(lowAndHigh) == 1
err error
)
var extra uint64
if lowAndHigh[0] == "*" || lowAndHigh[0] == "?" {
start = r.min
end = r.max
extra = starBit
} else {
start, err = parseIntOrName(lowAndHigh[0], r.names)
if err != nil {
return 0, err
}
switch len(lowAndHigh) {
case 1:
end = start
case 2:
end, err = parseIntOrName(lowAndHigh[1], r.names)
if err != nil {
return 0, err
}
default:
return 0, fmt.Errorf("too many hyphens: %s", expr)
}
}
switch len(rangeAndStep) {
case 1:
step = 1
case 2:
step, err = mustParseInt(rangeAndStep[1])
if err != nil {
return 0, err
}
// Special handling: "N/step" means "N-max/step".
if singleDigit {
end = r.max
}
if step > 1 {
extra = 0
}
default:
return 0, fmt.Errorf("too many slashes: %s", expr)
}
if start < r.min {
return 0, fmt.Errorf("beginning of range (%d) below minimum (%d): %s", start, r.min, expr)
}
if end > r.max {
return 0, fmt.Errorf("end of range (%d) above maximum (%d): %s", end, r.max, expr)
}
if start > end {
return 0, fmt.Errorf("beginning of range (%d) beyond end of range (%d): %s", start, end, expr)
}
if step == 0 {
return 0, fmt.Errorf("step of range should be a positive number: %s", expr)
}
return getBits(start, end, step) | extra, nil
}
// parseIntOrName returns the (possibly-named) integer contained in expr.
func parseIntOrName(expr string, names map[string]uint) (uint, error) {
if names != nil {
if namedInt, ok := names[strings.ToLower(expr)]; ok {
return namedInt, nil
}
}
return mustParseInt(expr)
}
// mustParseInt parses the given expression as an int or returns an error.
func mustParseInt(expr string) (uint, error) {
num, err := strconv.Atoi(expr)
if err != nil {
return 0, fmt.Errorf("failed to parse int from %s: %s", expr, err)
}
if num < 0 {
return 0, fmt.Errorf("negative number (%d) not allowed: %s", num, expr)
}
return uint(num), nil
}
// getBits sets all bits in the range [min, max], modulo the given step size.
func getBits(min, max, step uint) uint64 {
var bits uint64
// If step is 1, use shifts.
if step == 1 {
return ^(math.MaxUint64 << (max + 1)) & (math.MaxUint64 << min)
}
// Else, use a simple loop.
for i := min; i <= max; i += step {
bits |= 1 << i
}
return bits
}
// bounds provides a range of acceptable values (plus a map of name to value).
type bounds struct {
min, max uint
names map[string]uint
}
// The bounds for each field.
var (
seconds = bounds{0, 59, nil}
minutes = bounds{0, 59, nil}
hours = bounds{0, 23, nil}
dom = bounds{1, 31, nil}
months = bounds{1, 12, map[string]uint{
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"may": 5,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}}
dow = bounds{0, 6, map[string]uint{
"sun": 0,
"mon": 1,
"tue": 2,
"wed": 3,
"thu": 4,
"fri": 5,
"sat": 6,
}}
)
const (
// Set the top bit if a star was included in the expression.
starBit = 1 << 63
)
// dayMatches returns true if the schedule's day-of-week and day-of-month
// restrictions are satisfied by the given time.
func dayMatches(dayOfMonth, dayOfWeek uint64, t time.Time) bool {
var (
domMatch = 1<<uint(t.Day())&dayOfMonth > 0
dowMatch = 1<<uint(t.Weekday())&dayOfWeek > 0
)
if dayOfMonth&starBit > 0 || dayOfWeek&starBit > 0 {
return domMatch && dowMatch
}
return domMatch || dowMatch
}
+211 -1
View File
@@ -2008,5 +2008,215 @@
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageSchedulesSourceInvalidErr",
"code": 400,
"error_code": 10203,
"description": "message schedules source is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerInvalidResetErr",
"code": 400,
"error_code": 10204,
"description": "invalid reset: {err}",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSBatchPublishDisabledErr",
"code": 400,
"error_code": 10205,
"description": "batch publish is disabled",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSBatchPublishInvalidPatternErr",
"code": 400,
"error_code": 10206,
"description": "batch publish pattern is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSBatchPublishInvalidBatchIDErr",
"code": 400,
"error_code": 10207,
"description": "batch publish ID is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSBatchPublishUnknownBatchIDErr",
"code": 400,
"error_code": 10208,
"description": "batch publish ID unknown",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMirrorWithBatchPublishErr",
"code": 400,
"error_code": 10209,
"description": "stream mirrors can not also use batch publishing",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSAtomicPublishTooManyInflight",
"code": 429,
"error_code": 10210,
"description": "atomic publish too many inflight",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSBatchPublishTooManyInflight",
"code": 429,
"error_code": 10211,
"description": "batch publish too many inflight",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageSchedulesSchedulerInvalidErr",
"code": 400,
"error_code": 10212,
"description": "message schedules invalid scheduler",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMirrorDurableConsumerCfgInvalid",
"code": 400,
"error_code": 10213,
"description": "stream mirror consumer config is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMirrorConsumerRequiresAckFCErr",
"code": 400,
"error_code": 10214,
"description": "stream mirror consumer requires flow control ack policy",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSSourceDurableConsumerCfgInvalid",
"code": 400,
"error_code": 10215,
"description": "stream source consumer config is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSSourceDurableConsumerDuplicateDetected",
"code": 400,
"error_code": 10216,
"description": "duplicate stream source consumer detected",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSSourceConsumerRequiresAckFCErr",
"code": 400,
"error_code": 10217,
"description": "stream source consumer requires flow control ack policy",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerAckFCRequiresPushErr",
"code": 400,
"error_code": 10218,
"description": "flow control ack policy requires a push based consumer",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerAckFCRequiresFCErr",
"code": 400,
"error_code": 10219,
"description": "flow control ack policy requires flow control",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerAckFCRequiresMaxAckPendingErr",
"code": 400,
"error_code": 10220,
"description": "flow control ack policy requires max ack pending",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerAckFCRequiresNoAckWaitErr",
"code": 400,
"error_code": 10221,
"description": "flow control ack policy requires unset ack wait",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSConsumerAckFCRequiresNoMaxDeliverErr",
"code": 400,
"error_code": 10222,
"description": "flow control ack policy requires unset max deliver",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageSchedulesTimeZoneInvalidErr",
"code": 400,
"error_code": 10223,
"description": "message schedules time zone is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
}
]
]
+21 -11
View File
@@ -247,14 +247,15 @@ type ServerCapability uint64
// ServerInfo identifies remote servers.
type ServerInfo struct {
Name string `json:"name"`
Host string `json:"host"`
ID string `json:"id"`
Cluster string `json:"cluster,omitempty"`
Domain string `json:"domain,omitempty"`
Version string `json:"ver"`
Tags []string `json:"tags,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Name string `json:"name"`
Host string `json:"host"`
ID string `json:"id"`
Cluster string `json:"cluster,omitempty"`
Domain string `json:"domain,omitempty"`
Version string `json:"ver"`
Tags []string `json:"tags,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
FeatureFlags map[string]bool `json:"feature_flags,omitempty"`
// Whether JetStream is enabled (deprecated in favor of the `ServerCapability`).
JetStream bool `json:"jetstream"`
// Generic capability flags
@@ -328,6 +329,7 @@ type ClientInfo struct {
ClientType string `json:"client_type,omitempty"`
MQTTClient string `json:"client_id,omitempty"` // This is the MQTT client ID
Nonce string `json:"nonce,omitempty"`
Reply string `json:"reply,omitempty"` // Original reply subject after a service import (only when needed).
}
// forAssignmentSnap returns the minimum amount of ClientInfo we need for assignment snapshots.
@@ -518,7 +520,7 @@ RESET:
// Grab tags and metadata.
opts := s.getOpts()
tags, metadata := opts.Tags, opts.Metadata
tags, metadata, featureFlags := opts.Tags, opts.Metadata, opts.getMergedFeatureFlags()
for s.eventsRunning() {
select {
@@ -536,6 +538,7 @@ RESET:
si.Time = time.Now().UTC()
si.Tags = tags
si.Metadata = metadata
si.FeatureFlags = featureFlags
si.Flags = 0
if js {
// New capability based flags.
@@ -1052,8 +1055,15 @@ func (s *Server) sendStatsz(subj string) {
Size: mg.ClusterSize(),
}
}
if ipq := s.jsAPIRoutedReqs; ipq != nil && jStat.Meta != nil {
jStat.Meta.Pending = ipq.len()
if jStat.Meta != nil {
if ipq := s.jsAPIRoutedReqs; ipq != nil {
jStat.Meta.PendingRequests = ipq.len()
}
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
jStat.Meta.PendingInfos = ipq.len()
}
jStat.Meta.Pending = jStat.Meta.PendingRequests + jStat.Meta.PendingInfos
jStat.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
}
}
jStat.Limits = &s.getOpts().JetStreamLimits
+130
View File
@@ -0,0 +1,130 @@
// Copyright 2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"maps"
"slices"
"strings"
)
const (
FeatureFlagJsAckFormatV2 = "js_ack_fc_v2"
FeatureFlagJsRaftDeleteRange = "js_raft_delete_range"
)
var featureFlags = map[string]bool{
// Use v2 format for `$JS.ACK.>` and `$JS.FC.>`.
// - Introduced: 2.14.0, both v1 and v2 supported, only using v1.
// - Enabled: TBD, both supported, v2 becomes the default.
//
// - v1: $JS.ACK.<stream name>.<consumer name>.<num delivered>.<stream sequence>.<consumer sequence>.<timestamp>.<num pending>
// - v2: $JS.ACK.<domain>.<account hash>.<stream name>.<consumer name>.<num delivered>.<stream sequence>.<consumer sequence>.<timestamp>.<num pending>
// See also: https://github.com/nats-io/nats-architecture-and-design/blob/main/adr/ADR-15.md#jsack
FeatureFlagJsAckFormatV2: false,
// Propose delete range gaps as a single `deleteRangeOp` Raft append entry
// instead of one entry per deleted sequence. Dramatically reduces Raft cost
// on mirrors whose origin has a large number of interior deletes.
// - Introduced: 2.14.0, apply-side always supports receiving `deleteRangeOp`.
// - Enabled: TBD, once all supported versions carry the apply-side.
//
// WARNING: Only enable once every peer in the cluster is on a version that
// supports receiving `deleteRangeOp`. Older peers panic on apply of an
// unknown stream entry operation.
FeatureFlagJsRaftDeleteRange: false,
}
// getFeatureFlag is used to retrieve either the default or overwritten value for a feature flag.
// The user's value takes precedence over the system's default. However, if the flag doesn't exist, it's disabled.
// The *Options returned by Server.getOpts() is treated as immutable, mutations go through setOpts,
// so no lock is required on the map read here.
func (o *Options) getFeatureFlag(k string) bool {
defaultValue, ok := featureFlags[k]
if !ok {
return false // Not supported.
}
if userValue, ok := o.FeatureFlags[k]; ok {
return userValue
}
return defaultValue
}
// getMergedFeatureFlags returns a merged map of feature flags, with the user's values taking precedence.
func (o *Options) getMergedFeatureFlags() map[string]bool {
merged := make(map[string]bool)
for k, v := range featureFlags {
merged[k] = v
}
for k, v := range o.FeatureFlags {
if _, ok := featureFlags[k]; !ok {
continue
}
merged[k] = v
}
return merged
}
// printFeatureFlags logs the currently used feature flags on server startup.
func (s *Server) printFeatureFlags(o *Options) {
if len(o.FeatureFlags) == 0 {
return
}
keys := slices.Sorted(maps.Keys(o.FeatureFlags))
var (
configured strings.Builder
unsupported strings.Builder
)
for _, k := range keys {
// Unsupported
defaultValue, ok := featureFlags[k]
if !ok {
if unsupported.Len() > 0 {
unsupported.WriteString(", ")
}
unsupported.WriteString(k)
continue
}
v := o.FeatureFlags[k]
if configured.Len() > 0 {
configured.WriteString(", ")
}
configured.WriteString(k)
configured.WriteString(" (")
if defaultValue {
if v {
configured.WriteString("enabled")
} else {
configured.WriteString("opt-out")
}
} else if v {
configured.WriteString("opt-in")
} else {
configured.WriteString("disabled")
}
configured.WriteString(")")
}
if configured.Len() == 0 {
configured.WriteString("none")
}
s.Noticef(" Feature flags:")
s.Noticef(" Configured: %s", configured.String())
if unsupported.Len() > 0 {
s.Noticef(" Unsupported: %s", unsupported.String())
}
}
File diff suppressed because it is too large Load Diff
+14 -6
View File
@@ -1156,9 +1156,7 @@ func (c *client) processGatewayInfo(info *Info) {
// defensive code above that if we did not register this connection
// because we already have an outbound for this name, then
// close this connection (and make sure it does not try to reconnect)
c.mu.Lock()
c.flags.set(noReconnect)
c.mu.Unlock()
c.setNoReconnect()
c.closeConnection(WrongGateway)
return
}
@@ -1981,7 +1979,7 @@ func (c *client) processGatewayRUnsub(arg []byte) error {
return nil
} else {
// Plain sub, assume optimistic sends, create entry.
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistWithCache()}
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistForServer(c.srv)}
newe = true
}
// This is when a sub or queue sub is supposed to be in
@@ -2090,7 +2088,7 @@ func (c *client) processGatewayRSub(arg []byte) error {
} else if queue == nil {
return nil
} else {
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistWithCache()}
e = &outsie{ni: make(map[string]struct{}), sl: NewSublistForServer(c.srv)}
newe = true
useSl = true
}
@@ -2952,6 +2950,16 @@ func getSubjectFromGWRoutedReply(reply []byte, isOldPrefix bool) []byte {
return reply[gwSubjectOffset:]
}
// Returns the subject embedded in the given routed
// reply subject and whether the prefix was stripped.
// If the subject is not routed, returns it unchanged.
func getGWRoutedSubjectOrSelf(subject []byte) ([]byte, bool) {
if isGWPrefix, oldPrefix := isGWRoutedSubjectAndIsOldPrefix(subject); isGWPrefix {
return getSubjectFromGWRoutedReply(subject, oldPrefix), true
}
return subject, false
}
// This should be invoked only from processInboundGatewayMsg() or
// processInboundRoutedMsg() and is checking if the subject
// (c.pa.subject) has the _GR_ prefix. If so, this is processed
@@ -3201,7 +3209,7 @@ func (c *client) gatewayAllSubsReceiveStart(info *Info) {
e.mode = Transitioning
e.Unlock()
} else {
e := &outsie{sl: NewSublistWithCache()}
e := &outsie{sl: NewSublistForServer(c.srv)}
e.mode = Transitioning
c.mu.Lock()
c.gw.outsim.Store(account, e)
+60
View File
@@ -170,6 +170,66 @@ func (s *GenericSublist[T]) NumInterest(subject string) (np int) {
return
}
// MatchesFullWildcard returns true if there is top-level ">" interest.
func (s *GenericSublist[T]) MatchesFullWildcard() bool {
if s == nil {
return false
}
s.RLock()
defer s.RUnlock()
return s.root.fwc != nil
}
// MatchesSingleFilter returns the filter when the sublist contains exactly one unique subject.
func (s *GenericSublist[T]) MatchesSingleFilter() (string, bool) {
if s == nil {
return _EMPTY_, false
}
s.RLock()
defer s.RUnlock()
return singleFilter(s.root, _EMPTY_)
}
func singleFilter[T comparable](l *level[T], filter string) (string, bool) {
if l == nil {
return filter, filter != _EMPTY_
}
if len(l.nodes) > 1 {
return _EMPTY_, false
}
var next *node[T]
branches := 0
if l.pwc != nil {
next = l.pwc
branches++
}
if l.fwc != nil {
next = l.fwc
branches++
}
for _, n := range l.nodes {
next = n
branches++
}
if branches != 1 {
return _EMPTY_, false
}
for _, subj := range next.subs {
filter = subj
break
}
if next.next == nil {
return filter, filter != _EMPTY_
}
if filter != _EMPTY_ {
if next.next.numNodes() > 0 {
return _EMPTY_, false
}
return filter, true
}
return singleFilter(next.next, filter)
}
func (s *GenericSublist[T]) match(subject string, cb func(T), doLock bool) {
tsa := [32]string{}
tokens := tsa[:0]
+63 -423
View File
@@ -25,13 +25,13 @@ import (
"os"
"path/filepath"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/minio/highwayhash"
"github.com/nats-io/nats-server/v2/server/gsl"
"github.com/nats-io/nats-server/v2/server/sysmem"
"github.com/nats-io/nats-server/v2/server/tpm"
"github.com/nats-io/nkeys"
@@ -103,22 +103,26 @@ type JetStreamAPIStats struct {
// This is for internal accounting for JetStream for this server.
type jetStream struct {
// These are here first because of atomics on 32bit systems.
apiInflight int64
apiTotal int64
apiErrors int64
memReserved int64
storeReserved int64
memUsed int64
storeUsed int64
queueLimit int64
clustered int32
mu sync.RWMutex
srv *Server
config JetStreamConfig
cluster *jetStreamCluster
accounts map[string]*jsAccount
apiSubs *Sublist
started time.Time
apiInflight int64
apiTotal int64
apiErrors int64
memMax int64
memReserved int64 // Requires JS lock to be held.
memUsed int64
storeMax int64
storeReserved int64 // Requires JS lock to be held.
storeUsed int64
queueLimit int64
infoQueueLimit int64
clustered int32
mu sync.RWMutex
srv *Server
config JetStreamConfig
cluster *jetStreamCluster
accounts map[string]*jsAccount
apiSubs *Sublist
infoSubs *gsl.SimpleSublist // Subjects for info-specific queue.
started time.Time
// System level request to purge a stream move
accountPurge *subscription
@@ -150,14 +154,12 @@ type jsaStorage struct {
// an internal sub for a stream, so we will direct link to the stream
// and walk backwards as needed vs multiple hash lookups and locks, etc.
type jsAccount struct {
mu sync.RWMutex
js *jetStream
account *Account
storeDir string
inflight sync.Map
streams map[string]*stream
templates map[string]*streamTemplate // Deprecated: stream templates are deprecated and will be removed in a future version.
store TemplateStore // Deprecated: stream templates are deprecated and will be removed in a future version.
mu sync.RWMutex
js *jetStream
account *Account
storeDir string
inflight sync.Map
streams map[string]*stream
// From server
sendq *ipQueue[*pubMsg]
@@ -415,15 +417,19 @@ func (s *Server) initJetStreamEncryption() (err error) {
// enableJetStream will start up the JetStream subsystem.
func (s *Server) enableJetStream(cfg JetStreamConfig) error {
js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache()}
js := &jetStream{srv: s, config: cfg, accounts: make(map[string]*jsAccount), apiSubs: NewSublistNoCache(), infoSubs: gsl.NewSimpleSublist()}
s.gcbMu.Lock()
if s.gcbOutMax = s.getOpts().JetStreamMaxCatchup; s.gcbOutMax == 0 {
s.gcbOutMax = defaultMaxTotalCatchupOutBytes
}
s.gcbMu.Unlock()
atomic.StoreInt64(&js.memMax, cfg.MaxMemory)
atomic.StoreInt64(&js.storeMax, cfg.MaxStore)
// TODO: Not currently reloadable.
atomic.StoreInt64(&js.queueLimit, s.getOpts().JetStreamRequestQueueLimit)
atomic.StoreInt64(&js.infoQueueLimit, s.getOpts().JetStreamInfoQueueLimit)
s.js.Store(js)
@@ -1058,8 +1064,10 @@ func (s *Server) shutdownJetStream() {
func (s *Server) JetStreamConfig() *JetStreamConfig {
var c *JetStreamConfig
if js := s.getJetStream(); js != nil {
js.mu.RLock()
copy := js.config
c = &(copy)
js.mu.RUnlock()
}
return c
}
@@ -1219,54 +1227,6 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
s.Debugf("Recovering JetStream state for account %q", a.Name)
}
// Check templates first since messsage sets will need proper ownership.
// FIXME(dlc) - Make this consistent.
tdir := filepath.Join(jsa.storeDir, tmplsDir)
if stat, err := os.Stat(tdir); err == nil && stat.IsDir() {
key := sha256.Sum256([]byte("templates"))
hh, err := highwayhash.NewDigest64(key[:])
if err != nil {
return err
}
fis, _ := os.ReadDir(tdir)
for _, fi := range fis {
metafile := filepath.Join(tdir, fi.Name(), JetStreamMetaFile)
metasum := filepath.Join(tdir, fi.Name(), JetStreamMetaFileSum)
buf, err := os.ReadFile(metafile)
if err != nil {
s.Warnf(" Error reading StreamTemplate metafile %q: %v", metasum, err)
continue
}
if _, err := os.Stat(metasum); os.IsNotExist(err) {
s.Warnf(" Missing StreamTemplate checksum for %q", metasum)
continue
}
sum, err := os.ReadFile(metasum)
if err != nil {
s.Warnf(" Error reading StreamTemplate checksum %q: %v", metasum, err)
continue
}
hh.Reset()
hh.Write(buf)
var hb [highwayhash.Size64]byte
checksum := hex.EncodeToString(hh.Sum(hb[:0]))
if checksum != string(sum) {
s.Warnf(" StreamTemplate checksums do not match %q vs %q", sum, checksum)
continue
}
var cfg StreamTemplateConfig
if err := json.Unmarshal(buf, &cfg); err != nil {
s.Warnf(" Error unmarshalling StreamTemplate metafile: %v", err)
continue
}
cfg.Config.Name = _EMPTY_
if _, err := a.addStreamTemplate(&cfg); err != nil {
s.Warnf(" Error recreating StreamTemplate %q: %v", cfg.Name, err)
continue
}
}
}
// Remember if we should be encrypted and what cipher we think we should use.
encrypted := s.getOpts().JetStreamKey != _EMPTY_
sc := s.getOpts().JetStreamCipher
@@ -1510,15 +1470,6 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
return nil
}
if cfg.Template != _EMPTY_ {
jsa.mu.Lock()
err := jsa.addStreamNameToTemplate(cfg.Template, cfg.Name)
jsa.mu.Unlock()
if err != nil {
s.Warnf(" Error adding stream %q to template %q: %v", cfg.Name, cfg.Template, err)
}
}
// We had a bug that set a default de dupe window on mirror, despite that being not a valid config
fixCfgMirrorWithDedupWindow(&cfg.StreamConfig)
@@ -1587,6 +1538,7 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
batchId string
batchSeq uint64
commit bool
commitEob bool
batchStoreDir string
store StreamStore
state StreamState
@@ -1604,19 +1556,30 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
}
// We've observed a partial batch write. Write the remainder of the batch.
batchSeq++
_, batchStoreDir = getBatchStoreDir(mset, batchId)
_, batchStoreDir = getBatchStoreDir(jsa.storeDir, cfg.Name, batchId)
if _, err = os.Stat(batchStoreDir); err != nil {
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: %v",
mset.accName(), mset.name(), batchSeq, err)
goto SKIP
}
store, err = newBatchStore(mset, batchId)
store, err = newBatchStore(mset, batchId, cfg.Replicas, cfg.Storage, jsa.storeDir, cfg.Name)
if err != nil {
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: %v",
mset.accName(), mset.name(), batchSeq, err)
goto SKIP
}
store.FastState(&state)
sm, err = store.LoadMsg(state.LastSeq, &smv)
if err != nil || sm == nil {
s.Errorf(" Failed restoring partial batch write for stream '%s > %s' at sequence %d: last msg not found %d",
mset.accName(), mset.name(), batchSeq, state.LastSeq)
goto SKIP
}
commitEob = bytes.Equal(sliceHeader(JSBatchCommit, sm.hdr), []byte("eob"))
// If the commit ends with an "End Of Batch" message, we don't store this.
if commitEob {
state.LastSeq--
}
s.Noticef(" Restoring partial batch write for stream '%s > %s' (seq %d to %d)",
mset.accName(), mset.name(), batchSeq, state.LastSeq)
// Loop through items that weren't persisted yet.
@@ -1627,7 +1590,12 @@ func (a *Account) EnableJetStream(limits map[string]JetStreamAccountLimits, tq c
mset.accName(), mset.name(), seq, err)
break
}
mset.processJetStreamMsg(sm.subj, _EMPTY_, sm.hdr, sm.msg, 0, 0, nil, false, true)
hdr := sm.hdr
// If committed by EOB, the last message must get the normal commit header.
if commitEob && seq == state.LastSeq {
hdr = genHeader(hdr, JSBatchCommit, "1")
}
mset.processJetStreamMsg(sm.subj, _EMPTY_, hdr, sm.msg, 0, 0, nil, false, true)
}
store.Delete(true)
SKIP:
@@ -2342,14 +2310,14 @@ func (jsa *jsAccount) sendClusterUsageUpdate() {
func (js *jetStream) wouldExceedLimits(storeType StorageType, sz int) bool {
var (
total *int64
max int64
max *int64
)
if storeType == MemoryStorage {
total, max = &js.memUsed, js.config.MaxMemory
total, max = &js.memUsed, &js.memMax
} else {
total, max = &js.storeUsed, js.config.MaxStore
total, max = &js.storeUsed, &js.storeMax
}
return (atomic.LoadInt64(total) + int64(sz)) > max
return (atomic.LoadInt64(total) + int64(sz)) > atomic.LoadInt64(max)
}
func (js *jetStream) limitsExceeded(storeType StorageType) bool {
@@ -2519,7 +2487,6 @@ func (jsa *jsAccount) acc() *Account {
// Delete the JetStream resources.
func (jsa *jsAccount) delete() {
var streams []*stream
var ts []string
jsa.mu.Lock()
// The update timer and subs need to be protected by usageMu lock
@@ -2538,20 +2505,11 @@ func (jsa *jsAccount) delete() {
for _, ms := range jsa.streams {
streams = append(streams, ms)
}
acc := jsa.account
for _, t := range jsa.templates {
ts = append(ts, t.Name)
}
jsa.templates = nil
jsa.mu.Unlock()
for _, mset := range streams {
mset.stop(false, false)
}
for _, t := range ts {
acc.deleteStreamTemplate(t)
}
}
// Lookup the jetstream account for a given account.
@@ -2763,325 +2721,6 @@ func (a *Account) checkForJetStream() (*Server, *jsAccount, error) {
return s, jsa, nil
}
// StreamTemplateConfig allows a configuration to auto-create streams based on this template when a message
// is received that matches. Each new stream will use the config as the template config to create them.
// Deprecated: stream templates are deprecated and will be removed in a future version.
type StreamTemplateConfig struct {
Name string `json:"name"`
Config *StreamConfig `json:"config"`
MaxStreams uint32 `json:"max_streams"`
}
// StreamTemplateInfo
// Deprecated: stream templates are deprecated and will be removed in a future version.
type StreamTemplateInfo struct {
Config *StreamTemplateConfig `json:"config"`
Streams []string `json:"streams"`
}
// streamTemplate
// Deprecated: stream templates are deprecated and will be removed in a future version.
type streamTemplate struct {
mu sync.Mutex
tc *client
jsa *jsAccount
*StreamTemplateConfig
streams []string
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (t *StreamTemplateConfig) deepCopy() *StreamTemplateConfig {
copy := *t
cfg := *t.Config
copy.Config = &cfg
return &copy
}
// addStreamTemplate will add a stream template to this account that allows auto-creation of streams.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (a *Account) addStreamTemplate(tc *StreamTemplateConfig) (*streamTemplate, error) {
s, jsa, err := a.checkForJetStream()
if err != nil {
return nil, err
}
if tc.Config.Name != "" {
return nil, fmt.Errorf("template config name should be empty")
}
if len(tc.Name) > JSMaxNameLen {
return nil, fmt.Errorf("template name is too long, maximum allowed is %d", JSMaxNameLen)
}
// FIXME(dlc) - Hacky
tcopy := tc.deepCopy()
tcopy.Config.Name = "_"
cfg, apiErr := s.checkStreamCfg(tcopy.Config, a, false)
if apiErr != nil {
return nil, apiErr
}
tcopy.Config = &cfg
t := &streamTemplate{
StreamTemplateConfig: tcopy,
tc: s.createInternalJetStreamClient(),
jsa: jsa,
}
t.tc.registerWithAccount(a)
jsa.mu.Lock()
if jsa.templates == nil {
jsa.templates = make(map[string]*streamTemplate)
// Create the appropriate store
if cfg.Storage == FileStorage {
jsa.store = newTemplateFileStore(jsa.storeDir)
} else {
jsa.store = newTemplateMemStore()
}
} else if _, ok := jsa.templates[tcopy.Name]; ok {
jsa.mu.Unlock()
return nil, fmt.Errorf("template with name %q already exists", tcopy.Name)
}
jsa.templates[tcopy.Name] = t
jsa.mu.Unlock()
// FIXME(dlc) - we can not overlap subjects between templates. Need to have test.
// Setup the internal subscriptions to trap the messages.
if err := t.createTemplateSubscriptions(); err != nil {
return nil, err
}
if err := jsa.store.Store(t); err != nil {
t.delete()
return nil, err
}
return t, nil
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (t *streamTemplate) createTemplateSubscriptions() error {
if t == nil {
return fmt.Errorf("no template")
}
if t.tc == nil {
return fmt.Errorf("template not enabled")
}
c := t.tc
if !c.srv.EventsEnabled() {
return ErrNoSysAccount
}
sid := 1
for _, subject := range t.Config.Subjects {
// Now create the subscription
if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil {
c.acc.deleteStreamTemplate(t.Name)
return err
}
sid++
}
return nil
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (t *streamTemplate) processInboundTemplateMsg(_ *subscription, pc *client, acc *Account, subject, reply string, msg []byte) {
if t == nil || t.jsa == nil {
return
}
jsa := t.jsa
cn := canonicalName(subject)
jsa.mu.Lock()
// If we already are registered then we can just return here.
if _, ok := jsa.streams[cn]; ok {
jsa.mu.Unlock()
return
}
jsa.mu.Unlock()
// Check if we are at the maximum and grab some variables.
t.mu.Lock()
c := t.tc
cfg := *t.Config
cfg.Template = t.Name
atLimit := len(t.streams) >= int(t.MaxStreams)
if !atLimit {
t.streams = append(t.streams, cn)
}
t.mu.Unlock()
if atLimit {
c.RateLimitWarnf("JetStream could not create stream for account %q on subject %q, at limit", acc.Name, subject)
return
}
// We need to create the stream here.
// Change the config from the template and only use literal subject.
cfg.Name = cn
cfg.Subjects = []string{subject}
mset, err := acc.addStream(&cfg)
if err != nil {
acc.validateStreams(t)
c.RateLimitWarnf("JetStream could not create stream for account %q on subject %q: %v", acc.Name, subject, err)
return
}
// Process this message directly by invoking mset.
mset.processInboundJetStreamMsg(nil, pc, acc, subject, reply, msg)
}
// lookupStreamTemplate looks up the names stream template.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (a *Account) lookupStreamTemplate(name string) (*streamTemplate, error) {
_, jsa, err := a.checkForJetStream()
if err != nil {
return nil, err
}
jsa.mu.Lock()
defer jsa.mu.Unlock()
if jsa.templates == nil {
return nil, fmt.Errorf("template not found")
}
t, ok := jsa.templates[name]
if !ok {
return nil, fmt.Errorf("template not found")
}
return t, nil
}
// This function will check all named streams and make sure they are valid.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (a *Account) validateStreams(t *streamTemplate) {
t.mu.Lock()
var vstreams []string
for _, sname := range t.streams {
if _, err := a.lookupStream(sname); err == nil {
vstreams = append(vstreams, sname)
}
}
t.streams = vstreams
t.mu.Unlock()
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (t *streamTemplate) delete() error {
if t == nil {
return fmt.Errorf("nil stream template")
}
t.mu.Lock()
jsa := t.jsa
c := t.tc
t.tc = nil
defer func() {
if c != nil {
c.closeConnection(ClientClosed)
}
}()
t.mu.Unlock()
if jsa == nil {
return NewJSNotEnabledForAccountError()
}
jsa.mu.Lock()
if jsa.templates == nil {
jsa.mu.Unlock()
return fmt.Errorf("template not found")
}
if _, ok := jsa.templates[t.Name]; !ok {
jsa.mu.Unlock()
return fmt.Errorf("template not found")
}
delete(jsa.templates, t.Name)
acc := jsa.account
jsa.mu.Unlock()
// Remove streams associated with this template.
var streams []*stream
t.mu.Lock()
for _, name := range t.streams {
if mset, err := acc.lookupStream(name); err == nil {
streams = append(streams, mset)
}
}
t.mu.Unlock()
if jsa.store != nil {
if err := jsa.store.Delete(t); err != nil {
return fmt.Errorf("error deleting template from store: %v", err)
}
}
var lastErr error
for _, mset := range streams {
if err := mset.delete(); err != nil {
lastErr = err
}
}
return lastErr
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (a *Account) deleteStreamTemplate(name string) error {
t, err := a.lookupStreamTemplate(name)
if err != nil {
return NewJSStreamTemplateNotFoundError()
}
return t.delete()
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (a *Account) templates() []*streamTemplate {
var ts []*streamTemplate
_, jsa, err := a.checkForJetStream()
if err != nil {
return nil
}
jsa.mu.Lock()
for _, t := range jsa.templates {
// FIXME(dlc) - Copy?
ts = append(ts, t)
}
jsa.mu.Unlock()
return ts
}
// Will add a stream to a template, this is for recovery.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (jsa *jsAccount) addStreamNameToTemplate(tname, mname string) error {
if jsa.templates == nil {
return fmt.Errorf("template not found")
}
t, ok := jsa.templates[tname]
if !ok {
return fmt.Errorf("template not found")
}
// We found template.
t.mu.Lock()
t.streams = append(t.streams, mname)
t.mu.Unlock()
return nil
}
// This will check if a template owns this stream.
// jsAccount lock should be held
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (jsa *jsAccount) checkTemplateOwnership(tname, sname string) bool {
if jsa.templates == nil {
return false
}
t, ok := jsa.templates[tname]
if !ok {
return false
}
// We found template, make sure we are in streams.
for _, streamName := range t.streams {
if sname == streamName {
return true
}
}
return false
}
type Number interface {
int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64
}
@@ -3107,10 +2746,11 @@ func isValidName(name string) bool {
return !strings.ContainsAny(name, " \t\r\n\f.*>")
}
// CanonicalName will replace all token separators '.' with '_'.
// This can be used when naming streams or consumers with multi-token subjects.
func canonicalName(name string) string {
return strings.ReplaceAll(name, ".", "_")
func isValidAssetName(name string) bool {
if name == _EMPTY_ {
return false
}
return !strings.ContainsAny(name, " \t\r\n\f.*>\\/")
}
// To throttle the out of resources errors.
+175 -331
View File
@@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"os"
"path/filepath"
"runtime"
@@ -46,29 +47,6 @@ const (
// Will return JSON response.
JSApiAccountInfo = "$JS.API.INFO"
// JSApiTemplateCreate is the endpoint to create new stream templates.
// Will return JSON response.
// Deprecated: stream templates are deprecated and will be removed in a future version.
JSApiTemplateCreate = "$JS.API.STREAM.TEMPLATE.CREATE.*"
JSApiTemplateCreateT = "$JS.API.STREAM.TEMPLATE.CREATE.%s"
// JSApiTemplates is the endpoint to list all stream template names for this account.
// Will return JSON response.
// Deprecated: stream templates are deprecated and will be removed in a future version.
JSApiTemplates = "$JS.API.STREAM.TEMPLATE.NAMES"
// JSApiTemplateInfo is for obtaining general information about a named stream template.
// Will return JSON response.
// Deprecated: stream templates are deprecated and will be removed in a future version.
JSApiTemplateInfo = "$JS.API.STREAM.TEMPLATE.INFO.*"
JSApiTemplateInfoT = "$JS.API.STREAM.TEMPLATE.INFO.%s"
// JSApiTemplateDelete is the endpoint to delete stream templates.
// Will return JSON response.
// Deprecated: stream templates are deprecated and will be removed in a future version.
JSApiTemplateDelete = "$JS.API.STREAM.TEMPLATE.DELETE.*"
JSApiTemplateDeleteT = "$JS.API.STREAM.TEMPLATE.DELETE.%s"
// JSApiStreamCreate is the endpoint to create new streams.
// Will return JSON response.
JSApiStreamCreate = "$JS.API.STREAM.CREATE.*"
@@ -177,6 +155,9 @@ const (
// JSApiRequestNextT is the prefix for the request next message(s) for a consumer in worker/pull mode.
JSApiRequestNextT = "$JS.API.CONSUMER.MSG.NEXT.%s.%s"
// JSApiConsumerResetT is the prefix for resetting a given consumer to a new starting sequence.
JSApiConsumerResetT = "$JS.API.CONSUMER.RESET.%s.%s"
// JSApiConsumerUnpinT is the prefix for unpinning subscription for a given consumer.
JSApiConsumerUnpin = "$JS.API.CONSUMER.UNPIN.*.*"
JSApiConsumerUnpinT = "$JS.API.CONSUMER.UNPIN.%s.%s"
@@ -237,13 +218,15 @@ const (
// jsAckT is the template for the ack message stream coming back from a consumer
// when they ACK/NAK, etc a message.
jsAckT = "$JS.ACK.%s.%s"
jsAckTv2 = "$JS.ACK.%s.%s.%s.%s"
jsAckPre = "$JS.ACK."
jsAckPreLen = len(jsAckPre)
// jsFlowControl is for flow control subjects.
jsFlowControlPre = "$JS.FC."
// jsFlowControl is for FC responses.
jsFlowControl = "$JS.FC.%s.%s.*"
jsFlowControl = "$JS.FC.%s.%s.*"
jsFlowControlV2 = "$JS.FC.%s.%s.%s.%s.*"
// JSAdvisoryPrefix is a prefix for all JetStream advisories.
JSAdvisoryPrefix = "$JS.EVENT.ADVISORY"
@@ -787,50 +770,19 @@ type JSApiConsumerGetNextRequest struct {
PriorityGroup
}
// JSApiStreamTemplateCreateResponse for creating templates.
// Deprecated: stream templates are deprecated and will be removed in a future version.
type JSApiStreamTemplateCreateResponse struct {
// JSApiConsumerResetRequest is for resetting a consumer to a specific sequence.
type JSApiConsumerResetRequest struct {
Seq uint64 `json:"seq,omitempty"`
}
// JSApiConsumerResetResponse is a superset of JSApiConsumerCreateResponse, but including an explicit ResetSeq.
type JSApiConsumerResetResponse struct {
ApiResponse
*StreamTemplateInfo
*ConsumerInfo
ResetSeq uint64 `json:"reset_seq"`
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
const JSApiStreamTemplateCreateResponseType = "io.nats.jetstream.api.v1.stream_template_create_response"
// Deprecated: stream templates are deprecated and will be removed in a future version.
type JSApiStreamTemplateDeleteResponse struct {
ApiResponse
Success bool `json:"success,omitempty"`
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
const JSApiStreamTemplateDeleteResponseType = "io.nats.jetstream.api.v1.stream_template_delete_response"
// JSApiStreamTemplateInfoResponse for information about stream templates.
// Deprecated: stream templates are deprecated and will be removed in a future version.
type JSApiStreamTemplateInfoResponse struct {
ApiResponse
*StreamTemplateInfo
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
const JSApiStreamTemplateInfoResponseType = "io.nats.jetstream.api.v1.stream_template_info_response"
// Deprecated: stream templates are deprecated and will be removed in a future version.
type JSApiStreamTemplatesRequest struct {
ApiPagedRequest
}
// JSApiStreamTemplateNamesResponse list of templates
// Deprecated: stream templates are deprecated and will be removed in a future version.
type JSApiStreamTemplateNamesResponse struct {
ApiResponse
ApiPaged
Templates []string `json:"streams"`
}
// Deprecated: stream templates are deprecated and will be removed in a future version.
const JSApiStreamTemplateNamesResponseType = "io.nats.jetstream.api.v1.stream_template_names_response"
const JSApiConsumerResetResponseType = "io.nats.jetstream.api.v1.consumer_reset_response"
// Structure that holds state for a JetStream API request that is processed
// in a separate long-lived go routine. This is to avoid blocking connections.
@@ -911,11 +863,19 @@ func (js *jetStream) apiDispatch(sub *subscription, c *client, acc *Account, sub
// Copy the state. Note the JSAPI only uses the hdr index to piece apart the
// header from the msg body. No other references are needed.
// Check pending and warn if getting backed up.
pending, _ := s.jsAPIRoutedReqs.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa})
limit := atomic.LoadInt64(&js.queueLimit)
var queue *ipQueue[*jsAPIRoutedReq]
var limit int64
if js.infoSubs.HasInterest(subject) {
queue = s.jsAPIRoutedInfoReqs
limit = atomic.LoadInt64(&js.infoQueueLimit)
} else {
queue = s.jsAPIRoutedReqs
limit = atomic.LoadInt64(&js.queueLimit)
}
pending, _ := queue.push(&jsAPIRoutedReq{jsub, sub, acc, subject, reply, copyBytes(rmsg), c.pa})
if pending >= int(limit) {
s.rateLimitFormatWarnf("JetStream API queue limit reached, dropping %d requests", pending)
drained := int64(s.jsAPIRoutedReqs.drain())
s.rateLimitFormatWarnf("%s limit reached, dropping %d requests", queue.name, pending)
drained := int64(queue.drain())
atomic.AddInt64(&js.apiInflight, -drained)
s.publishAdvisory(nil, JSAdvisoryAPILimitReached, JSAPILimitReachedAdvisory{
@@ -935,29 +895,45 @@ func (s *Server) processJSAPIRoutedRequests() {
defer s.grWG.Done()
s.mu.RLock()
queue := s.jsAPIRoutedReqs
queue, infoqueue := s.jsAPIRoutedReqs, s.jsAPIRoutedInfoReqs
client := &client{srv: s, kind: JETSTREAM}
s.mu.RUnlock()
js := s.getJetStream()
processFromQueue := func(ipq *ipQueue[*jsAPIRoutedReq]) {
// Only pop one item at a time here, otherwise if the system is recovering
// from queue buildup, then one worker will pull off all the tasks and the
// others will be starved of work.
if r, ok := ipq.popOne(); ok && r != nil {
client.pa = r.pa
start := time.Now()
r.jsub.icb(r.sub, client, r.acc, r.subject, r.reply, r.msg)
if dur := time.Since(start); dur >= readLoopReportThreshold {
s.Warnf("Internal subscription on %q took too long: %v", r.subject, dur)
}
atomic.AddInt64(&js.apiInflight, -1)
}
}
for {
// First select case is prioritizing queue, we will only fall through
// to the second select case that considers infoqueue if queue is empty.
// This effectively means infos are deprioritized.
select {
case <-queue.ch:
// Only pop one item at a time here, otherwise if the system is recovering
// from queue buildup, then one worker will pull off all the tasks and the
// others will be starved of work.
for r, ok := queue.popOne(); ok && r != nil; r, ok = queue.popOne() {
client.pa = r.pa
start := time.Now()
r.jsub.icb(r.sub, client, r.acc, r.subject, r.reply, r.msg)
if dur := time.Since(start); dur >= readLoopReportThreshold {
s.Warnf("Internal subscription on %q took too long: %v", r.subject, dur)
}
atomic.AddInt64(&js.apiInflight, -1)
}
processFromQueue(queue)
case <-s.quitCh:
return
default:
select {
case <-infoqueue.ch:
processFromQueue(infoqueue)
case <-queue.ch:
processFromQueue(queue)
case <-s.quitCh:
return
}
}
}
}
@@ -976,7 +952,8 @@ func (s *Server) setJetStreamExportSubs() error {
if mp > maxProcs {
mp = maxProcs
}
s.jsAPIRoutedReqs = newIPQueue[*jsAPIRoutedReq](s, "Routed JS API Requests")
s.jsAPIRoutedReqs = newIPQueue[*jsAPIRoutedReq](s, "JetStream API queue")
s.jsAPIRoutedInfoReqs = newIPQueue[*jsAPIRoutedReq](s, "JetStream API info queue")
for i := 0; i < mp; i++ {
s.startGoRoutine(s.processJSAPIRoutedRequests)
}
@@ -992,20 +969,13 @@ func (s *Server) setJetStreamExportSubs() error {
}
// API handles themselves.
// infopairs are deprioritized compared to pairs in processJSAPIRoutedRequests.
pairs := []struct {
subject string
handler msgHandler
}{
{JSApiAccountInfo, s.jsAccountInfoRequest},
{JSApiTemplateCreate, s.jsTemplateCreateRequest},
{JSApiTemplates, s.jsTemplateNamesRequest},
{JSApiTemplateInfo, s.jsTemplateInfoRequest},
{JSApiTemplateDelete, s.jsTemplateDeleteRequest},
{JSApiStreamCreate, s.jsStreamCreateRequest},
{JSApiStreamUpdate, s.jsStreamUpdateRequest},
{JSApiStreams, s.jsStreamNamesRequest},
{JSApiStreamList, s.jsStreamListRequest},
{JSApiStreamInfo, s.jsStreamInfoRequest},
{JSApiStreamDelete, s.jsStreamDeleteRequest},
{JSApiStreamPurge, s.jsStreamPurgeRequest},
{JSApiStreamSnapshot, s.jsStreamSnapshotRequest},
@@ -1018,23 +988,40 @@ func (s *Server) setJetStreamExportSubs() error {
{JSApiConsumerCreateEx, s.jsConsumerCreateRequest},
{JSApiConsumerCreate, s.jsConsumerCreateRequest},
{JSApiDurableCreate, s.jsConsumerCreateRequest},
{JSApiConsumers, s.jsConsumerNamesRequest},
{JSApiConsumerList, s.jsConsumerListRequest},
{JSApiConsumerInfo, s.jsConsumerInfoRequest},
{JSApiConsumerDelete, s.jsConsumerDeleteRequest},
{JSApiConsumerPause, s.jsConsumerPauseRequest},
{JSApiConsumerUnpin, s.jsConsumerUnpinRequest},
}
infopairs := []struct {
subject string
handler msgHandler
}{
{JSApiAccountInfo, s.jsAccountInfoRequest},
{JSApiStreams, s.jsStreamNamesRequest},
{JSApiStreamList, s.jsStreamListRequest},
{JSApiStreamInfo, s.jsStreamInfoRequest},
{JSApiConsumers, s.jsConsumerNamesRequest},
{JSApiConsumerList, s.jsConsumerListRequest},
{JSApiConsumerInfo, s.jsConsumerInfoRequest},
}
js.mu.Lock()
defer js.mu.Unlock()
for _, p := range pairs {
// As well as populating js.apiSubs for the dispatch function to use, we
// will also populate js.infoSubs, so that the dispatch function can
// decide quickly whether or not the request is an info request or not.
for _, p := range append(infopairs, pairs...) {
sub := &subscription{subject: []byte(p.subject), icb: p.handler}
if err := js.apiSubs.Insert(sub); err != nil {
return err
}
}
for _, p := range infopairs {
if err := js.infoSubs.Insert(p.subject, struct{}{}); err != nil {
return err
}
}
return nil
}
@@ -1239,7 +1226,7 @@ func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg [
c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err)
if s.JetStreamConfig().Strict {
if js := s.getJetStream(); js != nil && js.config.Strict {
return err
}
@@ -1345,10 +1332,6 @@ func (s *Server) jsAccountInfoRequest(sub *subscription, c *client, _ *Account,
}
// Helpers for token extraction.
func templateNameFromSubject(subject string) string {
return tokenAt(subject, 6)
}
func streamNameFromSubject(subject string) string {
return tokenAt(subject, 5)
}
@@ -1357,223 +1340,6 @@ func consumerNameFromSubject(subject string) string {
return tokenAt(subject, 6)
}
// Request to create a new template.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
if c == nil {
return
}
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
if err != nil {
s.Warnf(badAPIRequestT, msg)
return
}
var resp = JSApiStreamTemplateCreateResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateCreateResponseType}}
if errorOnRequiredApiLevel(hdr) {
resp.Error = NewJSRequiredApiLevelError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !acc.JetStreamEnabled() {
resp.Error = NewJSNotEnabledForAccountError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
// Not supported for now.
if s.JetStreamIsClustered() {
resp.Error = NewJSClusterUnSupportFeatureError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
var cfg StreamTemplateConfig
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
templateName := templateNameFromSubject(subject)
if templateName != cfg.Name {
resp.Error = NewJSTemplateNameNotMatchSubjectError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
t, err := acc.addStreamTemplate(&cfg)
if err != nil {
resp.Error = NewJSStreamTemplateCreateError(err, Unless(err))
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
t.mu.Lock()
tcfg := t.StreamTemplateConfig.deepCopy()
streams := t.streams
if streams == nil {
streams = []string{}
}
t.mu.Unlock()
resp.StreamTemplateInfo = &StreamTemplateInfo{Config: tcfg, Streams: streams}
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
}
// Request for the list of all template names.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
if c == nil {
return
}
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
if err != nil {
s.Warnf(badAPIRequestT, msg)
return
}
var resp = JSApiStreamTemplateNamesResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateNamesResponseType}}
if errorOnRequiredApiLevel(hdr) {
resp.Error = NewJSRequiredApiLevelError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !acc.JetStreamEnabled() {
resp.Error = NewJSNotEnabledForAccountError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
// Not supported for now.
if s.JetStreamIsClustered() {
resp.Error = NewJSClusterUnSupportFeatureError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiStreamTemplatesRequest
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
offset = req.Offset
}
ts := acc.templates()
slices.SortFunc(ts, func(i, j *streamTemplate) int {
return cmp.Compare(i.StreamTemplateConfig.Name, j.StreamTemplateConfig.Name)
})
tcnt := len(ts)
if offset > tcnt {
offset = tcnt
}
for _, t := range ts[offset:] {
t.mu.Lock()
name := t.Name
t.mu.Unlock()
resp.Templates = append(resp.Templates, name)
if len(resp.Templates) >= JSApiNamesLimit {
break
}
}
resp.Total = tcnt
resp.Limit = JSApiNamesLimit
resp.Offset = offset
if resp.Templates == nil {
resp.Templates = []string{}
}
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
}
// Request for information about a stream template.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (s *Server) jsTemplateInfoRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
if c == nil {
return
}
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
if err != nil {
s.Warnf(badAPIRequestT, msg)
return
}
var resp = JSApiStreamTemplateInfoResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateInfoResponseType}}
if errorOnRequiredApiLevel(hdr) {
resp.Error = NewJSRequiredApiLevelError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !acc.JetStreamEnabled() {
resp.Error = NewJSNotEnabledForAccountError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !isEmptyRequest(msg) {
resp.Error = NewJSNotEmptyRequestError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
name := templateNameFromSubject(subject)
t, err := acc.lookupStreamTemplate(name)
if err != nil {
resp.Error = NewJSStreamTemplateNotFoundError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
t.mu.Lock()
cfg := t.StreamTemplateConfig.deepCopy()
streams := t.streams
if streams == nil {
streams = []string{}
}
t.mu.Unlock()
resp.StreamTemplateInfo = &StreamTemplateInfo{Config: cfg, Streams: streams}
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
}
// Request to delete a stream template.
// Deprecated: stream templates are deprecated and will be removed in a future version.
func (s *Server) jsTemplateDeleteRequest(sub *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
if c == nil {
return
}
ci, acc, hdr, msg, err := s.getRequestInfo(c, rmsg)
if err != nil {
s.Warnf(badAPIRequestT, msg)
return
}
var resp = JSApiStreamTemplateDeleteResponse{ApiResponse: ApiResponse{Type: JSApiStreamTemplateDeleteResponseType}}
if errorOnRequiredApiLevel(hdr) {
resp.Error = NewJSRequiredApiLevelError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !acc.JetStreamEnabled() {
resp.Error = NewJSNotEnabledForAccountError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if !isEmptyRequest(msg) {
resp.Error = NewJSNotEmptyRequestError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
name := templateNameFromSubject(subject)
err = acc.deleteStreamTemplate(name)
if err != nil {
resp.Error = NewJSStreamTemplateDeleteError(err, Unless(err))
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
resp.Success = true
s.sendAPIResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(resp))
}
func (s *Server) jsonResponse(v any) string {
b, err := json.Marshal(v)
if err != nil {
@@ -2109,7 +1875,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
if cc != nil {
// Check to make sure the stream is assigned.
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, streamName)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, streamName)
var offline bool
if sa != nil {
clusterWideConsCount = len(sa.consumers)
@@ -2338,7 +2104,7 @@ func (s *Server) jsStreamLeaderStepDownRequest(sub *subscription, c *client, _ *
}
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, name)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, name)
js.mu.RUnlock()
if isLeader && sa == nil {
@@ -2455,7 +2221,7 @@ func (s *Server) jsConsumerLeaderStepDownRequest(sub *subscription, c *client, _
consumer := tokenAt(subject, 7)
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
js.mu.RUnlock()
if isLeader && sa == nil {
@@ -3456,7 +3222,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su
}
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
js.mu.RUnlock()
if isLeader && sa == nil {
@@ -3581,7 +3347,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje
}
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
js.mu.RUnlock()
if isLeader && sa == nil {
@@ -3876,7 +3642,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account,
}
js.mu.RLock()
isLeader, sa := cc.isLeader(), js.streamAssignment(acc.Name, stream)
isLeader, sa := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, stream)
js.mu.RUnlock()
if isLeader && sa == nil {
@@ -4048,6 +3814,13 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
return
}
// Check for path like separators in the name.
if strings.ContainsAny(stream, `\/`) {
resp.Error = NewJSStreamNameContainsPathSeparatorsError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
if s.JetStreamIsClustered() {
s.jsClusteredStreamRestoreRequest(ci, acc, &req, subject, reply, rmsg)
return
@@ -4597,11 +4370,49 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
isClustered := s.JetStreamIsClustered()
// Determine if we should proceed here when we are in clustered mode.
direct := req.Config.Direct
if isClustered {
if req.Config.Direct {
// Check to see if we have this stream and are the stream leader.
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
return
if direct {
// If it's just a direct consumer, check for stream leader.
if !req.Config.Sourcing {
// Check to see if we have this stream and are the stream leader.
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
return
}
} else {
// Otherwise, we either need this to be answered by the stream or meta leader.
var cc *jetStreamCluster
js, cc = s.getJetStreamCluster()
if js == nil || cc == nil {
return
}
js.mu.RLock()
sa := js.streamAssignmentOrInflight(acc.Name, streamNameFromSubject(subject))
if sa == nil {
js.mu.RUnlock()
return
}
// If the stream is WQ or Interest, we need the meta leader to answer.
if sa.Config.Retention != LimitsPolicy {
direct = false
}
js.mu.RUnlock()
if direct {
// Check to see if we have this stream and are the stream leader.
if !acc.JetStreamIsStreamLeader(streamNameFromSubject(subject)) {
return
}
} else {
if js.isLeaderless() {
resp.Error = NewJSClusterNotAvailError()
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
}
// Make sure we are meta leader.
if !s.JetStreamIsLeader() {
return
}
}
}
} else {
var cc *jetStreamCluster
@@ -4645,6 +4456,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
// Legacy ephemeral.
rt = ccLegacyEphemeral
streamName = streamNameFromSubject(subject)
consumerName = req.Config.Name
} else {
// New style and durable legacy.
if tokenAt(subject, 4) == "DURABLE" {
@@ -4736,7 +4548,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
return
}
if isClustered && !req.Config.Direct {
if isClustered && !direct {
s.jsClusteredConsumerRequest(ci, acc, subject, reply, rmsg, req.Stream, &req.Config, req.Action, req.Pedantic)
return
}
@@ -4760,6 +4572,23 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
return
}
// If the consumer is a direct sourcing consumer, we need to "upgrade"
// it to be durable without AckNone if not a Limits-based stream.
if req.Config.Direct && req.Config.Sourcing && req.Config.Name != _EMPTY_ {
if !isClustered && stream.isInterestRetention() {
req.Config.Direct = false
req.Config.Durable = req.Config.Name
req.Config.AckPolicy = AckFlowControl
req.Config.AckWait = 0
req.Config.MaxDeliver = 0
req.Config.InactiveThreshold = 0
} else {
// Otherwise, need to append a randomized suffix since the source uses a stable name.
req.Config.Name = fmt.Sprintf("%s-%s", req.Config.Name, createConsumerName())
consumerName = req.Config.Name
}
}
if o := stream.lookupConsumer(consumerName); o != nil {
if o.offlineReason != _EMPTY_ {
resp.Error = NewJSConsumerOfflineReasonError(errors.New(o.offlineReason))
@@ -4770,6 +4599,12 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
// it back to whatever the current configured value is.
o.mu.RLock()
req.Config.PauseUntil = o.cfg.PauseUntil
// If a durable sourcing consumer is used, we need to reset the deliver policy.
if req.Config.Sourcing && req.Config.Durable != _EMPTY_ {
req.Config.DeliverPolicy = o.cfg.DeliverPolicy
req.Config.OptStartSeq = o.cfg.OptStartSeq
req.Config.OptStartTime = o.cfg.OptStartTime
}
o.mu.RUnlock()
}
@@ -5079,7 +4914,7 @@ func (s *Server) jsConsumerInfoRequest(sub *subscription, c *client, _ *Account,
groupCreated := meta.Created()
js.mu.RLock()
isLeader, sa, ca := cc.isLeader(), js.streamAssignment(acc.Name, streamName), js.consumerAssignment(acc.Name, streamName, consumerName)
isLeader, sa, ca := cc.isLeader(), js.streamAssignmentOrInflight(acc.Name, streamName), js.consumerAssignmentOrInflight(acc.Name, streamName, consumerName)
var rg *raftGroup
var offline, isMember bool
if ca != nil {
@@ -5404,8 +5239,11 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
return
}
nca := *ca
nca := ca.clone()
// We're only holding the read lock and release below,
// we need a copy to prevent concurrent reads/writes.
ncfg := *ca.Config
ncfg.Metadata = maps.Clone(ncfg.Metadata)
nca.Config = &ncfg
meta := cc.meta
js.mu.RUnlock()
@@ -5420,7 +5258,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
// Only PauseUntil is updated above, so reuse config for both.
setStaticConsumerMetadata(nca.Config)
eca := encodeAddConsumerAssignment(&nca)
eca := encodeAddConsumerAssignment(nca)
meta.Propose(eca)
resp.PauseUntil = pauseUTC
@@ -5453,7 +5291,13 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
return
}
// We're only holding the read lock and release below,
// we need a copy to prevent concurrent reads/writes.
obs.mu.RLock()
ncfg := obs.cfg
ncfg.Metadata = maps.Clone(ncfg.Metadata)
obs.mu.RUnlock()
pauseUTC := req.PauseUntil.UTC()
if !pauseUTC.IsZero() {
ncfg.PauseUntil = &pauseUTC
+455 -48
View File
@@ -21,6 +21,7 @@ import (
"math/big"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -29,60 +30,105 @@ import (
var (
// Tracks the total inflight batches, across all streams and accounts that enable batching.
globalInflightBatches atomic.Int32
globalInflightAtomicBatches atomic.Int64
globalInflightFastBatches atomic.Int64
)
type batching struct {
mu sync.Mutex
group map[string]*batchGroup
mu sync.Mutex
atomic map[string]*atomicBatch
fast map[string]*fastBatch
}
type batchGroup struct {
lseq uint64
store StreamStore
timer *time.Timer
type atomicBatch struct {
timer *time.Timer // Inactivity timer for the batch.
lseq uint64 // The highest sequence for this batch.
store StreamStore // Where the batch is staged before committing.
}
type fastBatch struct {
timer *time.Timer // Inactivity timer for the batch.
lseq uint64 // The highest sequence for this batch.
sseq uint64 // Last persisted stream sequence.
pseq uint64 // Last persisted batch sequence (is always lower or equal to lseq).
fseq uint64 // Sequence of when we last sent a flow message (is always lower or equal to pseq).
pending uint32 // Number of pending messages in the batch waiting to be persisted.
ackMessages uint16 // Ack will be sent every N messages.
maxAckMessages uint16 // Maximum ackMessages value the client allows.
reply string // The last reply subject seen when persisting a message.
gapOk bool // Whether a gap is okay, if not, the batch would be rejected.
commit bool // If the batch is committed.
}
// newAtomicBatch creates an atomic batch publish object.
// Lock should be held.
func (batches *batching) newBatchGroup(mset *stream, batchId string) (*batchGroup, error) {
store, err := newBatchStore(mset, batchId)
func (batches *batching) newAtomicBatch(mset *stream, batchId string, replicas int, storage StorageType, storeDir, streamName string) (*atomicBatch, error) {
store, err := newBatchStore(mset, batchId, replicas, storage, storeDir, streamName)
if err != nil {
return nil, err
}
b := &batchGroup{store: store}
b := &atomicBatch{store: store}
b.setupCleanupTimer(mset, batchId, batches)
return b, nil
}
// setupCleanupTimer sets up a timer to clean up the batch after a timeout.
func (b *atomicBatch) setupCleanupTimer(mset *stream, batchId string, batches *batching) {
// Create a timer to clean up after timeout.
timeout := streamMaxBatchTimeout
if maxBatchTimeout := mset.srv.getOpts().JetStreamLimits.MaxBatchTimeout; maxBatchTimeout > 0 {
timeout = maxBatchTimeout
}
timeout := getCleanupTimeout(mset)
b.timer = time.AfterFunc(timeout, func() {
b.cleanup(batchId, batches)
mset.sendStreamBatchAbandonedAdvisory(batchId, BatchTimeout)
})
return b, nil
}
func getBatchStoreDir(mset *stream, batchId string) (string, string) {
mset.mu.RLock()
jsa, name := mset.jsa, mset.cfg.Name
mset.mu.RUnlock()
// resetCleanupTimer resets the cleanup timer, allowing to extend the lifetime of the batch.
// Returns whether the timer was reset without it having expired before.
func (b *atomicBatch) resetCleanupTimer(mset *stream) bool {
timeout := getCleanupTimeout(mset)
return b.timer.Reset(timeout)
}
jsa.mu.RLock()
sd := jsa.storeDir
jsa.mu.RUnlock()
// cleanup deletes underlying resources associated with the batch and unregisters it from the stream's batches.
func (b *atomicBatch) cleanup(batchId string, batches *batching) {
batches.mu.Lock()
defer batches.mu.Unlock()
b.cleanupLocked(batchId, batches)
}
// Lock should be held.
func (b *atomicBatch) cleanupLocked(batchId string, batches *batching) {
if b.timer == nil {
return
}
globalInflightAtomicBatches.Add(-1)
b.timer.Stop()
b.store.Delete(true)
delete(batches.atomic, batchId)
// Reset so that another invocation doesn't double-account.
b.timer = nil
}
// Lock should be held.
func (b *atomicBatch) stopLocked() {
if b.timer == nil {
return
}
globalInflightAtomicBatches.Add(-1)
b.timer.Stop()
b.store.Stop()
// Reset so that another invocation doesn't double-account.
b.timer = nil
}
func getBatchStoreDir(storeDir, streamName, batchId string) (string, string) {
bname := getHash(batchId)
return bname, filepath.Join(sd, streamsDir, name, batchesDir, bname)
return bname, filepath.Join(storeDir, streamsDir, streamName, batchesDir, bname)
}
func newBatchStore(mset *stream, batchId string) (StreamStore, error) {
mset.mu.RLock()
replicas, storage := mset.cfg.Replicas, mset.cfg.Storage
mset.mu.RUnlock()
func newBatchStore(mset *stream, batchId string, replicas int, storage StorageType, storeDir, streamName string) (StreamStore, error) {
if replicas == 1 && storage == FileStorage {
bname, storeDir := getBatchStoreDir(mset, batchId)
bname, storeDir := getBatchStoreDir(storeDir, streamName, batchId)
fcfg := FileStoreConfig{AsyncFlush: true, BlockSize: defaultLargeBlockSize, StoreDir: storeDir}
s := mset.srv
prf := s.jsKeyGen(s.getOpts().JetStreamKey, mset.acc.Name)
@@ -101,34 +147,264 @@ func newBatchStore(mset *stream, batchId string) (StreamStore, error) {
// If the timer has already cleaned up the batch, we can't commit.
// Otherwise, we ensure the timer does not clean up the batch in the meantime.
// Lock should be held.
func (b *batchGroup) readyForCommit() bool {
func (b *atomicBatch) readyForCommit() *BatchAbandonReason {
if !b.timer.Stop() {
return &BatchTimeout
}
if b.store.FlushAllPending() != nil {
return &BatchIncomplete
}
return nil
}
// newFastBatch creates a fast batch publish object and registers it in batches.fast.
// Lock should be held.
func (batches *batching) newFastBatch(mset *stream, batchId string, gapOk bool, maxAckMessages uint16) *fastBatch {
b := &fastBatch{gapOk: gapOk, maxAckMessages: maxAckMessages}
if batches.fast == nil {
batches.fast = make(map[string]*fastBatch, 1)
}
batches.fast[batchId] = b
batches.fastBatchInit(b)
b.setupCleanupTimer(mset, batchId, batches)
return b
}
// fastBatchInit (re)initializes the ackMessages field for a fast batch.
// The batch must already be registered in batches.fast.
// Lock should be held.
func (batches *batching) fastBatchInit(b *fastBatch) {
// If it's the only batch, just allow what the client wants, otherwise we'll
// need to coordinate and slowly ramp up this publisher.
// TODO(mvv): fast ingest's initial flow value improvements?
ackMessages := min(500, b.maxAckMessages)
if len(batches.fast) > 1 {
ackMessages = 1
}
b.ackMessages = ackMessages
}
// fastBatchReset resets the fast batch to an empty state and sends a flow control message.
// Lock should be held.
func (batches *batching) fastBatchReset(mset *stream, batchId string, b *fastBatch) {
// If the timer already stopped before we could commit, we clean it up.
if b.timer == nil || (!b.commit && !b.timer.Stop()) {
b.cleanupLocked(batchId, batches)
return
}
// Otherwise, reset the state.
batches.fastBatchInit(b)
b.timer.Reset(getCleanupTimeout(mset))
b.commit = false
b.pending = 0
b.fseq, b.lseq = b.pseq, b.pseq
b.sendFlowControl(b.fseq, mset, b.reply)
}
// fastBatchRegisterSequences registers the highest stored batch and stream sequence and returns
// whether a PubAck should be sent if the batch has been committed.
// If this is called on a follower, it only registers the highest stream and persisted batch sequences.
// Lock should be held.
func (batches *batching) fastBatchRegisterSequences(mset *stream, reply string, streamSeq uint64, isLeader bool, batch *FastBatch) bool {
b, ok := batches.fast[batch.id]
if !ok || !isLeader {
// If this batch has committed, we can clean it up.
if batch.commit {
if b != nil {
b.cleanupLocked(batch.id, batches)
}
return false
}
// Otherwise, even as a follower, we record the latest state of this batch.
if b == nil || !b.resetCleanupTimer(mset) {
if b != nil {
// The timer couldn't be reset, this means the timer already runs and is likely
// waiting to acquire the lock. We reset the timer here so it doesn't clean up
// this batch that we're about to overwrite.
b.timer = nil
} else {
// If this is a new batch for us, even though we're a follower, we still need
// to account toward the global inflight limit.
globalInflightFastBatches.Add(1)
}
// We'll need a copy as we'll use it as a key and later for cleanup.
batchId := copyString(batch.id)
b = batches.newFastBatch(mset, batchId, batch.gapOk, batch.flow)
}
b.sseq = streamSeq
b.pseq, b.lseq = batch.seq, batch.seq
b.reply = reply
return false
}
b.store.FlushAllPending()
b.reply = reply
if b.pending > 0 {
b.pending--
}
b.sseq = streamSeq
// Store last persisted batch sequence.
// If we have no remaining pending writes, we might have had duplicate messages
// and need to send additional flow control messages.
var skipped bool
if b.pending == 0 {
skipped = true
b.pseq = b.lseq
} else {
b.pseq = batch.seq
}
// If the PubAck needs to be sent now as a result of a commit.
if b.lseq == b.pseq && b.commit {
b.cleanupLocked(batch.id, batches)
// If we skipped ahead due to duplicate messages, send the PubAck with the highest sequence.
if skipped {
var buf [256]byte
pubAck := append(buf[:0], mset.pubAck...)
response := append(pubAck, strconv.FormatUint(b.sseq, 10)...)
response = append(response, fmt.Sprintf(",\"batch\":%q,\"count\":%d}", batch.id, b.lseq)...)
if len(reply) > 0 {
mset.outq.sendMsg(reply, response)
}
return false
}
return true
}
b.checkFlowControl(mset, reply, batches)
return false
}
// checkFlowControl checks whether a flow control message should be sent.
// If so, it updates the flow values to speed up or slow down the publisher if needed.
// Returns whether a flow control message was sent.
// Lock should be held.
func (b *fastBatch) checkFlowControl(mset *stream, reply string, batches *batching) bool {
am := uint64(b.ackMessages)
if b.pseq < b.fseq+am {
return false
}
// Instead of sending multiple flow control messages, skip ahead to only send the last.
steps := (b.pseq - b.fseq) / am
b.fseq += steps * am
// TODO(mvv): fast ingest's dynamic flow value improvements?
// This is currently just a simple value to have a working version. Should take average
// message sizes into account and compare how much this client is contributing to the
// ingest IPQ total size and messages and have publishers share based on that.
maxAckMessages := uint16(500 / len(batches.fast))
if maxAckMessages < 1 {
maxAckMessages = 1
}
// Limit to the client's allowed maximum.
if maxAckMessages > b.maxAckMessages {
maxAckMessages = b.maxAckMessages
}
if b.ackMessages < maxAckMessages {
// Ramp up.
b.ackMessages *= 2
if b.ackMessages > maxAckMessages {
b.ackMessages = maxAckMessages
}
} else if b.ackMessages > maxAckMessages {
// Slow down.
b.ackMessages /= 2
if b.ackMessages <= maxAckMessages {
b.ackMessages = maxAckMessages
}
}
// Finally, send the flow control message.
b.sendFlowControl(b.fseq, mset, reply)
return true
}
// sendFlowControl sends a fast batch flow control message for the current highest sequence.
// Lock should be held.
func (b *fastBatch) sendFlowControl(batchSeq uint64, mset *stream, reply string) {
if len(reply) == 0 {
return
}
response, _ := BatchFlowAck{Sequence: batchSeq, Messages: b.ackMessages}.MarshalJSON()
mset.outq.sendMsg(reply, response)
}
// fastBatchCommit ends the batch and commits the data up to that point. If all messages
// have already been persisted, a PubAck is sent immediately. Otherwise, it will be sent
// after the last message has been persisted.
// Lock should be held.
func (batches *batching) fastBatchCommit(b *fastBatch, batchId string, mset *stream, reply string) bool {
// Either we commit now, or we clean up later, so stop the timer.
if b.timer == nil || (!b.commit && !b.timer.Stop()) {
// Shouldn't be possible for the timer to already be stopped if we haven't committed yet,
// since we pre-check being able to reset the timer. But guard against it anyhow.
return true
}
// Mark that this batch commits.
b.commit = true
// If the whole batch has been persisted, we can respond with the PubAck now.
if b.lseq == b.pseq {
b.cleanupLocked(batchId, batches)
var buf [256]byte
pubAck := append(buf[:0], mset.pubAck...)
response := append(pubAck, strconv.FormatUint(b.sseq, 10)...)
response = append(response, fmt.Sprintf(",\"batch\":%q,\"count\":%d}", batchId, b.lseq)...)
if len(reply) > 0 {
mset.outq.sendMsg(reply, response)
}
return true
}
// Otherwise, we need to wait and the PubAck will be sent when the last message is persisted.
return false
}
// setupCleanupTimer sets up a timer to clean up the batch after a timeout.
func (b *fastBatch) setupCleanupTimer(mset *stream, batchId string, batches *batching) {
// Create a timer to clean up after timeout.
timeout := getCleanupTimeout(mset)
b.timer = time.AfterFunc(timeout, func() {
b.cleanup(batchId, batches)
})
}
// resetCleanupTimer resets the cleanup timer, allowing to extend the lifetime of the batch.
// Returns whether the timer was reset without it having expired before.
func (b *fastBatch) resetCleanupTimer(mset *stream) bool {
if b.commit {
return true
}
if b.timer == nil {
return false
}
timeout := getCleanupTimeout(mset)
return b.timer.Reset(timeout)
}
// cleanup deletes underlying resources associated with the batch and unregisters it from the stream's batches.
func (b *batchGroup) cleanup(batchId string, batches *batching) {
func (b *fastBatch) cleanup(batchId string, batches *batching) {
batches.mu.Lock()
defer batches.mu.Unlock()
b.cleanupLocked(batchId, batches)
}
// Lock should be held.
func (b *batchGroup) cleanupLocked(batchId string, batches *batching) {
globalInflightBatches.Add(-1)
func (b *fastBatch) cleanupLocked(batchId string, batches *batching) {
// If the timer is nil, it means this batch has been replaced with a new one.
// This can happen on a follower depending on timing.
if b.timer == nil {
return
}
globalInflightFastBatches.Add(-1)
b.timer.Stop()
b.store.Delete(true)
delete(batches.group, batchId)
delete(batches.fast, batchId)
// Reset so that another invocation doesn't double-account.
b.timer = nil
}
// Lock should be held.
func (b *batchGroup) stopLocked() {
globalInflightBatches.Add(-1)
b.timer.Stop()
b.store.Stop()
// getCleanupTimeout returns the timeout for the batch, taking into account the server's limits.
func getCleanupTimeout(mset *stream) time.Duration {
timeout := streamMaxBatchTimeout
if maxBatchTimeout := mset.srv.getOpts().JetStreamLimits.MaxBatchTimeout; maxBatchTimeout > 0 {
timeout = maxBatchTimeout
}
return timeout
}
// batchStagedDiff stages all changes for consistency checks until commit.
@@ -136,6 +412,7 @@ type batchStagedDiff struct {
msgIds map[string]struct{}
counter map[string]*msgCounterRunningTotal
inflight map[string]*inflightSubjectRunningTotal
inflightTransform map[uint64]string
expectedPerSubject map[string]*batchExpectedPerSubject
}
@@ -180,6 +457,16 @@ func (diff *batchStagedDiff) commit(mset *stream) {
}
}
// Track inflight subject transforms.
if len(diff.inflightTransform) > 0 {
if mset.inflightTransform == nil {
mset.inflightTransform = make(map[uint64]string, len(diff.inflightTransform))
}
for clseq, subj := range diff.inflightTransform {
mset.inflightTransform[clseq] = subj
}
}
// Track sequence and subject.
if len(diff.expectedPerSubject) > 0 {
if mset.expectedPerSubjectSequence == nil {
@@ -238,7 +525,7 @@ func (batch *batchApply) rejectBatchState(mset *stream) {
// mset.mu lock must NOT be held or used.
// mset.clMu lock must be held.
func checkMsgHeadersPreClusteredProposal(
diff *batchStagedDiff, mset *stream, subject string, hdr []byte, msg []byte, sourced bool, name string,
diff *batchStagedDiff, mset *stream, subject, rsubject string, hdr []byte, msg []byte, sourced bool, name string,
jsa *jsAccount, allowRollup, denyPurge, allowTTL, allowMsgCounter, allowMsgSchedules bool,
discard DiscardPolicy, discardNewPer bool, maxMsgSize int, maxMsgs int64, maxMsgsPer int64, maxBytes int64,
) ([]byte, []byte, uint64, *ApiError, error) {
@@ -515,8 +802,9 @@ func checkMsgHeadersPreClusteredProposal(
}
// Message scheduling.
if schedule, ok := getMessageSchedule(hdr); !ok {
apiErr := NewJSMessageSchedulesPatternInvalidError()
if sourced {
// noop, sourced messages were already validated by the origin stream.
} else if schedule, apiErr := getMessageSchedule(hdr); apiErr != nil {
if !allowMsgSchedules {
apiErr = NewJSMessageSchedulesDisabledError()
}
@@ -528,22 +816,40 @@ func checkMsgHeadersPreClusteredProposal(
} else if scheduleTtl, ok := getMessageScheduleTTL(hdr); !ok {
apiErr := NewJSMessageSchedulesTTLInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else if scheduleRollup := getMessageScheduleRollup(hdr); scheduleRollup != _EMPTY_ && scheduleRollup != JSMsgRollupSubject {
apiErr := NewJSMessageSchedulesRollupInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else if scheduleTtl != _EMPTY_ && !allowTTL {
return hdr, msg, 0, NewJSMessageTTLDisabledError(), errMsgTTLDisabled
} else if scheduleTarget := getMessageScheduleTarget(hdr); scheduleTarget == _EMPTY_ ||
!IsValidPublishSubject(scheduleTarget) || SubjectsCollide(scheduleTarget, subject) {
!IsValidPublishSubject(scheduleTarget) || scheduleTarget == subject {
apiErr := NewJSMessageSchedulesTargetInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else if scheduleSource := getMessageScheduleSource(hdr); scheduleSource != _EMPTY_ &&
(scheduleSource == scheduleTarget || scheduleSource == subject || !IsValidPublishSubject(scheduleSource)) {
apiErr := NewJSMessageSchedulesSourceInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else {
mset.cfgMu.RLock()
match := slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
return SubjectsCollide(subj, scheduleTarget)
})
mset.cfgMu.RUnlock()
if !match {
mset.cfgMu.RUnlock()
apiErr := NewJSMessageSchedulesTargetInvalidError()
return hdr, msg, 0, apiErr, apiErr
}
if scheduleSource != _EMPTY_ {
match = slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
return SubjectsCollide(subj, scheduleSource)
})
if !match {
mset.cfgMu.RUnlock()
apiErr := NewJSMessageSchedulesSourceInvalidError()
return hdr, msg, 0, apiErr, apiErr
}
}
mset.cfgMu.RUnlock()
// Add a rollup sub header if it doesn't already exist.
// Otherwise, it must exist already as a rollup on the subject.
@@ -555,10 +861,32 @@ func checkMsgHeadersPreClusteredProposal(
}
}
}
if scheduleNext := sliceHeader(JSScheduleNext, hdr); len(scheduleNext) > 0 && !sourced {
// Clients may only use Nats-Schedule-Next to purge a schedule.
if bytesToString(scheduleNext) != JSScheduleNextPurge {
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
return hdr, msg, 0, apiErr, apiErr
}
// Nats-Scheduler must accompany the purge and:
// - it must NOT be empty.
// - it must NOT match the publish subject.
if scheduler := sliceHeader(JSScheduler, hdr); len(scheduler) == 0 ||
bytesToString(scheduler) == subject || !IsValidPublishSubject(bytesToString(scheduler)) {
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else if !allowMsgSchedules {
apiErr := NewJSMessageSchedulesDisabledError()
return hdr, msg, 0, apiErr, apiErr
}
} else if !sourced && len(sliceHeader(JSScheduler, hdr)) > 0 {
// Clients may only use Nats-Scheduler alongside Nats-Schedule-Next.
apiErr := NewJSMessageSchedulesSchedulerInvalidError()
return hdr, msg, 0, apiErr, apiErr
}
// Check for any rollups.
if rollup := getRollup(hdr); rollup != _EMPTY_ {
if !allowRollup || denyPurge {
if (!allowRollup || denyPurge) && !sourced {
err := errors.New("rollup not permitted")
return hdr, msg, 0, NewJSStreamRollupFailedError(err), err
}
@@ -607,6 +935,19 @@ func checkMsgHeadersPreClusteredProposal(
diff.inflight[subject] = i
}
// Subject transform.
if subject != rsubject {
// The 'subject' is a transformed subject used for consistency checks.
// But since we propose the original (raw) subject to our peers, we need
// to store the transformed subject separately for when we apply.
// TODO(mvv): since subject transforms are handled by each replica individually, this has a
// potential for desync given out-of-order stream subject transform updates.
if diff.inflightTransform == nil {
diff.inflightTransform = make(map[uint64]string, 1)
}
diff.inflightTransform[mset.clseq] = subject
}
// Check if we have discard new with max msgs or bytes.
// We need to deny here otherwise we'd need to bump CLFS, and it could succeed on some
// peers and not others depending on consumer ack state (if interest policy).
@@ -639,7 +980,8 @@ func checkMsgHeadersPreClusteredProposal(
}
// Similarly, check DiscardNew per-subject threshold to not need to bump CLFS.
if discardNewPer && maxMsgsPer > 0 {
// Allow rollup messages through since they will purge after storing.
if discardNewPer && maxMsgsPer > 0 && len(sliceHeader(JSMsgRollup, hdr)) == 0 {
// Get the current total for this subject.
totalMsgsForSubject := mset.store.SubjectsTotals(subject)[subject]
// Add inflight count in this batch and for this stream.
@@ -656,3 +998,68 @@ func checkMsgHeadersPreClusteredProposal(
return hdr, msg, 0, nil, nil
}
// recalculateClusteredSeq initializes or updates mset.clseq, for example after a leader change.
// This is reused for normal clustered publishing into a stream, and for atomic and fast batch publishing.
// mset.clMu lock must be held.
func recalculateClusteredSeq(mset *stream, needStreamLock bool) (lseq uint64) {
// Need to unlock and re-acquire the locks in the proper order.
mset.clMu.Unlock()
// Locking order is stream -> batchMu -> clMu
if needStreamLock {
mset.mu.RLock()
}
batch := mset.batchApply
var batchCount uint64
if batch != nil {
batch.mu.Lock()
batchCount = batch.count
}
mset.clMu.Lock()
// Re-capture
lseq = mset.lseq
mset.clseq = lseq + mset.clfs + batchCount
// Keep hold of the mset.clMu, but unlock the others.
if batch != nil {
batch.mu.Unlock()
}
if needStreamLock {
mset.mu.RUnlock()
}
return lseq
}
// commitSingleMsg commits and proposes a single message to the node.
// This is reused both for normal publishing into a stream, and for fast batch publishing.
// mset.clMu lock must be held.
func commitSingleMsg(
diff *batchStagedDiff, mset *stream, subject string, reply string, hdr []byte, msg []byte, name string,
jsa *jsAccount, mt *msgTrace, node RaftNode, replicas int, lseq uint64,
) error {
// Do proposal.
esm := encodeStreamMsgAllowCompress(subject, reply, hdr, msg, mset.clseq, time.Now().UnixNano(), false)
if err := node.Propose(esm); err != nil {
return err
}
var mtKey uint64
if mt != nil {
mtKey = mset.clseq
if mset.mt == nil {
mset.mt = make(map[uint64]*msgTrace)
}
mset.mt[mtKey] = mt
}
diff.commit(mset)
mset.clseq++
mset.trackReplicationTraffic(node, len(esm), replicas)
// Check to see if we are being overrun.
// TODO(dlc) - Make this a limit where we drop messages to protect ourselves, but allow to be configured.
if mset.clseq-(lseq+mset.clfs) > streamLagWarnThreshold {
lerr := fmt.Errorf("JetStream stream '%s > %s' has high message lag", jsa.acc().Name, name)
mset.srv.RateLimitWarnf("%s", lerr.Error())
}
return nil
}
File diff suppressed because it is too large Load Diff
@@ -29,12 +29,30 @@ const (
// JSAtomicPublishTooLargeBatchErrF atomic publish batch is too large: {size}
JSAtomicPublishTooLargeBatchErrF ErrorIdentifier = 10199
// JSAtomicPublishTooManyInflight atomic publish too many inflight
JSAtomicPublishTooManyInflight ErrorIdentifier = 10210
// JSAtomicPublishUnsupportedHeaderBatchErr atomic publish unsupported header used: {header}
JSAtomicPublishUnsupportedHeaderBatchErr ErrorIdentifier = 10177
// JSBadRequestErr bad request
JSBadRequestErr ErrorIdentifier = 10003
// JSBatchPublishDisabledErr batch publish is disabled
JSBatchPublishDisabledErr ErrorIdentifier = 10205
// JSBatchPublishInvalidBatchIDErr batch publish ID is invalid
JSBatchPublishInvalidBatchIDErr ErrorIdentifier = 10207
// JSBatchPublishInvalidPatternErr batch publish pattern is invalid
JSBatchPublishInvalidPatternErr ErrorIdentifier = 10206
// JSBatchPublishTooManyInflight batch publish too many inflight
JSBatchPublishTooManyInflight ErrorIdentifier = 10211
// JSBatchPublishUnknownBatchIDErr batch publish ID unknown
JSBatchPublishUnknownBatchIDErr ErrorIdentifier = 10208
// JSClusterIncompleteErr incomplete results
JSClusterIncompleteErr ErrorIdentifier = 10004
@@ -71,6 +89,21 @@ const (
// JSClusterUnSupportFeatureErr not currently supported in clustered mode
JSClusterUnSupportFeatureErr ErrorIdentifier = 10036
// JSConsumerAckFCRequiresFCErr flow control ack policy requires flow control
JSConsumerAckFCRequiresFCErr ErrorIdentifier = 10219
// JSConsumerAckFCRequiresMaxAckPendingErr flow control ack policy requires max ack pending
JSConsumerAckFCRequiresMaxAckPendingErr ErrorIdentifier = 10220
// JSConsumerAckFCRequiresNoAckWaitErr flow control ack policy requires unset ack wait
JSConsumerAckFCRequiresNoAckWaitErr ErrorIdentifier = 10221
// JSConsumerAckFCRequiresNoMaxDeliverErr flow control ack policy requires unset max deliver
JSConsumerAckFCRequiresNoMaxDeliverErr ErrorIdentifier = 10222
// JSConsumerAckFCRequiresPushErr flow control ack policy requires a push based consumer
JSConsumerAckFCRequiresPushErr ErrorIdentifier = 10218
// JSConsumerAckPolicyInvalidErr consumer ack policy invalid
JSConsumerAckPolicyInvalidErr ErrorIdentifier = 10181
@@ -167,6 +200,9 @@ const (
// JSConsumerInvalidPriorityGroupErr Provided priority group does not exist for this consumer
JSConsumerInvalidPriorityGroupErr ErrorIdentifier = 10160
// JSConsumerInvalidResetErr invalid reset: {err}
JSConsumerInvalidResetErr ErrorIdentifier = 10204
// JSConsumerInvalidSamplingErrF failed to parse consumer sampling configuration: {err}
JSConsumerInvalidSamplingErrF ErrorIdentifier = 10095
@@ -317,21 +353,36 @@ const (
// JSMessageSchedulesRollupInvalidErr message schedules invalid rollup
JSMessageSchedulesRollupInvalidErr ErrorIdentifier = 10192
// JSMessageSchedulesSchedulerInvalidErr message schedules invalid scheduler
JSMessageSchedulesSchedulerInvalidErr ErrorIdentifier = 10212
// JSMessageSchedulesSourceInvalidErr message schedules source is invalid
JSMessageSchedulesSourceInvalidErr ErrorIdentifier = 10203
// JSMessageSchedulesTTLInvalidErr message schedules invalid per-message TTL
JSMessageSchedulesTTLInvalidErr ErrorIdentifier = 10191
// JSMessageSchedulesTargetInvalidErr message schedules target is invalid
JSMessageSchedulesTargetInvalidErr ErrorIdentifier = 10190
// JSMessageSchedulesTimeZoneInvalidErr message schedules time zone is invalid
JSMessageSchedulesTimeZoneInvalidErr ErrorIdentifier = 10223
// JSMessageTTLDisabledErr per-message TTL is disabled
JSMessageTTLDisabledErr ErrorIdentifier = 10166
// JSMessageTTLInvalidErr invalid per-message TTL
JSMessageTTLInvalidErr ErrorIdentifier = 10165
// JSMirrorConsumerRequiresAckFCErr stream mirror consumer requires flow control ack policy
JSMirrorConsumerRequiresAckFCErr ErrorIdentifier = 10214
// JSMirrorConsumerSetupFailedErrF generic mirror consumer setup failure string ({err})
JSMirrorConsumerSetupFailedErrF ErrorIdentifier = 10029
// JSMirrorDurableConsumerCfgInvalid stream mirror consumer config is invalid
JSMirrorDurableConsumerCfgInvalid ErrorIdentifier = 10213
// JSMirrorInvalidStreamName mirrored stream name is invalid
JSMirrorInvalidStreamName ErrorIdentifier = 10142
@@ -353,6 +404,9 @@ const (
// JSMirrorWithAtomicPublishErr stream mirrors can not also use atomic publishing
JSMirrorWithAtomicPublishErr ErrorIdentifier = 10198
// JSMirrorWithBatchPublishErr stream mirrors can not also use batch publishing
JSMirrorWithBatchPublishErr ErrorIdentifier = 10209
// JSMirrorWithCountersErr stream mirrors can not also calculate counters
JSMirrorWithCountersErr ErrorIdentifier = 10173
@@ -416,12 +470,21 @@ const (
// JSSnapshotDeliverSubjectInvalidErr deliver subject not valid
JSSnapshotDeliverSubjectInvalidErr ErrorIdentifier = 10015
// JSSourceConsumerRequiresAckFCErr stream source consumer requires flow control ack policy
JSSourceConsumerRequiresAckFCErr ErrorIdentifier = 10217
// JSSourceConsumerSetupFailedErrF General source consumer setup failure string ({err})
JSSourceConsumerSetupFailedErrF ErrorIdentifier = 10045
// JSSourceDuplicateDetected source stream, filter and transform (plus external if present) must form a unique combination (duplicate source configuration detected)
JSSourceDuplicateDetected ErrorIdentifier = 10140
// JSSourceDurableConsumerCfgInvalid stream source consumer config is invalid
JSSourceDurableConsumerCfgInvalid ErrorIdentifier = 10215
// JSSourceDurableConsumerDuplicateDetected duplicate stream source consumer detected
JSSourceDurableConsumerDuplicateDetected ErrorIdentifier = 10216
// JSSourceInvalidStreamName sourced stream name is invalid
JSSourceInvalidStreamName ErrorIdentifier = 10141
@@ -619,8 +682,14 @@ var (
JSAtomicPublishInvalidBatchIDErr: {Code: 400, ErrCode: 10179, Description: "atomic publish batch ID is invalid"},
JSAtomicPublishMissingSeqErr: {Code: 400, ErrCode: 10175, Description: "atomic publish sequence is missing"},
JSAtomicPublishTooLargeBatchErrF: {Code: 400, ErrCode: 10199, Description: "atomic publish batch is too large: {size}"},
JSAtomicPublishTooManyInflight: {Code: 429, ErrCode: 10210, Description: "atomic publish too many inflight"},
JSAtomicPublishUnsupportedHeaderBatchErr: {Code: 400, ErrCode: 10177, Description: "atomic publish unsupported header used: {header}"},
JSBadRequestErr: {Code: 400, ErrCode: 10003, Description: "bad request"},
JSBatchPublishDisabledErr: {Code: 400, ErrCode: 10205, Description: "batch publish is disabled"},
JSBatchPublishInvalidBatchIDErr: {Code: 400, ErrCode: 10207, Description: "batch publish ID is invalid"},
JSBatchPublishInvalidPatternErr: {Code: 400, ErrCode: 10206, Description: "batch publish pattern is invalid"},
JSBatchPublishTooManyInflight: {Code: 429, ErrCode: 10211, Description: "batch publish too many inflight"},
JSBatchPublishUnknownBatchIDErr: {Code: 400, ErrCode: 10208, Description: "batch publish ID unknown"},
JSClusterIncompleteErr: {Code: 503, ErrCode: 10004, Description: "incomplete results"},
JSClusterNoPeersErrF: {Code: 400, ErrCode: 10005, Description: "{err}"},
JSClusterNotActiveErr: {Code: 500, ErrCode: 10006, Description: "JetStream not in clustered mode"},
@@ -633,6 +702,11 @@ var (
JSClusterServerNotMemberErr: {Code: 400, ErrCode: 10044, Description: "server is not a member of the cluster"},
JSClusterTagsErr: {Code: 400, ErrCode: 10011, Description: "tags placement not supported for operation"},
JSClusterUnSupportFeatureErr: {Code: 503, ErrCode: 10036, Description: "not currently supported in clustered mode"},
JSConsumerAckFCRequiresFCErr: {Code: 400, ErrCode: 10219, Description: "flow control ack policy requires flow control"},
JSConsumerAckFCRequiresMaxAckPendingErr: {Code: 400, ErrCode: 10220, Description: "flow control ack policy requires max ack pending"},
JSConsumerAckFCRequiresNoAckWaitErr: {Code: 400, ErrCode: 10221, Description: "flow control ack policy requires unset ack wait"},
JSConsumerAckFCRequiresNoMaxDeliverErr: {Code: 400, ErrCode: 10222, Description: "flow control ack policy requires unset max deliver"},
JSConsumerAckFCRequiresPushErr: {Code: 400, ErrCode: 10218, Description: "flow control ack policy requires a push based consumer"},
JSConsumerAckPolicyInvalidErr: {Code: 400, ErrCode: 10181, Description: "consumer ack policy invalid"},
JSConsumerAckWaitNegativeErr: {Code: 400, ErrCode: 10183, Description: "consumer ack wait needs to be positive"},
JSConsumerAlreadyExists: {Code: 400, ErrCode: 10148, Description: "consumer already exists"},
@@ -665,6 +739,7 @@ var (
JSConsumerInvalidGroupNameErr: {Code: 400, ErrCode: 10162, Description: "Valid priority group name must match A-Z, a-z, 0-9, -_/=)+ and may not exceed 16 characters"},
JSConsumerInvalidPolicyErrF: {Code: 400, ErrCode: 10094, Description: "{err}"},
JSConsumerInvalidPriorityGroupErr: {Code: 400, ErrCode: 10160, Description: "Provided priority group does not exist for this consumer"},
JSConsumerInvalidResetErr: {Code: 400, ErrCode: 10204, Description: "invalid reset: {err}"},
JSConsumerInvalidSamplingErrF: {Code: 400, ErrCode: 10095, Description: "failed to parse consumer sampling configuration: {err}"},
JSConsumerMaxDeliverBackoffErr: {Code: 400, ErrCode: 10116, Description: "max deliver is required to be > length of backoff values"},
JSConsumerMaxPendingAckExcessErrF: {Code: 400, ErrCode: 10121, Description: "consumer max ack pending exceeds system limit of {limit}"},
@@ -715,11 +790,16 @@ var (
JSMessageSchedulesDisabledErr: {Code: 400, ErrCode: 10188, Description: "message schedules is disabled"},
JSMessageSchedulesPatternInvalidErr: {Code: 400, ErrCode: 10189, Description: "message schedules pattern is invalid"},
JSMessageSchedulesRollupInvalidErr: {Code: 400, ErrCode: 10192, Description: "message schedules invalid rollup"},
JSMessageSchedulesSchedulerInvalidErr: {Code: 400, ErrCode: 10212, Description: "message schedules invalid scheduler"},
JSMessageSchedulesSourceInvalidErr: {Code: 400, ErrCode: 10203, Description: "message schedules source is invalid"},
JSMessageSchedulesTTLInvalidErr: {Code: 400, ErrCode: 10191, Description: "message schedules invalid per-message TTL"},
JSMessageSchedulesTargetInvalidErr: {Code: 400, ErrCode: 10190, Description: "message schedules target is invalid"},
JSMessageSchedulesTimeZoneInvalidErr: {Code: 400, ErrCode: 10223, Description: "message schedules time zone is invalid"},
JSMessageTTLDisabledErr: {Code: 400, ErrCode: 10166, Description: "per-message TTL is disabled"},
JSMessageTTLInvalidErr: {Code: 400, ErrCode: 10165, Description: "invalid per-message TTL"},
JSMirrorConsumerRequiresAckFCErr: {Code: 400, ErrCode: 10214, Description: "stream mirror consumer requires flow control ack policy"},
JSMirrorConsumerSetupFailedErrF: {Code: 500, ErrCode: 10029, Description: "{err}"},
JSMirrorDurableConsumerCfgInvalid: {Code: 400, ErrCode: 10213, Description: "stream mirror consumer config is invalid"},
JSMirrorInvalidStreamName: {Code: 400, ErrCode: 10142, Description: "mirrored stream name is invalid"},
JSMirrorInvalidSubjectFilter: {Code: 400, ErrCode: 10151, Description: "mirror transform source: {err}"},
JSMirrorInvalidTransformDestination: {Code: 400, ErrCode: 10154, Description: "mirror transform: {err}"},
@@ -727,6 +807,7 @@ var (
JSMirrorMultipleFiltersNotAllowed: {Code: 400, ErrCode: 10150, Description: "mirror with multiple subject transforms cannot also have a single subject filter"},
JSMirrorOverlappingSubjectFilters: {Code: 400, ErrCode: 10152, Description: "mirror subject filters can not overlap"},
JSMirrorWithAtomicPublishErr: {Code: 400, ErrCode: 10198, Description: "stream mirrors can not also use atomic publishing"},
JSMirrorWithBatchPublishErr: {Code: 400, ErrCode: 10209, Description: "stream mirrors can not also use batch publishing"},
JSMirrorWithCountersErr: {Code: 400, ErrCode: 10173, Description: "stream mirrors can not also calculate counters"},
JSMirrorWithFirstSeqErr: {Code: 400, ErrCode: 10143, Description: "stream mirrors can not have first sequence configured"},
JSMirrorWithMsgSchedulesErr: {Code: 400, ErrCode: 10186, Description: "stream mirrors can not also schedule messages"},
@@ -748,8 +829,11 @@ var (
JSRestoreSubscribeFailedErrF: {Code: 500, ErrCode: 10042, Description: "JetStream unable to subscribe to restore snapshot {subject}: {err}"},
JSSequenceNotFoundErrF: {Code: 400, ErrCode: 10043, Description: "sequence {seq} not found"},
JSSnapshotDeliverSubjectInvalidErr: {Code: 400, ErrCode: 10015, Description: "deliver subject not valid"},
JSSourceConsumerRequiresAckFCErr: {Code: 400, ErrCode: 10217, Description: "stream source consumer requires flow control ack policy"},
JSSourceConsumerSetupFailedErrF: {Code: 500, ErrCode: 10045, Description: "{err}"},
JSSourceDuplicateDetected: {Code: 400, ErrCode: 10140, Description: "duplicate source configuration detected"},
JSSourceDurableConsumerCfgInvalid: {Code: 400, ErrCode: 10215, Description: "stream source consumer config is invalid"},
JSSourceDurableConsumerDuplicateDetected: {Code: 400, ErrCode: 10216, Description: "duplicate stream source consumer detected"},
JSSourceInvalidStreamName: {Code: 400, ErrCode: 10141, Description: "sourced stream name is invalid"},
JSSourceInvalidSubjectFilter: {Code: 400, ErrCode: 10145, Description: "source transform source: {err}"},
JSSourceInvalidTransformDestination: {Code: 400, ErrCode: 10146, Description: "source transform: {err}"},
@@ -923,6 +1007,16 @@ func NewJSAtomicPublishTooLargeBatchError(size interface{}, opts ...ErrorOption)
}
}
// NewJSAtomicPublishTooManyInflightError creates a new JSAtomicPublishTooManyInflight error: "atomic publish too many inflight"
func NewJSAtomicPublishTooManyInflightError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSAtomicPublishTooManyInflight]
}
// NewJSAtomicPublishUnsupportedHeaderBatchError creates a new JSAtomicPublishUnsupportedHeaderBatchErr error: "atomic publish unsupported header used: {header}"
func NewJSAtomicPublishUnsupportedHeaderBatchError(header interface{}, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -949,6 +1043,56 @@ func NewJSBadRequestError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSBadRequestErr]
}
// NewJSBatchPublishDisabledError creates a new JSBatchPublishDisabledErr error: "batch publish is disabled"
func NewJSBatchPublishDisabledError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSBatchPublishDisabledErr]
}
// NewJSBatchPublishInvalidBatchIDError creates a new JSBatchPublishInvalidBatchIDErr error: "batch publish ID is invalid"
func NewJSBatchPublishInvalidBatchIDError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSBatchPublishInvalidBatchIDErr]
}
// NewJSBatchPublishInvalidPatternError creates a new JSBatchPublishInvalidPatternErr error: "batch publish pattern is invalid"
func NewJSBatchPublishInvalidPatternError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSBatchPublishInvalidPatternErr]
}
// NewJSBatchPublishTooManyInflightError creates a new JSBatchPublishTooManyInflight error: "batch publish too many inflight"
func NewJSBatchPublishTooManyInflightError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSBatchPublishTooManyInflight]
}
// NewJSBatchPublishUnknownBatchIDError creates a new JSBatchPublishUnknownBatchIDErr error: "batch publish ID unknown"
func NewJSBatchPublishUnknownBatchIDError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSBatchPublishUnknownBatchIDErr]
}
// NewJSClusterIncompleteError creates a new JSClusterIncompleteErr error: "incomplete results"
func NewJSClusterIncompleteError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -1075,6 +1219,56 @@ func NewJSClusterUnSupportFeatureError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSClusterUnSupportFeatureErr]
}
// NewJSConsumerAckFCRequiresFCError creates a new JSConsumerAckFCRequiresFCErr error: "flow control ack policy requires flow control"
func NewJSConsumerAckFCRequiresFCError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerAckFCRequiresFCErr]
}
// NewJSConsumerAckFCRequiresMaxAckPendingError creates a new JSConsumerAckFCRequiresMaxAckPendingErr error: "flow control ack policy requires max ack pending"
func NewJSConsumerAckFCRequiresMaxAckPendingError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerAckFCRequiresMaxAckPendingErr]
}
// NewJSConsumerAckFCRequiresNoAckWaitError creates a new JSConsumerAckFCRequiresNoAckWaitErr error: "flow control ack policy requires unset ack wait"
func NewJSConsumerAckFCRequiresNoAckWaitError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerAckFCRequiresNoAckWaitErr]
}
// NewJSConsumerAckFCRequiresNoMaxDeliverError creates a new JSConsumerAckFCRequiresNoMaxDeliverErr error: "flow control ack policy requires unset max deliver"
func NewJSConsumerAckFCRequiresNoMaxDeliverError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerAckFCRequiresNoMaxDeliverErr]
}
// NewJSConsumerAckFCRequiresPushError creates a new JSConsumerAckFCRequiresPushErr error: "flow control ack policy requires a push based consumer"
func NewJSConsumerAckFCRequiresPushError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSConsumerAckFCRequiresPushErr]
}
// NewJSConsumerAckPolicyInvalidError creates a new JSConsumerAckPolicyInvalidErr error: "consumer ack policy invalid"
func NewJSConsumerAckPolicyInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -1419,6 +1613,22 @@ func NewJSConsumerInvalidPriorityGroupError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSConsumerInvalidPriorityGroupErr]
}
// NewJSConsumerInvalidResetError creates a new JSConsumerInvalidResetErr error: "invalid reset: {err}"
func NewJSConsumerInvalidResetError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
e := ApiErrors[JSConsumerInvalidResetErr]
args := e.toReplacerArgs([]interface{}{"{err}", err})
return &ApiError{
Code: e.Code,
ErrCode: e.ErrCode,
Description: strings.NewReplacer(args...).Replace(e.Description),
}
}
// NewJSConsumerInvalidSamplingError creates a new JSConsumerInvalidSamplingErrF error: "failed to parse consumer sampling configuration: {err}"
func NewJSConsumerInvalidSamplingError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -1967,6 +2177,26 @@ func NewJSMessageSchedulesRollupInvalidError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMessageSchedulesRollupInvalidErr]
}
// NewJSMessageSchedulesSchedulerInvalidError creates a new JSMessageSchedulesSchedulerInvalidErr error: "message schedules invalid scheduler"
func NewJSMessageSchedulesSchedulerInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMessageSchedulesSchedulerInvalidErr]
}
// NewJSMessageSchedulesSourceInvalidError creates a new JSMessageSchedulesSourceInvalidErr error: "message schedules source is invalid"
func NewJSMessageSchedulesSourceInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMessageSchedulesSourceInvalidErr]
}
// NewJSMessageSchedulesTTLInvalidError creates a new JSMessageSchedulesTTLInvalidErr error: "message schedules invalid per-message TTL"
func NewJSMessageSchedulesTTLInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -1987,6 +2217,16 @@ func NewJSMessageSchedulesTargetInvalidError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMessageSchedulesTargetInvalidErr]
}
// NewJSMessageSchedulesTimeZoneInvalidError creates a new JSMessageSchedulesTimeZoneInvalidErr error: "message schedules time zone is invalid"
func NewJSMessageSchedulesTimeZoneInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMessageSchedulesTimeZoneInvalidErr]
}
// NewJSMessageTTLDisabledError creates a new JSMessageTTLDisabledErr error: "per-message TTL is disabled"
func NewJSMessageTTLDisabledError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -2007,6 +2247,16 @@ func NewJSMessageTTLInvalidError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMessageTTLInvalidErr]
}
// NewJSMirrorConsumerRequiresAckFCError creates a new JSMirrorConsumerRequiresAckFCErr error: "stream mirror consumer requires flow control ack policy"
func NewJSMirrorConsumerRequiresAckFCError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMirrorConsumerRequiresAckFCErr]
}
// NewJSMirrorConsumerSetupFailedError creates a new JSMirrorConsumerSetupFailedErrF error: "{err}"
func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -2023,6 +2273,16 @@ func NewJSMirrorConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiErr
}
}
// NewJSMirrorDurableConsumerCfgInvalidError creates a new JSMirrorDurableConsumerCfgInvalid error: "stream mirror consumer config is invalid"
func NewJSMirrorDurableConsumerCfgInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMirrorDurableConsumerCfgInvalid]
}
// NewJSMirrorInvalidStreamNameError creates a new JSMirrorInvalidStreamName error: "mirrored stream name is invalid"
func NewJSMirrorInvalidStreamNameError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -2105,6 +2365,16 @@ func NewJSMirrorWithAtomicPublishError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMirrorWithAtomicPublishErr]
}
// NewJSMirrorWithBatchPublishError creates a new JSMirrorWithBatchPublishErr error: "stream mirrors can not also use batch publishing"
func NewJSMirrorWithBatchPublishError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSMirrorWithBatchPublishErr]
}
// NewJSMirrorWithCountersError creates a new JSMirrorWithCountersErr error: "stream mirrors can not also calculate counters"
func NewJSMirrorWithCountersError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -2339,6 +2609,16 @@ func NewJSSnapshotDeliverSubjectInvalidError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSSnapshotDeliverSubjectInvalidErr]
}
// NewJSSourceConsumerRequiresAckFCError creates a new JSSourceConsumerRequiresAckFCErr error: "stream source consumer requires flow control ack policy"
func NewJSSourceConsumerRequiresAckFCError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSSourceConsumerRequiresAckFCErr]
}
// NewJSSourceConsumerSetupFailedError creates a new JSSourceConsumerSetupFailedErrF error: "{err}"
func NewJSSourceConsumerSetupFailedError(err error, opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
@@ -2365,6 +2645,26 @@ func NewJSSourceDuplicateDetectedError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSSourceDuplicateDetected]
}
// NewJSSourceDurableConsumerCfgInvalidError creates a new JSSourceDurableConsumerCfgInvalid error: "stream source consumer config is invalid"
func NewJSSourceDurableConsumerCfgInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSSourceDurableConsumerCfgInvalid]
}
// NewJSSourceDurableConsumerDuplicateDetectedError creates a new JSSourceDurableConsumerDuplicateDetected error: "duplicate stream source consumer detected"
func NewJSSourceDurableConsumerDuplicateDetectedError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}
return ApiErrors[JSSourceDurableConsumerDuplicateDetected]
}
// NewJSSourceInvalidStreamNameError creates a new JSSourceInvalidStreamName error: "sourced stream name is invalid"
func NewJSSourceInvalidStreamNameError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
+7 -7
View File
@@ -71,10 +71,9 @@ const (
// JSStreamActionAdvisory indicates that a stream was created, edited or deleted
type JSStreamActionAdvisory struct {
TypedEvent
Stream string `json:"stream"`
Action ActionAdvisoryType `json:"action"`
Template string `json:"template,omitempty"` // Deprecated: stream templates are deprecated and will be removed in a future version.
Domain string `json:"domain,omitempty"`
Stream string `json:"stream"`
Action ActionAdvisoryType `json:"action"`
Domain string `json:"domain,omitempty"`
}
const JSStreamActionAdvisoryType = "io.nats.jetstream.advisory.v1.stream_action"
@@ -269,9 +268,10 @@ type JSStreamBatchAbandonedAdvisory struct {
type BatchAbandonReason string
var (
BatchTimeout BatchAbandonReason = "timeout"
BatchLarge BatchAbandonReason = "large"
BatchIncomplete BatchAbandonReason = "incomplete"
BatchTimeout BatchAbandonReason = "timeout"
BatchLarge BatchAbandonReason = "large"
BatchIncomplete BatchAbandonReason = "incomplete"
BatchRequirementsNotMet BatchAbandonReason = "unsupported"
)
// JSConsumerLeaderElectedAdvisoryType is sent when the system elects a leader for a consumer.
+11 -1
View File
@@ -17,7 +17,7 @@ import "strconv"
const (
// JSApiLevel is the maximum supported JetStream API level for this server.
JSApiLevel int = 3
JSApiLevel int = 4
JSRequiredLevelMetadataKey = "_nats.req.level"
JSServerVersionMetadataKey = "_nats.ver"
@@ -82,6 +82,11 @@ func setStaticStreamMetadata(cfg *StreamConfig) {
requires(2)
}
// Fast batch publishing was added in v2.14 and requires API level 4.
if cfg.AllowBatchPublish {
requires(4)
}
cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel)
}
@@ -158,6 +163,11 @@ func setStaticConsumerMetadata(cfg *ConsumerConfig) {
requires(1)
}
// Added in 2.14
if cfg.AckPolicy == AckFlowControl {
requires(4)
}
cfg.Metadata[JSRequiredLevelMetadataKey] = strconv.Itoa(requiredApiLevel)
}
+44 -14
View File
@@ -202,12 +202,15 @@ func validateSrc(claims *jwt.UserClaims, host string) bool {
}
func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
return validateTimesAt(claims, time.Now())
}
func validateTimesAt(claims *jwt.UserClaims, now time.Time) (bool, time.Duration) {
if claims == nil {
return false, time.Duration(0)
} else if len(claims.Times) == 0 {
return true, time.Duration(0)
}
now := time.Now()
loc := time.Local
if claims.Locale != "" {
var err error
@@ -216,10 +219,11 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
}
now = now.In(loc)
}
var ok bool
var validFor time.Duration
for _, timeRange := range claims.Times {
y, m, d := now.Date()
m = m - 1
d = d - 1
start, err := time.ParseInLocation("15:04:05", timeRange.Start, loc)
if err != nil {
return false, time.Duration(0) // parsing not expected to fail at this point
@@ -228,17 +232,43 @@ func validateTimes(claims *jwt.UserClaims) (bool, time.Duration) {
if err != nil {
return false, time.Duration(0) // parsing not expected to fail at this point
}
if start.After(end) {
start = start.AddDate(y, int(m), d)
d++ // the intent is to be the next day
} else {
start = start.AddDate(y, int(m), d)
y, m, d := now.Date()
start = time.Date(y, m, d, start.Hour(), start.Minute(), start.Second(), 0, loc)
end = time.Date(y, m, d, end.Hour(), end.Minute(), end.Second(), 0, loc)
inRange, expires := validateTimeRangeAt(start, end, now)
if inRange && (!ok || expires > validFor) {
ok = true
validFor = expires
}
if start.Before(now) {
end = end.AddDate(y, int(m), d)
if end.After(now) {
return true, end.Sub(now)
}
}
return ok, validFor
}
// Returns true if now is within `start` and `end`, and
// how much time is left until `end`.
// False if `now` is not within range.
func validateTimeRangeAt(start, end, now time.Time) (bool, time.Duration) {
// Now falls within range.
// For example 11:00-22:00 at 13:00
if start.Before(now) && end.After(now) {
return true, end.Sub(now)
}
// Range crosses midnight.
if start.After(end) {
// Now is after midnight.
// For example 22:00-06:00 at 05:00.
if end.After(now) {
return true, end.Sub(now)
}
// Now is before midnight.
// For example 22:00-06:00 at 23:30.
end = end.AddDate(0, 0, 1)
if start.Before(now) && end.After(now) {
return true, end.Sub(now)
}
}
return false, time.Duration(0)
+317 -91
View File
@@ -1,4 +1,4 @@
// Copyright 2019-2025 The NATS Authors
// Copyright 2019-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -27,7 +27,6 @@ import (
"net/url"
"os"
"path"
"reflect"
"regexp"
"runtime"
"strconv"
@@ -115,21 +114,24 @@ type leafNodeCfg struct {
perms *Permissions
connDelay time.Duration // Delay before a connect, could be used while detecting loop condition, etc..
jsMigrateTimer *time.Timer
quitCh chan struct{}
removed bool
connInProgress bool
}
// Check to see if this is a solicited leafnode. We do special processing for solicited.
func (c *client) isSolicitedLeafNode() bool {
return c.kind == LEAF && c.leaf.remote != nil
return c.kind == LEAF && c.leaf != nil && c.leaf.remote != nil
}
// Returns true if this is a solicited leafnode and is not configured to be treated as a hub or a receiving
// connection leafnode where the otherside has declared itself to be the hub.
func (c *client) isSpokeLeafNode() bool {
return c.kind == LEAF && c.leaf.isSpoke
return c.kind == LEAF && c.leaf != nil && c.leaf.isSpoke
}
func (c *client) isHubLeafNode() bool {
return c.kind == LEAF && !c.leaf.isSpoke
return c.kind == LEAF && c.leaf != nil && !c.leaf.isSpoke
}
func (c *client) isIsolatedLeafNode() bool {
@@ -137,7 +139,7 @@ func (c *client) isIsolatedLeafNode() bool {
// group name here, which the hub and/or leaf could provide, so that we
// can isolate away certain LNs but not others on an opt-in basis. For
// now we will just isolate all LN interest until then.
return c.kind == LEAF && c.leaf.isolated
return c.kind == LEAF && c.leaf != nil && c.leaf.isolated
}
// This will spin up go routines to solicit the remote leaf node connections.
@@ -152,7 +154,10 @@ func (s *Server) solicitLeafNodeRemotes(remotes []*RemoteLeafOpts) {
remote := newLeafNodeCfg(r)
creds := remote.Credentials
accName := remote.LocalAccount
s.leafRemoteCfgs = append(s.leafRemoteCfgs, remote)
if s.leafRemoteCfgs == nil {
s.leafRemoteCfgs = make(map[*leafNodeCfg]struct{})
}
s.leafRemoteCfgs[remote] = struct{}{}
// Print notice if
if isSysAccRemote {
if len(remote.DenyExports) > 0 {
@@ -192,34 +197,30 @@ func (s *Server) solicitLeafNodeRemotes(remotes []*RemoteLeafOpts) {
// configuration required for configuration reload.
remote := addRemote(r, r.LocalAccount == sysAccName)
if !r.Disabled {
s.startGoRoutine(func() { s.connectToRemoteLeafNode(remote, true) })
s.connectToRemoteLeafNodeAsynchronously(remote, true)
}
}
}
func (s *Server) remoteLeafNodeStillValid(remote *leafNodeCfg) bool {
if remote.Disabled {
return false
}
for _, ri := range s.getOpts().LeafNode.Remotes {
// FIXME(dlc) - What about auth changes?
if reflect.DeepEqual(ri.URLs, remote.URLs) {
return true
}
}
return false
}
// Ensure that leafnode is properly configured.
func validateLeafNode(o *Options) error {
if err := validateLeafNodeAuthOptions(o); err != nil {
return err
}
// Users can bind to any local account, if its empty we will assume the $G account.
for _, r := range o.LeafNode.Remotes {
if r.LocalAccount == _EMPTY_ {
r.LocalAccount = globalAccountName
if len(o.LeafNode.Remotes) > 0 {
names := make(map[string]struct{})
// Check for duplicate remotes, also, users can bind to any local account,
// if its empty we will assume the $G account.
for _, r := range o.LeafNode.Remotes {
if r.LocalAccount == _EMPTY_ {
r.LocalAccount = globalAccountName
}
rn := r.name()
if _, dup := names[rn]; dup {
return fmt.Errorf("duplicate remote %s", r.safeName())
}
names[rn] = struct{}{}
}
}
@@ -428,42 +429,25 @@ func validateLeafNodeProxyOptions(remote *RemoteLeafOpts) ([]string, error) {
return warnings, nil
}
// Update remote LeafNode TLS configurations after a config reload.
func (s *Server) updateRemoteLeafNodesTLSConfig(opts *Options) {
max := len(opts.LeafNode.Remotes)
if max == 0 {
return
}
s.mu.RLock()
defer s.mu.RUnlock()
// Changes in the list of remote leaf nodes is not supported.
// However, make sure that we don't go over the arrays.
if len(s.leafRemoteCfgs) < max {
max = len(s.leafRemoteCfgs)
}
for i := 0; i < max; i++ {
ro := opts.LeafNode.Remotes[i]
cfg := s.leafRemoteCfgs[i]
if ro.TLSConfig != nil {
cfg.Lock()
cfg.TLSConfig = ro.TLSConfig.Clone()
cfg.TLSHandshakeFirst = ro.TLSHandshakeFirst
cfg.Unlock()
}
}
}
// Wait for the configured reconnect interval before attempting to connect
// again to the remote leafnode.
func (s *Server) reConnectToRemoteLeafNode(remote *leafNodeCfg) {
clearInProgress := true
defer func() {
s.grWG.Done()
if clearInProgress {
remote.setConnectInProgress(false)
}
}()
delay := s.getOpts().LeafNode.ReconnectInterval
select {
case <-time.After(delay):
case <-remote.quitCh:
return
case <-s.quitCh:
s.grWG.Done()
return
}
s.connectToRemoteLeafNode(remote, false)
clearInProgress = !connectToRemoteLeafNode(s, remote, false)
}
// Creates a leafNodeCfg object that wraps the RemoteLeafOpts.
@@ -471,6 +455,7 @@ func newLeafNodeCfg(remote *RemoteLeafOpts) *leafNodeCfg {
cfg := &leafNodeCfg{
RemoteLeafOpts: remote,
urls: make([]*url.URL, 0, len(remote.URLs)),
quitCh: make(chan struct{}, 1),
}
if len(remote.DenyExports) > 0 || len(remote.DenyImports) > 0 {
perms := &Permissions{}
@@ -506,6 +491,53 @@ func newLeafNodeCfg(remote *RemoteLeafOpts) *leafNodeCfg {
return cfg
}
// Notifies the quit channel without blocking.
// No lock is needed to invoke this function.
func (cfg *leafNodeCfg) notifyQuitChannel() {
select {
case cfg.quitCh <- struct{}{}:
default:
}
}
// Sets the connect-in-progress status for this remote leaf configuration.
func (cfg *leafNodeCfg) setConnectInProgress(inProgress bool) {
cfg.Lock()
defer cfg.Unlock()
// In both cases we want to drain the "quit" channel.
select {
case <-cfg.quitCh:
default:
}
cfg.connInProgress = inProgress
}
// Returns `true` if this remote is in the middle of a connect, `false` otherwise.
func (cfg *leafNodeCfg) isConnectInProgress() bool {
cfg.RLock()
defer cfg.RUnlock()
return cfg.connInProgress
}
// Mark that this remote is being removed from the configuration.
func (cfg *leafNodeCfg) markAsRemoved() {
cfg.Lock()
defer cfg.Unlock()
// This function should be invoked only once, but protect.
if cfg.removed {
return
}
cfg.removed = true
cfg.notifyQuitChannel()
}
// Returns false if it has been disabled or removed.
func (cfg *leafNodeCfg) stillValid() bool {
cfg.RLock()
defer cfg.RUnlock()
return !cfg.Disabled && !cfg.removed
}
// Will pick an URL from the list of available URLs.
func (cfg *leafNodeCfg) pickNextURL() *url.URL {
cfg.Lock()
@@ -622,12 +654,26 @@ func establishHTTPProxyTunnel(proxyURL, targetHost string, timeout time.Duration
return conn, nil
}
func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) {
defer s.grWG.Done()
// Connect to a remote leaf node asynchronously (that is, this function will do
// the connect in a go routine).
func (s *Server) connectToRemoteLeafNodeAsynchronously(remote *leafNodeCfg, firstConnect bool) {
remote.setConnectInProgress(true)
s.startGoRoutine(func() {
defer s.grWG.Done()
if !connectToRemoteLeafNode(s, remote, firstConnect) {
remote.setConnectInProgress(false)
}
})
}
// Connect to a remote leaf node. Should only be invoked from
// `s.connectToRemoteLeafNodeAsynchronously()` or `s.reConnectToRemoteLeafNode()`.
// Returns `true` if this function invoked `s.createLeafNode()`, false otherwise.
func connectToRemoteLeafNode(s *Server, remote *leafNodeCfg, firstConnect bool) bool {
if remote == nil || len(remote.URLs) == 0 {
s.Debugf("Empty remote leafnode definition, nothing to connect")
return
return false
}
opts := s.getOpts()
@@ -651,8 +697,10 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
if connDelay := remote.getConnectDelay(); connDelay > 0 {
select {
case <-time.After(connDelay):
case <-remote.quitCh:
return false
case <-s.quitCh:
return
return false
}
remote.setConnectDelay(0)
}
@@ -676,7 +724,14 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
attempts := 0
for s.isRunning() && s.remoteLeafNodeStillValid(remote) {
// In case the migrate timer was created but not canceled, do it when
// this function exits. Note that the timer would not be created if
// `jetstreamMigrateDelay == 0`.
if jetstreamMigrateDelay > 0 {
defer remote.cancelMigrateTimer()
}
for s.isRunning() && remote.stillValid() {
rURL := remote.pickNextURL()
url, err := s.getRandomIP(resolver, rURL.Host, nil)
if err == nil {
@@ -729,8 +784,9 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
remote.Unlock()
select {
case <-s.quitCh:
remote.cancelMigrateTimer()
return
return false
case <-remote.quitCh:
return false
case <-time.After(delay):
// Check if we should migrate any JetStream assets immediately while this remote is down.
// This will be used if JetStreamClusterMigrateDelay was not set
@@ -741,9 +797,11 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
}
}
remote.cancelMigrateTimer()
if !s.remoteLeafNodeStillValid(remote) {
// We can check here, but really we will have to check again when the server
// is about to add to the `s.leafs` map later in the process.
if !remote.stillValid() {
conn.Close()
return
return false
}
// We have a connection here to a remote server.
@@ -753,8 +811,10 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool)
// Clear any observer states if we had them.
s.clearObserverState(remote)
return
return true
}
return false
}
func (cfg *leafNodeCfg) cancelMigrateTimer() {
@@ -854,6 +914,8 @@ func (s *Server) isLeafConnectDisabled() bool {
// their remote connections did not have a tls{} block).
// We now save the host name regardless in case the remote returns an INFO indicating
// that TLS is required.
//
// Lock held on entry.
func (cfg *leafNodeCfg) saveTLSHostname(u *url.URL) {
if cfg.tlsName == _EMPTY_ && net.ParseIP(u.Hostname()) == nil {
cfg.tlsName = u.Hostname()
@@ -862,6 +924,8 @@ func (cfg *leafNodeCfg) saveTLSHostname(u *url.URL) {
// Save off the username/password for when we connect using a bare URL
// that we get from the INFO protocol.
//
// Lock held on entry.
func (cfg *leafNodeCfg) saveUserPassword(u *url.URL) {
if cfg.username == _EMPTY_ && u.User != nil {
cfg.username = u.User.Username()
@@ -1459,6 +1523,15 @@ func (c *client) processLeafnodeInfo(info *Info) {
// Check for compression, unless already done.
if firstINFO && !c.flags.isSet(compressionNegotiated) {
// A solicited leafnode connection must first receive a leafnode INFO.
// Classify wrong-port connections before any leaf-specific negotiation.
if didSolicit && (info.CID == 0 || info.LeafNodeURLs == nil) {
c.mu.Unlock()
c.Errorf(ErrConnectedToWrongPort.Error())
c.closeConnection(WrongPort)
return
}
// Prevent from getting back here.
c.flags.set(compressionNegotiated)
@@ -1536,15 +1609,6 @@ func (c *client) processLeafnodeInfo(info *Info) {
// ** Not if "no advertise" is enabled.
// *** Not if leafnode's "no advertise" is enabled.
//
// As seen from above, a solicited LeafNode connection should receive
// from the remote server an INFO with CID and LeafNodeURLs. Anything
// else should be considered an attempt to connect to a wrong port.
if didSolicit && (info.CID == 0 || info.LeafNodeURLs == nil) {
c.mu.Unlock()
c.Errorf(ErrConnectedToWrongPort.Error())
c.closeConnection(WrongPort)
return
}
// Reject a cluster that contains spaces.
if info.Cluster != _EMPTY_ && strings.Contains(info.Cluster, " ") {
c.mu.Unlock()
@@ -1552,8 +1616,12 @@ func (c *client) processLeafnodeInfo(info *Info) {
c.closeConnection(ProtocolViolation)
return
}
// Capture a nonce here.
c.nonce = []byte(info.Nonce)
// For solicited outbound leaf connections, capture the remote's nonce.
// For inbound leaf connections, keep using the server-issued nonce that
// was sent in our initial INFO and must be signed in CONNECT.
if didSolicit {
c.nonce = []byte(info.Nonce)
}
if info.TLSRequired && didSolicit {
remote.TLS = true
}
@@ -1578,15 +1646,17 @@ func (c *client) processLeafnodeInfo(info *Info) {
}
// For both initial INFO and async INFO protocols, Possibly
// update our list of remote leafnode URLs we can connect to.
if didSolicit && (len(info.LeafNodeURLs) > 0 || len(info.WSConnectURLs) > 0) {
// update our list of remote leafnode URLs we can connect to,
// unless we are instructed not to.
if didSolicit && !remote.IgnoreDiscoveredServers &&
(len(info.LeafNodeURLs) > 0 || len(info.WSConnectURLs) > 0) {
// Consider the incoming array as the most up-to-date
// representation of the remote cluster's list of URLs.
c.updateLeafNodeURLs(info)
}
// Check to see if we have permissions updates here.
if info.Import != nil || info.Export != nil {
// Only solicited leafnode connections trust permission updates from INFO.
if didSolicit && (info.Import != nil || info.Export != nil) {
perms := &Permissions{
Publish: info.Export,
Subscribe: info.Import,
@@ -1623,6 +1693,12 @@ func (c *client) processLeafnodeInfo(info *Info) {
// Check if we have the remote account information and if so make sure it's stored.
if info.RemoteAccount != _EMPTY_ {
if c.acc == nil {
c.mu.Unlock()
c.sendErr("Authorization Violation")
c.closeConnection(ProtocolViolation)
return
}
s.leafRemoteAccounts.Store(c.acc.Name, info.RemoteAccount)
}
c.mu.Unlock()
@@ -1807,7 +1883,7 @@ func (s *Server) setLeafNodeInfoHostPortAndIP() error {
// (this solves the stale connection situation). An error is returned to help the
// remote detect the misconfiguration when the duplicate is the result of that
// misconfiguration.
func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, checkForDup bool) {
func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, checkForDup bool) bool {
var accName string
c.mu.Lock()
cid := c.cid
@@ -1819,7 +1895,8 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
mySrvName := c.leaf.remoteServer
remoteAccName := c.leaf.remoteAccName
myClustName := c.leaf.remoteCluster
solicited := c.leaf.remote != nil
remote := c.leaf.remote
solicited := remote != nil
c.mu.Unlock()
var old *client
@@ -1843,6 +1920,23 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
}
}
}
// Now that we are under the server lock and before adding it to the map,
// for a solicited leaf, we need to make sure that it has not been removed
// from the config or disabled.
if solicited {
// If no longer valid, do not add to the server map. The connection
// should have been marked so that it can't reconnect. When the caller
// calls closeConnection(), cleanup (including clearing the connect-
// in-progress flag) will occur at the appropriate time.
if !remote.stillValid() {
// Prevent reconnect in case it was not yet done.
c.setNoReconnect()
s.mu.Unlock()
s.removeFromTempClients(cid)
return false
}
remote.setConnectInProgress(false)
}
// Store new connection in the map
s.leafs[cid] = c
s.mu.Unlock()
@@ -1891,7 +1985,7 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
} else if domain, ok := opts.JsAccDefaultDomain[accName]; ok && domain == _EMPTY_ {
// for backwards compatibility with old setups that do not have a domain name set
c.Debugf("Skipping deny %q for account %q due to default domain", jsAllAPI, accName)
return
return true
}
}
@@ -1969,9 +2063,11 @@ func (s *Server) addLeafNodeConnection(c *client, srvName, clusterName string, c
c.Debugf("Adding deny %q for outgoing messages to account %q", src, accName)
}
}
return true
}
func (s *Server) removeLeafNodeConnection(c *client) {
s.mu.Lock()
c.mu.Lock()
cid := c.cid
if c.leaf != nil {
@@ -1984,10 +2080,18 @@ func (s *Server) removeLeafNodeConnection(c *client) {
// We need to set this to nil for GC to release the connection
c.leaf.gwSub = nil
}
if remote := c.leaf.remote; remote != nil {
// If "noReconnect" is true, then we won't attempt to reconnect, so
// we will clear the "connect-in-progress" flag. However, if we can
// reconnect, then we should set "connect-in-progress" to true while
// we are under the server/client lock. The go routine that performs
// the reconnect will be started later and there would be a gap with
// the wrong flag value otherwise.
remote.setConnectInProgress(!c.flags.isSet(noReconnect))
}
}
proxyKey := c.proxyKey
c.mu.Unlock()
s.mu.Lock()
delete(s.leafs, cid)
if proxyKey != _EMPTY_ {
s.removeProxiedConn(proxyKey, cid)
@@ -2154,6 +2258,13 @@ func (c *client) processLeafNodeConnect(s *Server, arg []byte, lang string) erro
acc := c.acc
c.mu.Unlock()
// If the account is not set (e.g. connection was closed due to auth
// timeout while still being processed), bail out to avoid a panic.
if acc == nil {
c.closeConnection(MissingAccount)
return ErrMissingAccount
}
// Register the cluster, even if empty, as long as we are acting as a hub.
if !proto.Hub {
acc.registerLeafNodeCluster(proto.Cluster)
@@ -2999,6 +3110,11 @@ func (c *client) processLeafHeaderMsgArgs(arg []byte) error {
if c.pa.hdr > c.pa.size {
return fmt.Errorf("processLeafHeaderMsgArgs Header Size larger then TotalSize: '%s'", arg)
}
maxPayload := atomic.LoadInt32(&c.mpay)
if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) {
c.maxPayloadViolation(c.pa.size, maxPayload)
return ErrMaxPayload
}
// Common ones processed after check for arg length
c.pa.subject = args[0]
@@ -3068,6 +3184,11 @@ func (c *client) processLeafMsgArgs(arg []byte) error {
if c.pa.size < 0 {
return fmt.Errorf("processLeafMsgArgs Bad or Missing Size: '%s'", args)
}
maxPayload := atomic.LoadInt32(&c.mpay)
if maxPayload != jwt.NoLimit && int64(c.pa.size) > int64(maxPayload) {
c.maxPayloadViolation(c.pa.size, maxPayload)
return ErrMaxPayload
}
// Common ones processed after check for arg length
c.pa.subject = args[0]
@@ -3089,6 +3210,12 @@ func (c *client) processInboundLeafMsg(msg []byte) {
return
}
// Check that leaf messages respect the subject permissions.
if c.perms != nil && !c.leafMsgAllowed() {
c.leafPubPermViolation(c.pa.subject)
return
}
// Match the subscriptions. We will use our own L1 map if
// it's still valid, avoiding contention on the shared sublist.
var r *SublistResult
@@ -3150,12 +3277,102 @@ func (c *client) processInboundLeafMsg(msg []byte) {
}
}
// Checks whether the inbound leaf message is allowed by the
// connection's permissions. On the hub side this enforces what
// the remote leaf may publish. On the spoke side this enforces
// import restrictions such as deny_imports.
func (c *client) leafMsgAllowed() bool {
wireSubject := c.pa.subject
if len(c.pa.mapped) > 0 {
// Mappings rewrite c.pa.subject to the internal
// destination. For leaf ACLs, need to check
// the original wire subject from the remote side.
wireSubject = c.pa.mapped
}
// Strip any gateway routing prefix for the permission check.
subjectToCheck, isGW := getGWRoutedSubjectOrSelf(wireSubject)
// Service-import replies (_R_), JS ack subjects ($JS.ACK.)
// are internal routing subjects forwarded via LS+ without
// permission checks.
if isServiceReply(subjectToCheck) || isJSAckSubject(subjectToCheck) {
return true
}
c.mu.Lock()
defer c.mu.Unlock()
if c.isSpokeLeafNode() {
// Gateway routed replies are forwarded without
// permission checks.
if isGW || c.leafReceiveAllowed(subjectToCheck) {
return true
}
} else if c.leafSendAllowed(subjectToCheck) {
return true
}
// Check tracked reply permissions (allow_responses).
// Use the pre-strip subject since deliverMsg tracks
// replies under the original form, which includes
// the GW routing prefix for routed requests.
return c.responseAllowed(bytesToString(wireSubject))
}
// Returns true if the leaf side ACLs allow importing this subject,
// based on the permissions received over INFO and any local deny_imports.
// Lock must be held.
func (c *client) leafReceiveAllowed(subject []byte) bool {
return c.canSubscribe(bytesToString(subject))
}
// Returns true if the hub side ACLs allow the remote leaf to send
// this subject.
// Lock must be held.
func (c *client) leafSendAllowed(bsubject []byte) bool {
// Use the original export ACL captured for this accepted leaf.
// The live perms also contain additional JetStream denies used by
// the normal forwarding path, and applying them here would reject
// legitimate inbound JS API requests.
subject := bytesToString(bsubject)
perms := c.opts.Export
if perms == nil || (perms.Allow == nil && perms.Deny == nil) {
return true
}
allowed := true
if perms.Allow != nil && !strings.HasPrefix(subject, mqttPrefix) {
allowed = false
for _, allowSubj := range perms.Allow {
if matchLiteral(subject, allowSubj) {
allowed = true
break
}
}
}
if allowed && len(perms.Deny) > 0 {
for _, denySubj := range perms.Deny {
if matchLiteral(subject, denySubj) {
allowed = false
break
}
}
}
return allowed
}
// Handles a subscription permission violation.
// See leafPermViolation() for details.
func (c *client) leafSubPermViolation(subj []byte) {
c.leafPermViolation(false, subj)
}
// Handles a publish permission violation.
// See leafPermViolation() for details.
func (c *client) leafPubPermViolation(subj []byte) {
c.leafPermViolation(true, subj)
}
// Common function to process publish or subscribe leafnode permission violation.
// Sends the permission violation error to the remote, logs it and closes the connection.
// If this is from a server soliciting, the reconnection will be delayed.
@@ -3433,6 +3650,12 @@ func (s *Server) leafNodeFinishConnectProcess(c *client) {
return
}
remote := c.leaf.remote
if remote == nil || c.acc == nil {
c.mu.Unlock()
c.sendErr("Authorization Violation")
c.closeConnection(ProtocolViolation)
return
}
// Check if we will need to send the system connect event.
remote.RLock()
sendSysConnectEvent := remote.Hub
@@ -3457,19 +3680,22 @@ func (s *Server) leafNodeFinishConnectProcess(c *client) {
c.closeConnection(ProtocolViolation)
return
}
s.addLeafNodeConnection(c, _EMPTY_, _EMPTY_, false)
if !s.addLeafNodeConnection(c, _EMPTY_, _EMPTY_, false) {
// Was not added, could be because the remote configuration has been removed.
c.closeConnection(ClientClosed)
return
}
s.initLeafNodeSmapAndSendSubs(c)
if sendSysConnectEvent {
s.sendLeafNodeConnect(acc)
}
s.accountConnectEvent(c)
// The above functions are not atomically under the client
// lock doing those operations. It is possible - since we
// have started the read/write loops - that the connection
// is closed before or in between. This would leave the
// closed LN connection possible registered with the account
// and/or the server's leafs map. So check if connection
// is closed, and if so, manually cleanup.
// The above functions are not running under the client lock, so it is
// possible that between the time we have started the read/write loops
// and now, that the connection was closed. This would leave the closed
// LN connection possibly registered with the account and/or the server's
// leafs map. So check if connection is closed, and if so, manually cleanup.
c.mu.Lock()
closed := c.isClosed()
if !closed {
+8
View File
@@ -227,6 +227,14 @@ func (s *Server) rateLimitFormatWarnf(format string, v ...any) {
s.Warnf("%s", statement)
}
func (s *Server) RateLimitErrorf(format string, v ...any) {
statement := fmt.Sprintf(format, v...)
if _, loaded := s.rateLimitLogging.LoadOrStore(statement, time.Now()); loaded {
return
}
s.Errorf("%s", statement)
}
func (s *Server) RateLimitWarnf(format string, v ...any) {
statement := fmt.Sprintf(format, v...)
if _, loaded := s.rateLimitLogging.LoadOrStore(statement, time.Now()); loaded {
+137 -67
View File
@@ -184,7 +184,7 @@ func (ms *memStore) recoverMsgSchedulingState() {
if len(sm.hdr) == 0 {
continue
}
if schedule, ok := getMessageSchedule(sm.hdr); ok && !schedule.IsZero() {
if schedule, apiErr := nextMessageSchedule(sm.hdr, sm.ts); apiErr == nil && !schedule.IsZero() {
ms.scheduling.init(seq, sm.subj, schedule.UnixNano())
}
}
@@ -192,7 +192,7 @@ func (ms *memStore) recoverMsgSchedulingState() {
// Stores a raw message with expected sequence number and timestamp.
// Lock should be held.
func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error {
func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64, discardNewCheck bool) error {
if ms.msgs == nil {
return ErrStoreClosed
}
@@ -208,31 +208,31 @@ func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
}
// Check if we are discarding new messages when we reach the limit.
if ms.cfg.Discard == DiscardNew {
if asl && ms.cfg.DiscardNewPer {
// If we are clustered, we do the enforcement above and should not disqualify
// the message here since it could cause replicas to drift.
if discardNewCheck && ms.cfg.Discard == DiscardNew {
// Allow rollup messages through since they will purge old
// messages for the subject after storing, restoring the limit.
if asl && ms.cfg.DiscardNewPer && len(sliceHeader(JSMsgRollup, hdr)) == 0 {
return ErrMaxMsgsPerSubject
}
// If we are discard new and limits policy and clustered, we do the enforcement
// above and should not disqualify the message here since it could cause replicas to drift.
if ms.cfg.Retention == LimitsPolicy || ms.cfg.Replicas == 1 {
if ms.cfg.MaxMsgs > 0 && ms.state.Msgs >= uint64(ms.cfg.MaxMsgs) {
// If we are tracking max messages per subject and are at the limit we will replace, so this is ok.
if !asl {
return ErrMaxMsgs
}
if ms.cfg.MaxMsgs > 0 && ms.state.Msgs >= uint64(ms.cfg.MaxMsgs) {
// If we are tracking max messages per subject and are at the limit we will replace, so this is ok.
if !asl {
return ErrMaxMsgs
}
if ms.cfg.MaxBytes > 0 && ms.state.Bytes+memStoreMsgSize(subj, hdr, msg) >= uint64(ms.cfg.MaxBytes) {
if !asl {
return ErrMaxBytes
}
// If we are here we are at a subject maximum, need to determine if dropping last message gives us enough room.
if ss.firstNeedsUpdate || ss.lastNeedsUpdate {
ms.recalculateForSubj(subj, ss)
}
sm, ok := ms.msgs[ss.First]
if !ok || memStoreMsgSize(sm.subj, sm.hdr, sm.msg) < memStoreMsgSize(subj, hdr, msg) {
return ErrMaxBytes
}
}
if ms.cfg.MaxBytes > 0 && ms.state.Bytes+memStoreMsgSize(subj, hdr, msg) > uint64(ms.cfg.MaxBytes) {
if !asl {
return ErrMaxBytes
}
// If we are here we are at a subject maximum, need to determine if dropping last message gives us enough room.
if ss.firstNeedsUpdate || ss.lastNeedsUpdate {
ms.recalculateForSubj(subj, ss)
}
sm, ok := ms.msgs[ss.First]
if !ok || memStoreMsgSize(sm.subj, sm.hdr, sm.msg) < memStoreMsgSize(subj, hdr, msg) {
return ErrMaxBytes
}
}
}
@@ -309,20 +309,28 @@ func (ms *memStore) storeRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
// Message scheduling.
if ms.scheduling != nil {
if schedule, ok := getMessageSchedule(hdr); ok && !schedule.IsZero() {
if schedule, apiErr := nextMessageSchedule(hdr, ts); apiErr == nil && !schedule.IsZero() {
ms.scheduling.add(seq, subj, schedule.UnixNano())
} else {
} else if getMessageScheduler(hdr) == _EMPTY_ {
ms.scheduling.removeSubject(subj)
}
// Check for a repeating schedule and update such that it triggers again.
if scheduleNext := bytesToString(sliceHeader(JSScheduleNext, hdr)); scheduleNext != _EMPTY_ && scheduleNext != JSScheduleNextPurge {
scheduler := getMessageScheduler(hdr)
if next, err := time.Parse(time.RFC3339Nano, scheduleNext); err == nil && scheduler != _EMPTY_ {
ms.scheduling.update(scheduler, next.UnixNano())
}
}
}
return nil
}
// StoreRawMsg stores a raw message with expected sequence number and timestamp.
func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64) error {
func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, ttl int64, discardNewCheck bool) error {
ms.mu.Lock()
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl)
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl, discardNewCheck)
cb := ms.scb
// Check if first message timestamp requires expiry
// sooner than initial replica expiry timer set to MaxAge when initializing.
@@ -344,7 +352,8 @@ func (ms *memStore) StoreRawMsg(subj string, hdr, msg []byte, seq uint64, ts, tt
func (ms *memStore) StoreMsg(subj string, hdr, msg []byte, ttl int64) (uint64, int64, error) {
ms.mu.Lock()
seq, ts := ms.state.LastSeq+1, time.Now().UnixNano()
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl)
// This is called for a R1 with no expected sequence number, so perform DiscardNew checks on the store-level.
err := ms.storeRawMsg(subj, hdr, msg, seq, ts, ttl, true)
cb := ms.scb
ms.mu.Unlock()
@@ -414,8 +423,9 @@ func (ms *memStore) SkipMsgs(seq uint64, num uint64) error {
}
// FlushAllPending flushes all data that was still pending to be written.
func (ms *memStore) FlushAllPending() {
func (ms *memStore) FlushAllPending() error {
// Noop, in-memory store doesn't use async applying.
return nil
}
// RegisterStorageUpdates registers a callback for updates to storage changes.
@@ -521,13 +531,13 @@ loop:
}
// FilteredState will return the SimpleState associated with the filtered subject and a proposed starting sequence.
func (ms *memStore) FilteredState(sseq uint64, subj string) SimpleState {
func (ms *memStore) FilteredState(sseq uint64, subj string) (SimpleState, error) {
// This needs to be a write lock, as filteredStateLocked can
// mutate the per-subject state.
ms.mu.Lock()
defer ms.mu.Unlock()
return ms.filteredStateLocked(sseq, subj, false)
return ms.filteredStateLocked(sseq, subj, false), nil
}
func (ms *memStore) filteredStateLocked(sseq uint64, filter string, lastPerSubject bool) SimpleState {
@@ -1391,10 +1401,16 @@ func (ms *memStore) runMsgScheduling() {
}
ms.scheduling.running = true
scheduledMsgs := ms.scheduling.getScheduledMessages(func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadMsgLocked(seq, smv, false)
return sm
})
scheduledMsgs := ms.scheduling.getScheduledMessages(
func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadMsgLocked(seq, smv, false)
return sm
},
func(subj string, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadLastLocked(subj, smv)
return sm
},
)
if len(scheduledMsgs) > 0 {
ms.mu.Unlock()
for _, msg := range scheduledMsgs {
@@ -1429,7 +1445,7 @@ func (ms *memStore) PurgeEx(subject string, sequence, keep uint64) (purged uint6
}
eq := compareFn(subject)
if ss := ms.FilteredState(1, subject); ss.Msgs > 0 {
if ss, _ := ms.FilteredState(1, subject); ss.Msgs > 0 {
if keep > 0 {
if keep >= ss.Msgs {
return 0, nil
@@ -1712,13 +1728,17 @@ func (ms *memStore) loadMsgLocked(seq uint64, smp *StoreMsg, needMSLock bool) (*
// LoadLastMsg will return the last message we have that matches a given subject.
// The subject can be a wildcard.
func (ms *memStore) LoadLastMsg(subject string, smp *StoreMsg) (*StoreMsg, error) {
var sm *StoreMsg
var ok bool
// This needs to be a write lock, as filteredStateLocked can
// mutate the per-subject state.
ms.mu.Lock()
defer ms.mu.Unlock()
return ms.loadLastLocked(subject, smp)
}
// Lock should be held.
func (ms *memStore) loadLastLocked(subject string, smp *StoreMsg) (*StoreMsg, error) {
var sm *StoreMsg
var ok bool
if subject == _EMPTY_ || subject == fwcs {
sm, ok = ms.msgs[ms.state.LastSeq]
@@ -1907,31 +1927,41 @@ func (ms *memStore) loadNextMsgLocked(filter string, wc bool, start uint64, smp
return nil, ms.state.LastSeq, ErrStoreEOF
}
// Will load the next non-deleted msg starting at the start sequence and walking backwards.
func (ms *memStore) LoadPrevMsg(start uint64, smp *StoreMsg) (sm *StoreMsg, err error) {
// Will load the previous message matching the filter subject, starting at the start sequence and walking backwards.
func (ms *memStore) LoadPrevMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error) {
ms.mu.RLock()
defer ms.mu.RUnlock()
if ms.msgs == nil {
return nil, ErrStoreClosed
return nil, 0, ErrStoreClosed
}
if ms.state.Msgs == 0 || start < ms.state.FirstSeq {
return nil, ErrStoreEOF
return nil, ms.state.FirstSeq, ErrStoreEOF
}
if start > ms.state.LastSeq {
start = ms.state.LastSeq
}
if filter == _EMPTY_ {
filter = fwcs
wc = true
}
isAll := filter == fwcs
eq := subjectsEqual
if wc {
eq = matchLiteral
}
for seq := start; seq >= ms.state.FirstSeq; seq-- {
if sm, ok := ms.msgs[seq]; ok {
if sm, ok := ms.msgs[seq]; ok && (isAll || eq(sm.subj, filter)) {
if smp == nil {
smp = new(StoreMsg)
}
sm.copy(smp)
return smp, nil
return smp, seq, nil
}
}
return nil, ErrStoreEOF
return nil, ms.state.FirstSeq, ErrStoreEOF
}
// LoadPrevMsgMulti will find the previous message matching any entry in the sublist.
@@ -1965,7 +1995,7 @@ func (ms *memStore) LoadPrevMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *S
return smp, nseq, nil
}
}
return nil, ms.state.LastSeq, ErrStoreEOF
return nil, ms.state.FirstSeq, ErrStoreEOF
}
// RemoveMsg will remove the message from this store.
@@ -2329,10 +2359,7 @@ func (ms *memStore) EncodedStreamState(failed uint64) ([]byte, error) {
b := buf[0:n]
if numDeleted > 0 {
buf, err := ms.dmap.Encode(nil)
if err != nil {
return nil, err
}
buf := ms.dmap.Encode(nil)
b = append(b, buf...)
}
@@ -2340,7 +2367,11 @@ func (ms *memStore) EncodedStreamState(failed uint64) ([]byte, error) {
}
// SyncDeleted will make sure this stream has same deleted state as dbs.
func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
func (ms *memStore) SyncDeleted(dbs DeleteBlocks) error {
if len(dbs) == 0 {
return nil
}
ms.mu.Lock()
defer ms.mu.Unlock()
@@ -2349,7 +2380,7 @@ func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
if len(dbs) == 1 {
min, max, num := ms.dmap.State()
if pmin, pmax, pnum := dbs[0].State(); pmin == min && pmax == max && pnum == num {
return
return nil
}
}
lseq := ms.state.LastSeq
@@ -2363,6 +2394,7 @@ func (ms *memStore) SyncDeleted(dbs DeleteBlocks) {
return true
})
}
return nil
}
func (o *consumerMemStore) Update(state *ConsumerState) error {
@@ -2410,10 +2442,51 @@ func (o *consumerMemStore) Update(state *ConsumerState) error {
return nil
}
func (o *consumerMemStore) ForceUpdate(state *ConsumerState) error {
// Sanity checks.
if state.AckFloor.Consumer > state.Delivered.Consumer {
return fmt.Errorf("bad ack floor for consumer")
}
if state.AckFloor.Stream > state.Delivered.Stream {
return fmt.Errorf("bad ack floor for stream")
}
// Copy to our state.
var pending map[uint64]*Pending
var redelivered map[uint64]uint64
if len(state.Pending) > 0 {
pending = make(map[uint64]*Pending, len(state.Pending))
for seq, p := range state.Pending {
pending[seq] = &Pending{p.Sequence, p.Timestamp}
if seq <= state.AckFloor.Stream || seq > state.Delivered.Stream {
return fmt.Errorf("bad pending entry, sequence [%d] out of range", seq)
}
}
}
if len(state.Redelivered) > 0 {
redelivered = make(map[uint64]uint64, len(state.Redelivered))
for seq, dc := range state.Redelivered {
redelivered[seq] = dc
}
}
// Replace our state.
o.mu.Lock()
defer o.mu.Unlock()
o.state.Delivered = state.Delivered
o.state.AckFloor = state.AckFloor
o.state.Pending = pending
o.state.Redelivered = redelivered
return nil
}
// SetStarting sets our starting stream sequence.
func (o *consumerMemStore) SetStarting(sseq uint64) error {
o.mu.Lock()
o.state.Delivered.Stream = sseq
o.state.AckFloor.Stream = sseq
o.mu.Unlock()
return nil
}
@@ -2432,6 +2505,14 @@ func (o *consumerMemStore) UpdateStarting(sseq uint64) {
}
}
// Reset all values in the store, and reset the starting sequence.
func (o *consumerMemStore) Reset(sseq uint64) error {
o.mu.Lock()
o.state = ConsumerState{}
o.mu.Unlock()
return o.SetStarting(sseq)
}
// HasState returns if this store has a recorded state.
func (o *consumerMemStore) HasState() bool {
o.mu.Lock()
@@ -2524,8 +2605,8 @@ func (o *consumerMemStore) UpdateAcks(dseq, sseq uint64) error {
return ErrStoreMsgNotFound
}
// Check for AckAll here.
if o.cfg.AckPolicy == AckAll {
// Check for AckAll here (or AckFlowControl which functions like AckAll).
if o.cfg.AckPolicy == AckAll || o.cfg.AckPolicy == AckFlowControl {
sgap := sseq - o.state.AckFloor.Stream
o.state.AckFloor.Consumer = dseq
o.state.AckFloor.Stream = sseq
@@ -2675,14 +2756,3 @@ func (o *consumerMemStore) copyRedelivered() map[uint64]uint64 {
// Type returns the type of the underlying store.
func (o *consumerMemStore) Type() StorageType { return MemoryStorage }
// Templates
type templateMemStore struct{}
func newTemplateMemStore() *templateMemStore {
return &templateMemStore{}
}
// No-ops for memstore.
func (ts *templateMemStore) Store(t *streamTemplate) error { return nil }
func (ts *templateMemStore) Delete(t *streamTemplate) error { return nil }
+101 -61
View File
@@ -189,6 +189,17 @@ func newSubsList(client *client) []string {
return subs
}
func redactBearerJWT(userJWT string) string {
if userJWT == _EMPTY_ {
return _EMPTY_
}
uc, err := jwt.DecodeUserClaims(userJWT)
if err == nil && uc != nil && uc.BearerToken {
return _EMPTY_
}
return userJWT
}
// Connz returns a Connz struct containing information about connections.
func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
var (
@@ -441,6 +452,7 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
ci.NameTag = client.acc.getNameTag()
}
client.mu.Unlock()
ci.JWT = redactBearerJWT(ci.JWT)
pconns[i] = ci
i++
}
@@ -487,6 +499,7 @@ func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) {
cc.NameTag = acc.getNameTag()
}
}
cc.JWT = redactBearerJWT(cc.JWT)
}
pconns[i] = &cc.ConnInfo
i++
@@ -1271,6 +1284,7 @@ type Varz struct {
ConfigDigest string `json:"config_digest"` // ConfigDigest is a calculated hash of the current configuration
Tags jwt.TagList `json:"tags,omitempty"` // Tags are the tags assigned to the server in configuration
Metadata map[string]string `json:"metadata,omitempty"` // Metadata is the metadata assigned to the server in configuration
FeatureFlags map[string]bool `json:"feature_flags,omitempty"` // FeatureFlags is the feature flags enabled/disabled in configuration
TrustedOperatorsJwt []string `json:"trusted_operators_jwt,omitempty"` // TrustedOperatorsJwt is the JWTs for all trusted operators
TrustedOperatorsClaim []*jwt.OperatorClaims `json:"trusted_operators_claim,omitempty"` // TrustedOperatorsClaim is the decoded claims for each trusted operator
SystemAccount string `json:"system_account,omitempty"` // SystemAccount is the name of the System account
@@ -1570,8 +1584,13 @@ func (s *Server) updateJszVarz(js *jetStream, v *JetStreamVarz, doConfig bool) {
v.Meta.Replicas = ci.Replicas
}
if ipq := s.jsAPIRoutedReqs; ipq != nil {
v.Meta.Pending = ipq.len()
v.Meta.PendingRequests = ipq.len()
}
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
v.Meta.PendingInfos = ipq.len()
}
v.Meta.Pending = v.Meta.PendingRequests + v.Meta.PendingInfos
v.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
}
}
}
@@ -1788,6 +1807,7 @@ func (s *Server) updateVarzConfigReloadableFields(v *Varz) {
v.ConfigDigest = opts.configDigest
v.Tags = opts.Tags
v.Metadata = opts.Metadata
v.FeatureFlags = opts.getMergedFeatureFlags()
// Update route URLs if applicable
if s.varzUpdateRouteURLs {
v.Cluster.URLs = urlsToStrings(opts.Routes)
@@ -3008,15 +3028,43 @@ type MetaSnapshotStats struct {
LastDuration time.Duration `json:"last_duration,omitempty"` // LastDuration is how long the last meta snapshot took
}
// metaClusterSnapshotStats returns snapshot statistics for the meta group.
func (s *Server) metaClusterSnapshotStats(js *jetStream, mg RaftNode) *MetaSnapshotStats {
entries, bytes := mg.Size()
snap := &MetaSnapshotStats{
PendingEntries: entries,
PendingSize: bytes,
}
js.mu.RLock()
cluster := js.cluster
js.mu.RUnlock()
if cluster != nil {
timeNanos := atomic.LoadInt64(&cluster.lastMetaSnapTime)
durationNanos := atomic.LoadInt64(&cluster.lastMetaSnapDuration)
if timeNanos > 0 {
snap.LastTime = time.Unix(0, timeNanos).UTC()
}
if durationNanos > 0 {
snap.LastDuration = time.Duration(durationNanos)
}
}
return snap
}
// MetaClusterInfo shows information about the meta group.
type MetaClusterInfo struct {
Name string `json:"name,omitempty"` // Name is the name of the cluster
Leader string `json:"leader,omitempty"` // Leader is the server name of the cluster leader
Peer string `json:"peer,omitempty"` // Peer is unique ID of the leader
Replicas []*PeerInfo `json:"replicas,omitempty"` // Replicas is a list of known peers
Size int `json:"cluster_size"` // Size is the known size of the cluster
Pending int `json:"pending"` // Pending is how many RAFT messages are not yet processed
Snapshot *MetaSnapshotStats `json:"snapshot"` // Snapshot contains meta snapshot statistics
Name string `json:"name,omitempty"` // Name is the name of the cluster
Leader string `json:"leader,omitempty"` // Leader is the server name of the cluster leader
Peer string `json:"peer,omitempty"` // Peer is unique ID of the leader
Replicas []*PeerInfo `json:"replicas,omitempty"` // Replicas is a list of known peers
Size int `json:"cluster_size"` // Size is the known size of the cluster
Pending int `json:"pending"` // Pending is how many RAFT messages are not yet processed
PendingRequests int `json:"pending_requests"` // PendingRequests is how many CRUD operations are queued for processing
PendingInfos int `json:"pending_infos"` // PendingInfos is how many info operations are queued for processing
Snapshot *MetaSnapshotStats `json:"snapshot"` // Snapshot contains meta snapshot statistics
}
// JSInfo has detailed information on JetStream.
@@ -3233,32 +3281,18 @@ func (s *Server) Jsz(opts *JSzOptions) (*JSInfo, error) {
if mg := js.getMetaGroup(); mg != nil {
if ci := s.raftNodeToClusterInfo(mg); ci != nil {
entries, bytes := mg.Size()
jsi.Meta = &MetaClusterInfo{Name: ci.Name, Leader: ci.Leader, Peer: getHash(ci.Leader), Size: mg.ClusterSize()}
if isLeader {
jsi.Meta.Replicas = ci.Replicas
}
if ipq := s.jsAPIRoutedReqs; ipq != nil {
jsi.Meta.Pending = ipq.len()
jsi.Meta.PendingRequests = ipq.len()
}
// Add meta snapshot stats
jsi.Meta.Snapshot = &MetaSnapshotStats{
PendingEntries: entries,
PendingSize: bytes,
}
js.mu.RLock()
cluster := js.cluster
js.mu.RUnlock()
if cluster != nil {
timeNanos := atomic.LoadInt64(&cluster.lastMetaSnapTime)
durationNanos := atomic.LoadInt64(&cluster.lastMetaSnapDuration)
if timeNanos > 0 {
jsi.Meta.Snapshot.LastTime = time.Unix(0, timeNanos).UTC()
}
if durationNanos > 0 {
jsi.Meta.Snapshot.LastDuration = time.Duration(durationNanos)
}
if ipq := s.jsAPIRoutedInfoReqs; ipq != nil {
jsi.Meta.PendingInfos = ipq.len()
}
jsi.Meta.Pending = jsi.Meta.PendingRequests + jsi.Meta.PendingInfos
jsi.Meta.Snapshot = s.metaClusterSnapshotStats(js, mg)
}
}
@@ -3695,6 +3729,20 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus {
})
continue
}
if streamWerr := s.getWriteErr(); streamWerr != nil {
if !details {
health.Status = na
health.Error = fmt.Sprintf("JetStream stream '%s > %s' write error: %v", acc, stream, streamWerr)
return health
}
health.Errors = append(health.Errors, HealthzError{
Type: HealthzErrorStream,
Account: acc.Name,
Stream: stream,
Error: fmt.Sprintf("JetStream stream '%s > %s' write error: %v", acc, stream, streamWerr),
})
continue
}
if streamFound {
// if consumer option is passed, verify that the consumer exists on stream
if opts.Consumer != _EMPTY_ {
@@ -3771,49 +3819,37 @@ func (s *Server) healthz(opts *HealthzOptions) *HealthStatus {
meta = cc.meta
js.mu.RUnlock()
// If no meta leader.
if meta == nil || meta.GroupLeader() == _EMPTY_ {
if !details {
health.Status = na
health.Error = "JetStream has not established contact with a meta leader"
} else {
health.Errors = []HealthzError{
{
Type: HealthzErrorJetStream,
Error: "JetStream has not established contact with a meta leader",
},
}
}
return health
// Check meta layer health.
var metaNoLeader, metaClosed, metaUnhealthy bool
var metaWerr error
if meta != nil {
metaNoLeader = meta.GroupLeader() == _EMPTY_
metaClosed = meta.State() == Closed
metaUnhealthy = !meta.Healthy()
metaWerr = meta.GetWriteErr()
}
// If we are not current with the meta leader.
if !meta.Healthy() {
if !details {
health.Status = na
health.Error = "JetStream is not current with the meta leader"
metaRecovering := js.isMetaRecovering()
if meta == nil || metaNoLeader || metaClosed || metaUnhealthy || metaWerr != nil || metaRecovering {
var desc string
if metaWerr != nil {
desc = fmt.Sprintf("JetStream meta layer write error: %v", metaWerr)
} else if metaClosed {
desc = "JetStream meta layer is not running"
} else if meta != nil && metaRecovering {
desc = "JetStream is still recovering meta layer"
} else if meta == nil || metaNoLeader {
desc = "JetStream has not established contact with a meta leader"
} else {
health.Errors = []HealthzError{
{
Type: HealthzErrorJetStream,
Error: "JetStream is not current with the meta leader",
},
}
desc = "JetStream is not current with the meta leader"
}
return health
}
// Are we still recovering meta layer?
if js.isMetaRecovering() {
if !details {
health.Status = na
health.Error = "JetStream is still recovering meta layer"
health.Error = desc
} else {
health.Errors = []HealthzError{
{
Type: HealthzErrorJetStream,
Error: "JetStream is still recovering meta layer",
Error: desc,
},
}
}
@@ -4090,6 +4126,8 @@ type RaftzGroup struct {
QuorumNeeded int `json:"quorum_needed"`
Observer bool `json:"observer,omitempty"`
Paused bool `json:"paused,omitempty"`
Overrun bool `json:"overrun,omitempty"`
OverrunCount uint64 `json:"overrun_count,omitempty"`
Committed uint64 `json:"committed"`
Applied uint64 `json:"applied"`
CatchingUp bool `json:"catching_up,omitempty"`
@@ -4198,6 +4236,8 @@ func (s *Server) Raftz(opts *RaftzOptions) *RaftzStatus {
QuorumNeeded: n.qn,
Observer: n.observer,
Paused: n.paused,
Overrun: n.quorumPaused || n.isLeaderOverrun(),
OverrunCount: n.overrunCount,
Committed: n.commit,
Applied: n.applied,
CatchingUp: n.catchup != nil,
+78 -85
View File
@@ -15,7 +15,6 @@ package server
import (
"bytes"
"cmp"
"crypto/tls"
"encoding/binary"
"encoding/json"
@@ -33,6 +32,8 @@ import (
"unicode/utf8"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nats-server/v2/server/gsl"
"github.com/nats-io/nats-server/v2/server/stree"
"github.com/nats-io/nuid"
)
@@ -258,14 +259,13 @@ type mqttSessionManager struct {
type mqttAccountSessionManager struct {
mu sync.RWMutex
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]time.Time // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
sl *Sublist // sublist allowing to find retained messages for given subscription
retmsgs map[string]*mqttRetainedMsgRef // retained messages
rmsCache *sync.Map // map[subject]mqttRetainedMsg
sessions map[string]*mqttSession // key is MQTT client ID
sessByHash map[string]*mqttSession // key is MQTT client ID hash
sessLocked map[string]struct{} // key is MQTT client ID and indicate that a session can not be taken by a new client at this time
flappers map[string]time.Time // When connection connects with client ID already in use
flapTimer *time.Timer // Timer to perform some cleanup of the flappers map
retmsgs *stree.SubjectTree[mqttRetainedMsgRef] // retained message metadata
rmsCache *sync.Map // map[subject]mqttRetainedMsg
jsa mqttJSA
domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject.
}
@@ -366,7 +366,6 @@ type mqttRetainedMsg struct {
type mqttRetainedMsgRef struct {
sseq uint64
sub *subscription
}
// mqttSub contains fields associated with a MQTT subscription, and is added to
@@ -2022,10 +2021,14 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
if err != nil {
return
}
if strings.IndexByte(rm.Subject, 0x7f) >= 0 {
c.Warnf("Skipping retained message for subject %q: unsupported character 0x7f", rm.Subject)
return
}
// The as.jsa.id is immutable, so no need to have a rlock here.
local := rm.Origin == as.jsa.id
// Get the stream sequence for this message.
seq, _, _ := ackReplyInfo(reply)
seq, _, _, _, _ := ackReplyInfo(reply)
if len(m) == 0 {
// An empty payload means that we need to remove the retained message.
rmSeq := as.removeRetainedMsg(rm.Subject, 0)
@@ -2042,7 +2045,7 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
// Add this retained message. The `rm.Msg` references some buffer that we
// don't own. But addRetainedMsg() will take care of making a copy of
// `rm.Msg` it `rm` ends-up being stored in the cache.
as.addRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm)
as.addRetainedMsg(rm.Subject, seq, rm)
}
}
@@ -2310,17 +2313,16 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc
// If a message for this topic already existed, the existing record is updated
// with the provided information.
// Lock not held on entry.
func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) {
func (as *mqttAccountSessionManager) addRetainedMsg(key string, sseq uint64, rm *mqttRetainedMsg) {
as.mu.Lock()
defer as.mu.Unlock()
if as.retmsgs == nil {
as.retmsgs = make(map[string]*mqttRetainedMsgRef)
as.sl = NewSublistWithCache()
as.retmsgs = stree.NewSubjectTree[mqttRetainedMsgRef]()
} else {
// Check if we already had one retained message. If so, update the existing one.
if erf, exists := as.retmsgs[key]; exists {
if erf, exists := as.retmsgs.Find(stringToBytes(key)); exists {
// Update the stream sequence with the new value.
erf.sseq = rf.sseq
erf.sseq = sseq
// Update the in-memory retained message cache but only for messages
// that are already in the cache, i.e. have been (recently) used.
// If that is the case, we ask setCachedRetainedMsg() to make a copy
@@ -2329,9 +2331,7 @@ func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetained
return
}
}
rf.sub = &subscription{subject: []byte(key)}
as.retmsgs[key] = rf
as.sl.Insert(rf.sub)
as.retmsgs.Insert([]byte(key), mqttRetainedMsgRef{sseq: sseq})
}
// Remove the retained message stored with the `subject` key from the map/cache.
@@ -2348,15 +2348,13 @@ func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetained
func (as *mqttAccountSessionManager) removeRetainedMsg(subject string, seq uint64) uint64 {
as.mu.Lock()
defer as.mu.Unlock()
rm, ok := as.retmsgs[subject]
rm, ok := as.retmsgs.Find(stringToBytes(subject))
if !ok || (seq > 0 && rm.sseq != seq) {
return 0
}
seq = rm.sseq
rm, _ = as.retmsgs.Delete(stringToBytes(subject))
as.rmsCache.Delete(subject)
delete(as.retmsgs, subject)
as.sl.Remove(rm.sub)
return seq
return rm.sseq
}
// First check if this session's client ID is already in the "locked" map,
@@ -2684,27 +2682,22 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client,
// Account session manager lock held on entry.
// Session lock held on entry.
func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]*mqttRetainedMsg, sess *mqttSession, c *client, sub *subscription, trace bool) {
if len(as.retmsgs) == 0 || len(rms) == 0 {
return
}
result := as.sl.ReverseMatch(string(sub.subject))
if len(result.psubs) == 0 {
if as.retmsgs.Size() == 0 || len(rms) == 0 {
return
}
toTrace := []mqttPublish{}
for _, psub := range result.psubs {
rm := rms[string(psub.subject)]
as.retmsgs.Match(sub.subject, func(subj []byte, _ *mqttRetainedMsgRef) {
rm := rms[string(subj)]
if rm == nil {
// This should not happen since we pre-load messages into rms before
// calling serialize.
continue
return
}
var pi uint16
qos := min(mqttGetQoS(rm.Flags), sub.mqtt.qos)
if c.mqtt.rejectQoS2Pub && qos == 2 {
c.Warnf("Rejecting retained message with QoS2 for subscription %q, as configured", sub.subject)
continue
return
}
if qos > 0 {
pi = sess.trackPublishRetained()
@@ -2731,7 +2724,7 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]
sz: len(rm.Msg),
})
}
}
})
for _, pp := range toTrace {
c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp)))
}
@@ -2743,27 +2736,21 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]
// Account session manager NOT lock held on entry.
func (as *mqttAccountSessionManager) addRetainedSubjectsForSubject(list map[string]uint64, topSubject string) {
as.mu.RLock()
if len(as.retmsgs) == 0 {
as.mu.RUnlock()
defer as.mu.RUnlock()
if as.retmsgs.Size() == 0 {
return
}
result := as.sl.ReverseMatch(topSubject)
as.mu.RUnlock()
for _, sub := range result.psubs {
if _, ok := list[string(sub.subject)]; ok {
continue
as.retmsgs.Match(stringToBytes(topSubject), func(subj []byte, ret *mqttRetainedMsgRef) {
subject := string(subj)
if _, ok := list[subject]; ok {
return
}
var seq uint64
as.mu.RLock()
if rm, ok := as.retmsgs[string(sub.subject)]; ok {
seq = rm.sseq
if seq := ret.sseq; seq > 0 {
list[subject] = seq
}
as.mu.RUnlock()
if seq > 0 {
list[string(sub.subject)] = seq
}
}
})
}
type warner interface {
@@ -3457,7 +3444,7 @@ func (sess *mqttSession) trackPublish(jsDur, jsAckSubject string) (uint16, bool)
}
// Get the stream sequence and duplicate flag from the ack reply subject.
sseq, _, dcount := ackReplyInfo(jsAckSubject)
sseq, _, dcount, _, _ := ackReplyInfo(jsAckSubject)
if dcount > 1 {
dup = true
}
@@ -3561,7 +3548,7 @@ func (sess *mqttSession) trackAsPubRel(pi uint16, jsAckSubject string) {
return
}
sseq, _, _ := ackReplyInfo(jsAckSubject)
sseq, _, _, _, _ := ackReplyInfo(jsAckSubject)
if sess.cpending == nil {
sess.cpending = make(map[string]map[uint64]uint16)
@@ -4568,13 +4555,8 @@ func (s *Server) mqttCheckPubRetainedPerms() {
}
sm.mu.RUnlock()
type retainedMsg struct {
subj string
rmsg *mqttRetainedMsgRef
}
// For each session we will obtain a list of retained messages.
var _rms [128]retainedMsg
var _rms [128]uint64
rms := _rms[:0]
for _, asm := range asms {
// Get all of the retained messages. Then we will sort them so
@@ -4582,19 +4564,20 @@ func (s *Server) mqttCheckPubRetainedPerms() {
// store to not have to load out-of-order blocks so often.
asm.mu.RLock()
rms = rms[:0] // reuse slice
for subj, rf := range asm.retmsgs {
rms = append(rms, retainedMsg{
subj: subj,
rmsg: rf,
})
}
// Copy the sequence out of the tree. The tree entry itself can be
// updated concurrently by addRetainedMsg() after we release the lock,
// so keeping a pointer here would race with the later sort.
asm.retmsgs.IterOrdered(func(_ []byte, rm *mqttRetainedMsgRef) bool {
rms = append(rms, rm.sseq)
return true
})
jsaID := asm.jsa.id
asm.mu.RUnlock()
slices.SortFunc(rms, func(i, j retainedMsg) int { return cmp.Compare(i.rmsg.sseq, j.rmsg.sseq) })
slices.Sort(rms)
perms := map[string]*perm{}
perms := map[string]*mqttPerm{}
for _, rf := range rms {
jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf.rmsg.sseq)
jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf)
if err != nil || jsm == nil {
continue
}
@@ -4617,7 +4600,7 @@ func (s *Server) mqttCheckPubRetainedPerms() {
}
// If there is permission and no longer allowed to publish in
// the subject, remove the publish retained message from the map.
if p != nil && !pubAllowed(p, rf.subj) {
if p != nil && !pubAllowed(p, rm.Subject) {
u = nil
}
}
@@ -4636,7 +4619,12 @@ func (s *Server) mqttCheckPubRetainedPerms() {
}
// Helper to generate only pub permissions from a Permissions object
func generatePubPerms(perms *Permissions) *perm {
type mqttPerm struct {
allow *gsl.SimpleSublist
deny *gsl.SimpleSublist
}
func generatePubPerms(perms *Permissions) *mqttPerm {
// If given permissions is `nil`, then it means that permissions block
// has been removed (so the user is now allowed to publish on everything)
// or was never there in the first place. Returning `nil` will let the
@@ -4644,39 +4632,38 @@ func generatePubPerms(perms *Permissions) *perm {
if perms == nil {
return nil
}
var p *perm
var p *mqttPerm
if perms.Publish.Allow != nil {
p = &perm{}
p.allow = NewSublistWithCache()
p = &mqttPerm{}
p.allow = gsl.NewSimpleSublist()
for _, pubSubject := range perms.Publish.Allow {
sub := &subscription{subject: []byte(pubSubject)}
p.allow.Insert(sub)
_ = p.allow.Insert(pubSubject, struct{}{})
}
}
if len(perms.Publish.Deny) > 0 {
if p == nil {
p = &perm{}
p = &mqttPerm{}
}
p.deny = NewSublistWithCache()
p.deny = gsl.NewSimpleSublist()
for _, pubSubject := range perms.Publish.Deny {
sub := &subscription{subject: []byte(pubSubject)}
p.deny.Insert(sub)
_ = p.deny.Insert(pubSubject, struct{}{})
}
}
return p
}
// Helper that checks if given `perms` allow to publish on the given `subject`
func pubAllowed(perms *perm, subject string) bool {
func pubAllowed(perms *mqttPerm, subject string) bool {
if perms == nil {
return true
}
allowed := true
if perms.allow != nil {
np, _ := perms.allow.NumInterest(subject)
allowed = np != 0
allowed = perms.allow.HasInterest(subject)
}
// If we have a deny list and are currently allowed, check that as well.
if allowed && perms.deny != nil {
np, _ := perms.deny.NumInterest(subject)
allowed = np == 0
allowed = !perms.deny.HasInterest(subject)
}
return allowed
}
@@ -5729,6 +5716,12 @@ func mqttToNATSSubjectConversion(mt []byte, wcOk bool) ([]byte, error) {
case ' ':
// As of now, we cannot support ' ' in the MQTT topic/filter.
return nil, errMQTTUnsupportedCharacters
case 0x7f:
// SubjectTree uses DEL as an internal pivot marker, so retained
// subjects containing it cannot be indexed safely, including
// legacy retained messages recovered from the retained-message
// stream.
return nil, errMQTTUnsupportedCharacters
case btsep:
if !cp {
makeCopy(i)
+43 -90
View File
@@ -1,4 +1,4 @@
// Copyright 2024-2025 The NATS Authors
// Copyright 2024-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -26,6 +26,7 @@ import (
const (
MsgTraceDest = "Nats-Trace-Dest"
MsgTraceDestDisabled = "trace disabled" // This must be an invalid NATS subject
MsgTraceHop = "Nats-Trace-Hop"
MsgTraceOriginAccount = "Nats-Trace-Origin-Account"
MsgTraceOnly = "Nats-Trace-Only"
@@ -33,10 +34,19 @@ const (
// External trace header. Note that this header is normally in lower
// case (https://www.w3.org/TR/trace-context/#header-name). Vendors
// MUST expect the header in any case (upper, lower, mixed), and
// SHOULD send the header name in lowercase.
// SHOULD send the header name in lowercase. We used to change it
// to lower case, but no longer do that in 2.14.
traceParentHdr = "traceparent"
)
var (
traceDestHdrAsBytes = stringToBytes(MsgTraceDest)
traceDestDisabledAsBytes = stringToBytes(MsgTraceDestDisabled)
traceParentHdrAsBytes = stringToBytes(traceParentHdr)
crLFAsBytes = stringToBytes(CR_LF)
dashAsBytes = stringToBytes("-")
)
type MsgTraceType string
// Type of message trace events in the MsgTraceEvents list.
@@ -352,7 +362,6 @@ func (c *client) initMsgTrace() *msgTrace {
}
return vv[0]
}
ct := getCompressionType(getHdrVal(acceptEncodingHeader))
var (
dest string
traceOnly bool
@@ -454,9 +463,9 @@ func (c *client) initMsgTrace() *msgTrace {
}
// Check sampling, but only from origin server.
if c.kind == CLIENT && !sample(sampling) {
// Need to desactivate the traceParentHdr so that if the message
// is routed, it does possibly trigger a trace there.
disableTraceHeaders(c, hdr)
// Need to disable tracing so that if the message is routed, it won't
// trigger a trace there.
c.msgBuf = c.setHeader(MsgTraceDest, MsgTraceDestDisabled, c.msgBuf)
return nil
}
}
@@ -465,7 +474,7 @@ func (c *client) initMsgTrace() *msgTrace {
acc: acc,
oan: oan,
dest: dest,
ct: ct,
ct: getCompressionType(getHdrVal(acceptEncodingHeader)),
hop: hop,
event: &MsgTraceEvent{
Request: MsgTraceRequest{
@@ -503,9 +512,7 @@ func sample(sampling int) bool {
// the headers have been lifted due to the presence of the external trace header
// only.
// Note that because of the traceParentHdr, the search is done in a case
// insensitive way, but if the header is found, it is rewritten in lower case
// as suggested by the spec, but also to make it easier to disable the header
// when needed.
// insensitive way. We used to rewrite it in lower case but no longer do since v2.14.
func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
var (
@@ -520,11 +527,6 @@ func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
return nil, false
}
traceDestHdrAsBytes := stringToBytes(MsgTraceDest)
traceParentHdrAsBytes := stringToBytes(traceParentHdr)
crLFAsBytes := stringToBytes(CR_LF)
dashAsBytes := stringToBytes("-")
keys := _keys[:0]
vals := _vals[:0]
@@ -537,46 +539,50 @@ func genHeaderMapIfTraceHeadersPresent(hdr []byte) (map[string][]string, bool) {
keyStart := i
key := hdr[keyStart : keyStart+del]
i += del + 1
for i < len(hdr) && (hdr[i] == ' ' || hdr[i] == '\t') {
i++
}
valStart := i
nl := bytes.Index(hdr[valStart:], crLFAsBytes)
if nl < 0 {
break
}
if len(key) > 0 {
val := bytes.Trim(hdr[valStart:valStart+nl], " \t")
valEnd := valStart + nl
for valEnd > valStart && (hdr[valEnd-1] == ' ' || hdr[valEnd-1] == '\t') {
valEnd--
}
val := hdr[valStart:valEnd]
if len(key) > 0 && len(val) > 0 {
vals = append(vals, val)
// We search for our special keys only if not already found.
// Check for the external trace header.
if bytes.EqualFold(key, traceParentHdrAsBytes) {
// Rewrite the header using lower case if needed.
if !bytes.Equal(key, traceParentHdrAsBytes) {
copy(hdr[keyStart:], traceParentHdrAsBytes)
}
// Search needs to be case insensitive.
if !traceParentHdrFound && bytes.EqualFold(key, traceParentHdrAsBytes) {
// We will now check if the value has sampling or not.
// TODO(ik): Not sure if this header can have multiple values
// or not, and if so, what would be the rule to check for
// sampling. What is done here is to check them all until we
// found one with sampling.
if !traceParentHdrFound {
tk := bytes.Split(val, dashAsBytes)
if len(tk) == 4 && len([]byte(tk[3])) == 2 {
if hexVal, err := strconv.ParseInt(bytesToString(tk[3]), 16, 8); err == nil {
if hexVal&0x1 == 0x1 {
traceParentHdrFound = true
}
tk := bytes.Split(val, dashAsBytes)
if len(tk) == 4 && len([]byte(tk[3])) == 2 {
if hexVal, err := strconv.ParseInt(bytesToString(tk[3]), 16, 8); err == nil {
if hexVal&0x1 == 0x1 {
traceParentHdrFound = true
}
}
}
// Add to the keys with the external trace header in lower case.
keys = append(keys, traceParentHdrAsBytes)
} else {
// Is the key the Nats-Trace-Dest header?
if bytes.EqualFold(key, traceDestHdrAsBytes) {
traceDestHdrFound = true
} else if !traceDestHdrFound && bytes.Equal(key, traceDestHdrAsBytes) {
// This is the Nats-Trace-Dest header, check the value to see
// if it indicates that the trace was disabled.
if bytes.Equal(val, traceDestDisabledAsBytes) {
return nil, false
}
// Add to the keys and preserve the key's case
keys = append(keys, key)
traceDestHdrFound = true
}
// Add to the keys and preserve the key's case
keys = append(keys, key)
}
i += nl + 2
}
@@ -655,59 +661,6 @@ func (t *msgTrace) setHopHeader(c *client, msg []byte) []byte {
return c.setHeader(MsgTraceHop, t.nhop, msg)
}
// Will look for the MsgTraceSendTo and traceParentHdr headers and change the first
// character to an 'X' so that if this message is sent to a remote, the remote
// will not initialize tracing since it won't find the actual trace headers.
// The function returns the position of the headers so it can efficiently be
// re-enabled by calling enableTraceHeaders.
// Note that if `msg` can be either the header alone or the full message
// (header and payload). This function will use c.pa.hdr to limit the
// search to the header section alone.
func disableTraceHeaders(c *client, msg []byte) []int {
// Code largely copied from getHeader(), except that we don't need the value
if c.pa.hdr <= 0 {
return []int{-1, -1}
}
hdr := msg[:c.pa.hdr]
headers := [2]string{MsgTraceDest, traceParentHdr}
positions := [2]int{-1, -1}
for i := 0; i < 2; i++ {
key := stringToBytes(headers[i])
pos := bytes.Index(hdr, key)
if pos < 0 {
continue
}
// Make sure this key does not have additional prefix.
if pos < 2 || hdr[pos-1] != '\n' || hdr[pos-2] != '\r' {
continue
}
index := pos + len(key)
if index >= len(hdr) {
continue
}
if hdr[index] != ':' {
continue
}
// Disable the trace by altering the first character of the header
hdr[pos] = 'X'
positions[i] = pos
}
// Return the positions of those characters so we can re-enable the headers.
return positions[:2]
}
// Changes back the character at the given position `pos` in the `msg`
// byte slice to the first character of the MsgTraceSendTo header.
func enableTraceHeaders(msg []byte, positions []int) {
firstChar := [2]byte{MsgTraceDest[0], traceParentHdr[0]}
for i, pos := range positions {
if pos == -1 {
continue
}
msg[pos] = firstChar[i]
}
}
func (t *msgTrace) setIngressError(err string) {
if i := t.event.Ingress(); i != nil {
i.Error = err
+165 -7
View File
@@ -27,6 +27,7 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strconv"
@@ -108,6 +109,34 @@ type CompressionOpts struct {
RTTThresholds []time.Duration
}
func (c1 *CompressionOpts) equals(c2 *CompressionOpts) bool {
if c1 == c2 {
return true
}
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
return false
}
if c1.Mode != c2.Mode {
return false
}
// For s2_auto, if one has an empty RTTThresholds, it is equivalent
// to the defaultCompressionS2AutoRTTThresholds array, so compare with that.
if c1.Mode == CompressionS2Auto {
rtts1 := c1.RTTThresholds
if len(rtts1) == 0 {
rtts1 = defaultCompressionS2AutoRTTThresholds
}
rtts2 := c2.RTTThresholds
if len(rtts2) == 0 {
rtts2 = defaultCompressionS2AutoRTTThresholds
}
if !reflect.DeepEqual(rtts1, rtts2) {
return false
}
}
return true
}
// GatewayOpts are options for gateways.
// NOTE: This structure is no longer used for monitoring endpoints
// and json tags are deprecated and may be removed in the future.
@@ -283,6 +312,48 @@ type RemoteLeafOpts struct {
// existing connection will be closed and not solicited again (until it is changed
// to `false` again.
Disabled bool `json:"-"`
// If this is set to true, this remote will ignore any server leafnode URLs
// returned by the hub, allowing the user to fully manage the servers this
// remote can connect to.
IgnoreDiscoveredServers bool `json:"-"`
}
// Returns a string representation of this `RemoteLeafOpts` object, containing
// the URLs (unredacted), the account (or "$G" if none is specified) and, if present,
// the credentials filename.
func (r *RemoteLeafOpts) name() string {
return generateRemoteLeafOptsName(r, false)
}
// Same than RemoteLeafOpts.name() but uses redacted URLs. This is to be used for logging.
func (r *RemoteLeafOpts) safeName() string {
return generateRemoteLeafOptsName(r, true)
}
func generateRemoteLeafOptsName(r *RemoteLeafOpts, redacted bool) string {
acc := r.LocalAccount
if acc == _EMPTY_ {
acc = globalAccountName
}
var optional string
// There could be Credentials or NKey, not both (would be caught as a misconfig)
if c := r.Credentials; c != _EMPTY_ {
optional = fmt.Sprintf(", credentials=%q", c)
} else if nk := r.Nkey; nk != _EMPTY_ {
if redacted {
optional = ", nkey=\"[REDACTED]\""
} else {
optional = fmt.Sprintf(", nkey=%q", nk)
}
}
var urls []*url.URL
if redacted {
urls = redactURLList(r.URLs)
} else {
urls = r.URLs
}
return fmt.Sprintf("urls=%q, account=%q%s", urls, acc, optional)
}
// JSLimitOpts are active limits for the meta cluster
@@ -387,6 +458,7 @@ type Options struct {
JetStreamTpm JSTpmOpts
JetStreamMaxCatchup int64
JetStreamRequestQueueLimit int64
JetStreamInfoQueueLimit int64
JetStreamMetaCompact uint64
JetStreamMetaCompactSize uint64
JetStreamMetaCompactSync bool
@@ -478,6 +550,9 @@ type Options struct {
// Metadata describing the server. They will be included in 'Z' responses.
Metadata map[string]string `json:"-"`
// FeatureFlags the server opts-in to (or opts-out of). They will be included in 'Z' responses.
FeatureFlags map[string]bool `json:"-"`
// OCSPConfig enables OCSP Stapling in the server.
OCSPConfig *OCSPConfig
tlsConfigOpts *TLSConfigOpts
@@ -1748,6 +1823,29 @@ func (o *Options) processConfigFileLine(k string, v any, errors *[]error, warnin
*errors = append(*errors, err)
return
}
case "feature_flags":
var err error
switch v := v.(type) {
case map[string]any:
for mk, mv := range v {
tk, mv = unwrapValue(mv, &lt)
b, ok := mv.(bool)
if !ok {
err = &configErr{tk, fmt.Sprintf("error parsing feature flag %q: expected bool, got %T", mk, mv)}
break
}
if o.FeatureFlags == nil {
o.FeatureFlags = make(map[string]bool)
}
o.FeatureFlags[mk] = b
}
default:
err = &configErr{tk, fmt.Sprintf("error parsing feature flags: unsupported type %T", v)}
}
if err != nil {
*errors = append(*errors, err)
return
}
case "default_js_domain":
vv, ok := v.(map[string]any)
if !ok {
@@ -2641,6 +2739,12 @@ func parseJetStream(v any, opts *Options, errors *[]error, warnings *[]error) er
return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)}
}
opts.JetStreamRequestQueueLimit = lim
case "info_queue_limit":
lim, ok := mv.(int64)
if !ok {
return &configErr{tk, fmt.Sprintf("Expected a parseable size for %q, got %v", mk, mv)}
}
opts.JetStreamInfoQueueLimit = lim
case "meta_compact":
thres, ok := mv.(int64)
if !ok || thres < 0 {
@@ -2936,6 +3040,7 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
if !ok {
return nil, &configErr{tk, fmt.Sprintf("Expected remotes field to be an array, got %T", v)}
}
names := make(map[string]struct{})
remotes := make([]*RemoteLeafOpts, 0, len(ra))
for _, r := range ra {
tk, r = unwrapValue(r, &lt)
@@ -3105,6 +3210,8 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
}
}
}
case "ignore_discovered_servers":
remote.IgnoreDiscoveredServers = v.(bool)
default:
if !tk.IsUsedVariable() {
err := &unknownConfigFieldErr{
@@ -3128,6 +3235,12 @@ func parseRemoteLeafNodes(v any, errors *[]error, warnings *[]error) ([]*RemoteL
*warnings = append(*warnings, &configErr{proxyToken, warn})
}
}
rn := remote.name()
if _, dup := names[rn]; dup {
*errors = append(*errors, &configErr{tk, fmt.Sprintf("duplicate remote %s", remote.safeName())})
continue
}
names[rn] = struct{}{}
remotes = append(remotes, remote)
}
return remotes, nil
@@ -6006,6 +6119,9 @@ func setBaselineOptions(opts *Options) {
if opts.JetStreamRequestQueueLimit <= 0 {
opts.JetStreamRequestQueueLimit = JSDefaultRequestQueueLimit
}
if opts.JetStreamInfoQueueLimit <= 0 {
opts.JetStreamInfoQueueLimit = opts.JetStreamRequestQueueLimit
}
}
func getDefaultAuthTimeout(tls *tls.Config, tlsTimeout float64) float64 {
@@ -6439,14 +6555,56 @@ func expandPath(p string) (string, error) {
// RedactArgs redacts sensitive arguments from the command line.
// For example, turns '--pass=secret' into '--pass=[REDACTED]'.
func RedactArgs(args []string) {
secret := regexp.MustCompile("^-{1,2}(user|pass|auth)(=.*)?$")
secretArg := regexp.MustCompile("^-{1,2}(user|pass|auth)(=.*)?$")
routeURLArg := regexp.MustCompile("^-{1,2}(routes)(=.*)?$")
singleURLArg := regexp.MustCompile("^-{1,2}(cluster|cluster_listen)(=.*)?$")
for i, arg := range args {
if secret.MatchString(arg) {
if idx := strings.Index(arg, "="); idx != -1 {
args[i] = arg[:idx] + "=[REDACTED]"
} else if i+1 < len(args) {
args[i+1] = "[REDACTED]"
}
switch {
case secretArg.MatchString(arg):
redactArgValue(args, i, func(_ string) string { return "[REDACTED]" })
case routeURLArg.MatchString(arg):
redactArgValue(args, i, redactURLListUser)
case singleURLArg.MatchString(arg):
redactArgValue(args, i, redactURLUser)
}
}
}
func redactArgValue(args []string, i int, redact func(string) string) {
if flag, value, ok := strings.Cut(args[i], "="); ok {
args[i] = flag + "=" + redact(value)
} else if i+1 < len(args) {
args[i+1] = redact(args[i+1])
}
}
func redactURLUser(raw string) string {
if !strings.Contains(raw, "@") {
return raw
}
parseValue := strings.TrimSpace(raw)
restoreRandom := false
if prefix, ok := strings.CutSuffix(parseValue, ":-1"); ok {
parseValue = prefix + ":0"
restoreRandom = true
}
u, err := url.Parse(parseValue)
if err != nil || u.User == nil {
return raw
}
// url.String escapes brackets in userinfo, so use
// a placeholder here and rewrite it afterward.
u.User = url.User("_REDACTED_")
if restoreRandom {
u.Host = strings.TrimSuffix(u.Host, ":0") + ":-1"
}
return strings.Replace(u.String(), "_REDACTED_@", "[REDACTED]@", 1)
}
func redactURLListUser(raw string) string {
parts := strings.Split(raw, ",")
for i, part := range parts {
parts[i] = redactURLUser(part)
}
return strings.Join(parts, ",")
}
+37 -22
View File
@@ -166,32 +166,40 @@ func (c *client) parse(buf []byte) error {
goto authErr
}
var ok bool
// Check here for NoAuthUser. If this is set allow non CONNECT protos as our first.
// E.g. telnet proto demos.
if noAuthUser := s.getOpts().NoAuthUser; noAuthUser != _EMPTY_ {
s.mu.Lock()
user, exists := s.users[noAuthUser]
s.mu.Unlock()
if exists {
c.RegisterUser(user)
c.mu.Lock()
c.clearAuthTimer()
c.flags.set(connectReceived)
c.mu.Unlock()
authSet, ok = false, true
switch c.kind {
case CLIENT:
// Check here for NoAuthUser. If this is set allow non CONNECT protos as our first.
// E.g. telnet proto demos.
opts := s.getOpts()
noAuthUser := opts.NoAuthUser
if c.ws != nil {
if noAuthUserWS := opts.Websocket.NoAuthUser; noAuthUserWS != _EMPTY_ {
noAuthUser = noAuthUserWS
}
}
if noAuthUser != _EMPTY_ {
s.mu.Lock()
user, exists := s.users[noAuthUser]
s.mu.Unlock()
if exists {
c.RegisterUser(user)
c.mu.Lock()
c.clearAuthTimer()
c.flags.set(connectReceived)
c.mu.Unlock()
authSet, ok = false, true
}
}
case LEAF:
// Compressed inbound leaf-node negotiation may require INFO
// before CONNECT. Without compression, leaf connections must
// still start with CONNECT.
ok = (b == 'I' || b == 'i') && needsCompression(s.getOpts().LeafNode.Compression.Mode)
}
if !ok {
goto authErr
}
}
// If the connection is a gateway connection, make sure that
// if this is an inbound, it starts with a CONNECT.
if c.kind == GATEWAY && !c.gw.outbound && !c.gw.connected {
// Use auth violation since no CONNECT was sent.
// It could be a parseErr too.
goto authErr
}
}
switch b {
case 'P', 'p':
@@ -1250,10 +1258,17 @@ func protoSnippet(start, max int, buf []byte) string {
// If so, an error is sent to the client and the connection is closed.
// The error ErrMaxControlLine is returned.
func (c *client) overMaxControlLineLimit(arg []byte, mcl int32) error {
// Widen to int64 so mcl*16 cannot overflow for large configured values.
effective := int64(mcl)
if c.kind != CLIENT {
return nil
// This is the upper bound on argBuf length for LEAF, ROUTER, and GATEWAY connections.
// These kinds need longer arg lines than CLIENT (which is capped at mcl=4096 by default)
// because cluster/leaf frames encode origin, account, reply, and queue groups.
// By default, this is 64 KB, which matches maxBufSize so a single oversized read
// is caught on the very next parse call.
effective *= 16
}
if len(arg) > int(mcl) {
if int64(len(arg)) > effective {
err := NewErrorCtx(ErrMaxControlLine, "State %d, max_control_line %d, Buffer len %d (snip: %s...)",
c.state, int(mcl), len(c.argBuf), protoSnippet(0, MAX_CONTROL_LINE_SNIPPET_SIZE, arg))
c.sendErr(err.Error())
+198 -55
View File
@@ -89,6 +89,7 @@ type RaftNode interface {
RecreateInternalSubs() error
IsSystemAccount() bool
GetTrafficAccountName() string
GetWriteErr() error
}
// RaftNodeCheckpoint is used as an alternative to a direct InstallSnapshot.
@@ -248,6 +249,9 @@ type raft struct {
scaleUp bool // The node is part of a scale up, puts us in observer mode until the log contains data.
deleted bool // If the node was deleted.
snapshotting bool // Snapshot is in progress.
quorumPaused bool // Pause replication and quorum participation to prevent log growth during slow applies.
overrunCount uint64 // Counter of how many times we were overrun, either as follower or as leader.
}
type proposedEntry struct {
@@ -487,9 +491,12 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
n.papplied = 0
if _, ok := n.wal.(*memStore); ok {
_ = os.RemoveAll(filepath.Join(n.sd, snapshotsDir))
} else {
// See if we have any snapshots and if so load and process on startup.
n.setupLastSnapshot()
} else if err := n.setupLastSnapshot(); err != nil && err != errNoSnapAvailable {
// If we failed to recover from the snapshot, then we should surface
// the error upwards, otherwise we can complete recovery but have only
// a partial view of the world.
n.shutdown()
return nil, err
}
// We may have restored the peer state from the
@@ -503,6 +510,7 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
// Make sure that the snapshots directory exists.
if err := os.MkdirAll(filepath.Join(n.sd, snapshotsDir), defaultDirPerms); err != nil {
n.shutdown()
return nil, fmt.Errorf("could not create snapshots directory - %v", err)
}
@@ -521,12 +529,6 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
if state.Msgs > 0 {
n.debug("Replaying state of %d entries", state.Msgs)
if first, err := n.loadFirstEntry(); err == nil {
n.pterm, n.pindex = first.pterm, first.pindex
if first.commit > 0 && first.commit > n.commit {
n.commit = first.commit
}
}
// This process will queue up entries on our applied queue but prior to the upper
// state machine running. So we will monitor how much we have queued and if we
@@ -537,6 +539,25 @@ func (s *Server) initRaftNode(accName string, cfg *RaftConfig, labels pprofLabel
// yet. Replay them.
for index, qsz := state.FirstSeq, 0; index <= state.LastSeq; index++ {
ae, err := n.loadEntry(index)
// The first entry in our WAL initializes state but must align with our snapshot if we had one.
// Importantly, check this first, as we might need to truncate the WAL further than the index.
if index == state.FirstSeq {
// If the entry is missing, corrupt, or doesn't align with the snapshot, truncate the WAL.
if err != nil || ae == nil || ae.pindex != index-1 || n.pindex != ae.pindex {
if err != nil {
n.warn("Could not load %d from WAL [%+v]: %v", index, state, err)
} else {
n.warn("Misaligned WAL, will truncate")
}
// Truncate to the snapshot or beginning if there is none.
truncateAndErr(n.pindex)
break
}
n.pterm, n.pindex = ae.pterm, ae.pindex
if ae.commit > 0 && ae.commit > n.commit {
n.commit = ae.commit
}
}
if err != nil {
n.warn("Could not load %d from WAL [%+v]: %v", index, state, err)
// Truncate to the previous correct entry.
@@ -902,6 +923,16 @@ func (n *raft) Propose(data []byte) error {
if werr := n.werr; werr != nil {
return werr
}
if n.isLeaderOverrun() {
var state StreamState
n.wal.FastState(&state)
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
n.stepdownLocked(noLeader)
n.overrunCount++
return errNotLeader
}
n.prop.push(newProposedEntry(newEntry(EntryNormal, data), _EMPTY_))
return nil
}
@@ -921,12 +952,39 @@ func (n *raft) ProposeMulti(entries []*Entry) error {
if werr := n.werr; werr != nil {
return werr
}
if n.isLeaderOverrun() {
var state StreamState
n.wal.FastState(&state)
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
n.stepdownLocked(noLeader)
n.overrunCount++
return errNotLeader
}
for _, e := range entries {
n.prop.push(newProposedEntry(e, _EMPTY_))
}
return nil
}
// isLeaderOverrun returns whether we are overrun and should step down due to continuously increasing
// uncommitted or unapplied entries. If triggered, this means we're being severely overrun by
// incoming proposals or the system is degraded such that it's too slow (or unable) to process them.
// Stepping down means the system gets to "breathe" for a bit, until a new leader can be elected.
// Lock should be held.
func (n *raft) isLeaderOverrun() bool {
applied := max(n.applied, n.papplied)
commit := max(n.commit, n.papplied)
// We only do this past a high threshold to protect ourselves.
// Worst-case we'll have 2x the threshold, once in uncommitted and once in unapplied entries.
// Either the number of uncommitted entries is over the threshold: we're not getting quorum from our followers.
uncommittedThreshold := n.pindex > commit && n.pindex-commit > pauseQuorumThreshold
// Or, the number of in-memory committed but not yet applied entries is over the threshold: we're slow to apply.
unappliedThreshold := commit > applied && commit-applied > pauseQuorumThreshold
return uncommittedThreshold || unappliedThreshold
}
// ForwardProposal will forward the proposal to the leader if known.
// If we are the leader this is the same as calling propose.
func (n *raft) ForwardProposal(entry []byte) error {
@@ -1075,8 +1133,8 @@ func (n *raft) PauseApply() error {
}
func (n *raft) pauseApplyLocked() {
// If we are currently a candidate make sure we step down.
if n.State() == Candidate {
// If we are currently not a follower, make sure we step down.
if n.State() != Follower {
n.stepdownLocked(noLeader)
}
@@ -1558,11 +1616,14 @@ func termAndIndexFromSnapFile(sn string) (term, index uint64, err error) {
// setupLastSnapshot is called at startup to try and recover the last snapshot from
// the disk if possible. We will try to recover the term, index and commit/applied
// indices and then notify the upper layer what we found. Compacts the WAL if needed.
func (n *raft) setupLastSnapshot() {
func (n *raft) setupLastSnapshot() error {
snapDir := filepath.Join(n.sd, snapshotsDir)
psnaps, err := os.ReadDir(snapDir)
if err != nil {
return
if os.IsNotExist(err) {
return errNoSnapAvailable
}
return err
}
var lterm, lindex uint64
@@ -1586,18 +1647,8 @@ func (n *raft) setupLastSnapshot() {
os.Remove(sfile)
}
}
// Now cleanup any old entries
for _, sf := range psnaps {
sfile := filepath.Join(snapDir, sf.Name())
if sfile != latest {
n.debug("Removing old snapshot: %q", sfile)
os.Remove(sfile)
}
}
if latest == _EMPTY_ {
return
return nil
}
// Set latest snapshot we have.
@@ -1607,13 +1658,7 @@ func (n *raft) setupLastSnapshot() {
n.snapfile = latest
snap, err := n.loadLastSnapshot()
if err != nil {
// We failed to recover the last snapshot for some reason, so we will
// assume it has been corrupted and will try to delete it.
if n.snapfile != _EMPTY_ {
os.Remove(n.snapfile)
n.snapfile = _EMPTY_
}
return
return err
}
// We successfully recovered the last snapshot from the disk.
@@ -1627,8 +1672,8 @@ func (n *raft) setupLastSnapshot() {
n.papplied = snap.lastIndex
// Restore the peerState
ps, err := decodePeerState(snap.peerstate)
if err == nil {
n.processPeerState(ps)
if err != nil {
return err
}
n.processPeerState(ps)
n.extSt = ps.domainExt
@@ -1636,7 +1681,19 @@ func (n *raft) setupLastSnapshot() {
n.apply.push(newCommittedEntry(n.commit, []*Entry{{EntrySnapshot, snap.data}}))
if _, err := n.wal.Compact(snap.lastIndex + 1); err != nil {
n.setWriteErrLocked(err)
return err
}
// Now cleanup any old entries. We only do this once we know that the
// latest snapshot was OK.
for _, sf := range psnaps {
if sfile := filepath.Join(snapDir, sf.Name()); sfile != latest {
n.debug("Removing old snapshot: %q", sfile)
os.Remove(sfile)
}
}
return nil
}
// loadLastSnapshot will load and return our last snapshot.
@@ -1652,14 +1709,10 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
if err != nil {
n.warn("Error reading snapshot: %v", err)
os.Remove(n.snapfile)
n.snapfile = _EMPTY_
return nil, err
}
if len(buf) < minSnapshotLen {
n.warn("Snapshot corrupt, too short")
os.Remove(n.snapfile)
n.snapfile = _EMPTY_
return nil, errSnapshotCorrupt
}
@@ -1671,8 +1724,6 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
var hb [highwayhash.Size64]byte
if !bytes.Equal(lchk[:], n.hh.Sum(hb[:0])) {
n.warn("Snapshot corrupt, checksums did not match")
os.Remove(n.snapfile)
n.snapfile = _EMPTY_
return nil, errSnapshotCorrupt
}
@@ -1686,12 +1737,12 @@ func (n *raft) loadLastSnapshot() (*snapshot, error) {
}
// We had a bug in 2.9.12 that would allow snapshots on last index of 0.
// Detect that here and return err.
// Detect that and continue anyway, nothing else we can do about it.
if snap.lastIndex == 0 {
n.warn("Snapshot with last index 0 is invalid, cleaning up")
os.Remove(n.snapfile)
n.snapfile = _EMPTY_
return nil, errSnapshotCorrupt
return nil, errNoSnapAvailable
}
return snap, nil
@@ -2733,9 +2784,9 @@ func decodeAppendEntry(msg []byte, sub *subscription, reply string) (*appendEntr
ae.reply, ae.sub = reply, sub
// Decode Entries.
ne, ri := int(le.Uint16(msg[40:])), uint64(42)
ne, ri := int(le.Uint16(msg[40:])), uint64(appendEntryBaseLen)
for i, max := 0, uint64(len(msg)); i < ne; i++ {
if ri >= max-1 {
if max-ri < 4 {
return nil, errBadAppendEntry
}
ml := uint64(le.Uint32(msg[ri:]))
@@ -2867,20 +2918,44 @@ func (n *raft) handleForwardedProposal(sub *subscription, c *client, _ *Account,
msg = copyBytes(msg)
n.RLock()
prop := n.prop
// Check state under lock, we might not be leader anymore.
if n.State() != Leader || !n.leaderState.Load() {
n.debug("Ignoring forwarded proposal, not leader")
n.RUnlock()
return
}
prop, werr := n.prop, n.werr
n.RUnlock()
// Ignore if we have had a write error previous.
if werr != nil {
if n.werr != nil {
n.RUnlock()
return
}
if n.isLeaderOverrun() {
n.RUnlock()
n.Lock()
defer n.Unlock()
// Now that we've reacquired as write lock, we need to make sure that everything we
// believed before is still true. Otherwise we've either stepped down already from
// another goroutine or we've stopped being overrun and shouldn't drop the entry.
if n.State() != Leader || !n.leaderState.Load() {
return
} else if !n.isLeaderOverrun() {
prop.push(newProposedEntry(newEntry(EntryNormal, msg), reply))
return
}
var state StreamState
n.wal.FastState(&state)
n.warn("Leader falling behind, stepping down: pindex %d, commit %d, applied %d, WAL size %s", n.pindex, n.commit, n.applied, friendlyBytes(state.Bytes))
// Stepdown without leader transfer, likely all replicas will be overrun, and we need time to recover.
n.stepdownLocked(noLeader)
n.overrunCount++
return
}
// Possible that we could fall through to here from multiple connections but if
// one does end up stepping down then the proposal queue gets drained anyway.
n.RUnlock()
prop.push(newProposedEntry(newEntry(EntryNormal, msg), reply))
}
@@ -3252,8 +3327,6 @@ func (n *raft) sendSnapshotToFollower(subject string) (uint64, error) {
if err != nil {
// We need to stepdown here when this happens.
n.stepdownLocked(noLeader)
// We need to reset our state here as well.
n.resetWAL()
return 0, err
}
// Go ahead and send the snapshot and peerstate here as first append entry to the catchup follower.
@@ -4041,6 +4114,42 @@ func (n *raft) processAppendEntry(ae *appendEntry, sub *subscription) {
}
}
// If commits are outpacing our applies, temporarily stop accepting new entries to avoid falling further behind.
// This encourages the leader to sync us via a snapshot instead. We use max(applied, papplied) to avoid
// incorrectly triggering this pause immediately after receiving a snapshot.
applied := max(n.applied, n.papplied)
commit := max(n.commit, n.papplied)
if sub != nil && (commit > applied || n.quorumPaused) {
diff := commit - applied
if n.quorumPaused {
if diff > paeWarnThreshold {
if catchingUp {
n.cancelCatchup()
}
n.Unlock()
return
}
// Once we're sufficiently below the threshold, we continue again. We'll likely receive a snapshot
// from the leader.
n.quorumPaused = false
var state StreamState
n.wal.FastState(&state)
n.warn("Quorum resumed: commit %d, applied %d, WAL size %s", commit, applied, friendlyBytes(state.Bytes))
} else if diff > pauseQuorumThreshold {
// It takes a while until we reach the pause threshold, but once we do we enter a "cooldown period".
n.quorumPaused = true
n.overrunCount++
var state StreamState
n.wal.FastState(&state)
n.warn("Quorum paused, falling behind: commit %d != applied %d, WAL size %s", commit, applied, friendlyBytes(state.Bytes))
if catchingUp {
n.cancelCatchup()
}
n.Unlock()
return
}
}
if ae.pterm != n.pterm || ae.pindex != n.pindex {
// Check if this is a lower or equal index than what we were expecting.
if ae.pindex <= n.pindex {
@@ -4086,9 +4195,26 @@ func (n *raft) processAppendEntry(ae *appendEntry, sub *subscription) {
}
} else {
// If terms mismatched, delete that entry and all others past it.
// Make sure to cancel any catchups in progress.
// Truncate will reset our pterm and pindex. Only do so if we have an entry.
n.truncateWAL(eae.pterm, eae.pindex)
// But only if we haven't already committed past this point.
if eae.pindex < n.commit {
success = true
assert.Unreachable("Truncate to earlier entry would lose commits", map[string]any{
"n.accName": n.accName,
"n.group": n.group,
"n.id": n.id,
"n.term": n.term,
"n.pindex": n.pindex,
"n.commit": n.commit,
"n.applied": n.applied,
"ae.pindex": ae.pindex,
"ae.pterm": ae.pterm,
"ae.commit": ae.commit,
"eae.pterm": eae.pterm,
"eae.pindex": eae.pindex,
})
} else {
n.truncateWAL(eae.pterm, eae.pindex)
}
}
// Cancel regardless if unsuccessful.
if !success {
@@ -4420,9 +4546,10 @@ func (n *raft) storeToWAL(ae *appendEntry) error {
}
const (
paeDropThreshold = 20_000
paeWarnThreshold = 10_000
paeWarnModulo = 5_000
pauseQuorumThreshold = 100_000
paeDropThreshold = 20_000
paeWarnThreshold = 10_000
paeWarnModulo = 5_000
)
func (n *raft) sendAppendEntry(entries []*Entry) {
@@ -4699,11 +4826,18 @@ func (n *raft) setWriteErrLocked(err error) {
}
// If this is a not found report but do not disable.
if os.IsNotExist(err) {
n.error("Resource not found: %v", err)
n.warn("Resource not found: %v", err)
return
}
n.error("Critical write error: %v", err)
n.werr = err
n.shutdown()
assert.Unreachable("Raft encountered write error", map[string]any{
"n.accName": n.accName,
"n.group": n.group,
"n.id": n.id,
"err": err,
})
if isPermissionError(err) {
go n.s.handleWritePermissionError()
@@ -4720,6 +4854,13 @@ func (n *raft) isClosed() bool {
return n.State() == Closed
}
// GetWriteErr returns the write error (if any).
func (n *raft) GetWriteErr() error {
n.RLock()
defer n.RUnlock()
return n.werr
}
// Capture our write error if any and hold.
func (n *raft) setWriteErr(err error) {
n.Lock()
@@ -5033,6 +5174,8 @@ func (n *raft) switchToCandidate() {
// Increment the term.
n.term++
n.vote = noVote
// Reset quorum paused. If it was previously set, we checked above that we've applied all committed entries.
n.quorumPaused = false
// Clear current Leader.
n.updateLeader(noLeader)
n.switchState(Candidate)
+480 -272
View File
@@ -1,4 +1,4 @@
// Copyright 2017-2025 The NATS Authors
// Copyright 2017-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
@@ -740,6 +740,35 @@ func (jso jetStreamOption) IsStatszChange() bool {
return true
}
type jetStreamLimitsOption struct {
noopOption
newMaxMemory int64
newMaxStore int64
}
func (jso *jetStreamLimitsOption) Apply(s *Server) {
js := s.getJetStream()
if js == nil {
return
}
js.mu.Lock()
if jso.newMaxMemory > 0 {
js.config.MaxMemory = jso.newMaxMemory
atomic.StoreInt64(&js.memMax, js.config.MaxMemory)
s.Noticef("Reloaded: JetStream max_mem_store = %s", friendlyBytes(jso.newMaxMemory))
}
if jso.newMaxStore > 0 {
js.config.MaxStore = jso.newMaxStore
atomic.StoreInt64(&js.storeMax, js.config.MaxStore)
s.Noticef("Reloaded: JetStream max_file_store = %s", friendlyBytes(jso.newMaxStore))
}
js.mu.Unlock()
}
func (jso *jetStreamLimitsOption) IsStatszChange() bool {
return true
}
type defaultSentinelOption struct {
noopOption
newValue string
@@ -872,9 +901,207 @@ func (o *profBlockRateReload) Apply(s *Server) {
type leafNodeOption struct {
noopOption
tlsFirstChanged bool
compressionChanged bool
// These are for the remotes
added []*RemoteLeafOpts
changed map[*leafNodeCfg]*remoteLeafOption
}
type remoteLeafOption struct {
tlsFirstChanged bool
compressionChanged bool
disabledChanged bool
opts *RemoteLeafOpts
}
// Given `old` and `new` Leafnode options, this function will return the structure
// used for applying the configuration, or an error is there are changes that
// are not supported.
func getLeafNodeOptionsChanges(s *Server, old, new *LeafNodeOpts) (*leafNodeOption, error) {
// We can't use DeepEqual for `Users` field, so do custom check.
if usersHaveChanged(old.Users, new.Users) {
return nil, fmt.Errorf("field \"Users\": old=%v, new=%v", old.Users, new.Users)
}
// Check the main leafnodes{} block to see if there are any changes that are
// not supported. We provide a list of fields to ignore (we already checked,
// allow them to be modified or will check later).
if err := checkConfigsEqual(old, new, []string{
"Compression",
"Remotes",
"TLSHandshakeFirst",
"TLSHandshakeFirstFallback",
"TLSConfig",
"Users",
}); err != nil {
return nil, err
}
const (
remoteErrFormat = "remote %s: %s"
maxAttempts = 20
)
var (
nlo *leafNodeOption
// Track whether any existing remote was not found (i.e. removed).
removed bool
)
forLoop:
for failed := range maxAttempts {
removed = false
if failed > 0 {
// If we failed once, we will wait a bit before trying again the remotes.
// This could give enough time for connections that were in progress to complete.
select {
case <-time.After(50 * time.Millisecond):
case <-s.quitCh:
return nil, ErrServerNotRunning
}
}
nlo = &leafNodeOption{
tlsFirstChanged: (old.TLSHandshakeFirst != new.TLSHandshakeFirst || old.TLSHandshakeFirstFallback != new.TLSHandshakeFirstFallback),
compressionChanged: !old.Compression.equals(&new.Compression),
// Start with all, will update when processing existing ones.
// Since the list will be modified, we need to clone it.
added: slices.Clone(new.Remotes),
}
s.mu.RLock()
// Go through the list of existing remote configurations.
for lrc := range s.leafRemoteCfgs {
var rlo *RemoteLeafOpts
// Look for the corresponding `*RemoteLeafOpts` in the `nlo.added`
// list. If it is found, that function returns an updated list
// with the element removed from it.
lrc.RLock()
rlo, nlo.added = getRemoteLeafOpts(lrc.name(), nlo.added)
if rlo == nil {
// Not found, will be removed in leafNodeOption.Apply().
removed = true
lrc.RUnlock()
continue
}
// Now we need to make sure that there are no changes that we don't
// support for a RemoteLeafOpts.
err := checkConfigsEqual(lrc.RemoteLeafOpts, rlo, []string{
"Compression",
"Disabled",
"TLS",
"TLSHandshakeFirst",
"TLSConfig",
})
if err != nil {
lrc.RUnlock()
s.mu.RUnlock()
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(), err)
}
disabledChanged := lrc.Disabled != rlo.Disabled
// If this remote was disabled and is now enabled, we need to make sure
// that there is no connect in progress. If that is the case, either
// try again (if it is the first failure) or return an error.
if disabledChanged && lrc.Disabled && lrc.connInProgress {
lrc.RUnlock()
s.mu.RUnlock()
if failed < maxAttempts-1 {
continue forLoop
}
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(),
"cannot be enabled at the moment, try again")
}
// Since we will use the new `rlo.TLSConfig` later on, consider all
// existing remote configs as "changed" and store them in the
// `nlo.changed` map.
if nlo.changed == nil {
nlo.changed = make(map[*leafNodeCfg]*remoteLeafOption)
}
lnro := &remoteLeafOption{
tlsFirstChanged: lrc.TLSHandshakeFirst != rlo.TLSHandshakeFirst,
compressionChanged: !lrc.Compression.equals(&rlo.Compression),
disabledChanged: disabledChanged,
opts: rlo,
}
lrc.RUnlock()
nlo.changed[lrc] = lnro
}
if len(nlo.added) > 0 {
// Go through the added list and check if an added was recently removed and,
// if that is the case, is it still in the `s.rmLeafRemoteCfgs` map, which
// may mean that there was a connect-in-progress that did not complete yet.
// Either try again (if it is the first failure) or return an error.
for _, rlo := range nlo.added {
if _, cip := s.rmLeafRemoteCfgs[rlo.name()]; cip {
s.mu.RUnlock()
if failed < maxAttempts-1 {
continue forLoop
}
return nil, fmt.Errorf(remoteErrFormat, rlo.safeName(),
"cannot be added at the moment, try again")
}
}
}
s.mu.RUnlock()
break
}
// Now we want to make sure that there were actual changes, so that we don't
// cause a reload of leafnodes for nothing. However, if one has (or all have)
// been removed we still need to invoke leafNodeOption.Apply().
if !nlo.tlsFirstChanged && !nlo.compressionChanged && !removed && len(nlo.added) == 0 && len(nlo.changed) == 0 {
return nil, nil
}
return nlo, nil
}
func usersHaveChanged(ousers, nusers []*User) bool {
if len(ousers) != len(nusers) {
return true
}
// We did not do a strict list order check in the past, so maintain this to
// avoid possible breaking changes.
oua := make(map[string]*User, len(ousers))
nua := make(map[string]*User, len(nusers))
for _, u := range ousers {
oua[u.Username] = u
}
for _, u := range nusers {
nua[u.Username] = u
}
for uname, u := range oua {
// If we can not find new one with same name, consider that they have changed.
nu, ok := nua[uname]
if !ok {
return true
}
// Same if password or account has changed.
if u.Password != nu.Password || u.Account.GetName() != nu.Account.GetName() {
return true
}
}
return false
}
// Given the `search` remote leafnode options name, searches for a match in the `list`.
// If found, returns the `*RemoteLeafOpts` from the list, and the updated list
// without the element in it. If not found, returns `nil` and the unmodified list.
func getRemoteLeafOpts(search string, list []*RemoteLeafOpts) (*RemoteLeafOpts, []*RemoteLeafOpts) {
for i, rlo := range list {
if search == rlo.name() {
lastIdx := len(list) - 1
if lastIdx == 0 {
return rlo, nil
}
if i < lastIdx {
list[i] = list[lastIdx]
}
list = list[:lastIdx]
return rlo, list
}
}
return nil, list
}
func (l *leafNodeOption) Apply(s *Server) {
@@ -882,101 +1109,181 @@ func (l *leafNodeOption) Apply(s *Server) {
if l.tlsFirstChanged {
s.Noticef("Reloaded: LeafNode TLS HandshakeFirst value is: %v", opts.LeafNode.TLSHandshakeFirst)
s.Noticef("Reloaded: LeafNode TLS HandshakeFirstFallback value is: %v", opts.LeafNode.TLSHandshakeFirstFallback)
for _, r := range opts.LeafNode.Remotes {
s.Noticef("Reloaded: LeafNode Remote to %v TLS HandshakeFirst value is: %v", r.URLs, r.TLSHandshakeFirst)
}
}
if l.compressionChanged || l.disabledChanged {
var leafs []*client
var solicit []*leafNodeCfg
acceptSideCompOpts := &opts.LeafNode.Compression
if l.compressionChanged {
s.Noticef("Reloaded: LeafNode Compression value is: %v", opts.LeafNode.Compression)
}
s.mu.RLock()
// First, update our internal leaf remote configurations with the new
// compress options.
// Since changing the remotes (as in adding/removing) is currently not
// supported, we know that we should have the same number in Options
// than in leafRemoteCfgs, but to be sure, use the max size.
max := len(opts.LeafNode.Remotes)
if l := len(s.leafRemoteCfgs); l < max {
max = l
}
for i := range max {
lr := s.leafRemoteCfgs[i]
or := opts.LeafNode.Remotes[i]
lr.Lock()
lr.Compression = or.Compression
if lr.Disabled && !or.Disabled {
solicit = append(solicit, lr)
var close []*client
var enable []*leafNodeCfg
var removed bool
s.mu.Lock()
acceptSideCompOpts := &opts.LeafNode.Compression
// First go over existing leafnode remote configurations and
// either remove if no longer present, or update the config.
for lrc := range s.leafRemoteCfgs {
rlo := l.changed[lrc]
if rlo == nil {
delete(s.leafRemoteCfgs, lrc)
removed = true
if s.rmLeafRemoteCfgs == nil {
s.rmLeafRemoteCfgs = make(map[string]*leafNodeCfg)
}
lr.Disabled = or.Disabled
lr.Unlock()
s.rmLeafRemoteCfgs[lrc.name()] = lrc
lrc.markAsRemoved()
s.Noticef("Reloaded: LeafNode Remote %s removed", lrc.RemoteLeafOpts.safeName())
// We will close the existing connection in the next for-loop.
continue
}
for _, l := range s.leafs {
var co *CompressionOpts
l.mu.Lock()
if r := l.leaf.remote; r != nil {
// If newly marked as disabled, collect and ignore the rest.
if r.Disabled {
l.flags.set(noReconnect)
leafs = append(leafs, l)
l.mu.Unlock()
continue
}
co = &r.Compression
lrc.Lock()
// TLSConfig is always applied.
lrc.TLSConfig = rlo.opts.TLSConfig.Clone()
// Now update what has been detected has changed.
if rlo.tlsFirstChanged {
lrc.TLSHandshakeFirst = rlo.opts.TLSHandshakeFirst
s.Noticef("Reloaded: LeafNode Remote %s TLS HandshakeFirst value is: %v",
lrc.RemoteLeafOpts.safeName(), rlo.opts.TLSHandshakeFirst)
}
if rlo.compressionChanged {
lrc.Compression = rlo.opts.Compression
s.Noticef("Reloaded: LeafNode Remote %s Compression value is: %v",
lrc.RemoteLeafOpts.safeName(), rlo.opts.Compression)
}
if rlo.disabledChanged {
// Change to new value.
lrc.Disabled = rlo.opts.Disabled
if lrc.Disabled {
lrc.notifyQuitChannel()
} else {
co = acceptSideCompOpts
enable = append(enable, lrc)
}
newMode := co.Mode
// Skip leaf connections that are "not supported" (because they
// will never do compression) or the ones that have already the
// new compression mode.
if l.leaf.compression == CompressionNotSupported || l.leaf.compression == newMode {
l.mu.Unlock()
s.Noticef("Reloaded: LeafNode Remote %s Disabled value is: %v",
lrc.RemoteLeafOpts.safeName(), rlo.opts.Disabled)
}
lrc.Unlock()
}
// Second, go over existing leaf connections and apply compression
// changes (if applicable) and collect connections that need to be
// closed and/or disabled.
for _, c := range s.leafs {
var co *CompressionOpts
c.mu.Lock()
if r := c.leaf.remote; r != nil {
rlo := l.changed[r]
// If the config is not in the `changed` map, or the new config says that
// the connection is disabled, collect so we can close it after the server
// lock is released.
if rlo == nil || (rlo.disabledChanged && rlo.opts.Disabled) {
c.flags.set(noReconnect)
close = append(close, c)
c.mu.Unlock()
continue
}
// We need to close the connections if it had compression "off" or the new
// mode is compression "off", or if the new mode is "accept", because
// these require negotiation.
if l.leaf.compression == CompressionOff || newMode == CompressionOff || newMode == CompressionAccept {
leafs = append(leafs, l)
} else if newMode == CompressionS2Auto {
// If the mode is "s2_auto", we need to check if there is really
// need to change, and at any rate, we want to save the actual
// compression level here, not s2_auto.
l.updateS2AutoCompressionLevel(co, &l.leaf.compression)
} else {
// Simply change the compression writer
l.out.cw = s2.NewWriter(nil, s2WriterOptions(newMode)...)
l.leaf.compression = newMode
if rlo.compressionChanged {
co = &r.Compression
}
l.mu.Unlock()
} else if l.compressionChanged {
co = acceptSideCompOpts
}
s.mu.RUnlock()
// Close the connections for which negotiation is required, or that
// have been disabled.
for _, l := range leafs {
l.closeConnection(ClientClosed)
if co != nil && applyCompressionChanges(c, co) {
close = append(close, c)
}
if l.compressionChanged {
s.Noticef("Reloaded: LeafNode compression settings")
}
if l.disabledChanged {
if len(leafs) > 0 {
s.Noticef("Reloaded: LeafNode(s) disabled")
}
if len(solicit) > 0 {
for _, remote := range solicit {
s.startGoRoutine(func() { s.connectToRemoteLeafNode(remote, true) })
}
s.Noticef("Reloaded: LeafNode(s) enabled")
}
c.mu.Unlock()
}
s.mu.Unlock()
// Close the connections for which negotiation is required, have been disabled
// or simply removed.
for _, c := range close {
c.closeConnection(ClientClosed)
}
// Start the ones that have been enabled.
for _, r := range enable {
s.connectToRemoteLeafNodeAsynchronously(r, true)
}
// Finally, deal with the ones that have been added.
if len(l.added) > 0 {
s.solicitLeafNodeRemotes(l.added)
}
// Deal with removed configs. Make sure there are no connect-in-progress.
// If there are still some, have a go routine to check in the background.
if removed {
if checkAgain := checkRemovedLeafNodeCfgs(s); checkAgain {
checkRemovedLeafNodeCfgsAsync(s)
}
}
}
// Go through the removed remote leafnode configs map to check if the
// connect-in-progress flag is set. If not, remove from the map.
// Returns `true` if there are still some that are in progress.
func checkRemovedLeafNodeCfgs(s *Server) bool {
var inProgress int
s.mu.Lock()
for rn, r := range s.rmLeafRemoteCfgs {
if r.isConnectInProgress() {
inProgress++
} else {
delete(s.rmLeafRemoteCfgs, rn)
}
}
s.mu.Unlock()
// Needs to be called again if inProgress > 0
return inProgress > 0
}
// Will start a go routine that will periodically call `checkRemovedLeafNodeCfgs`.
// When the removed map has been emptied, the go routine will end. It is ok for
// this function to be invoked multiple times and have multiple instances running
// concurrently.
func checkRemovedLeafNodeCfgsAsync(s *Server) {
s.startGoRoutine(func() {
defer s.grWG.Done()
tick := time.NewTicker(50 * time.Millisecond)
defer tick.Stop()
for {
select {
case <-tick.C:
if checkAgain := checkRemovedLeafNodeCfgs(s); !checkAgain {
return
}
case <-s.quitCh:
return
}
}
})
}
// The `co` compression options are applied to the given leaf connection `c`.
// If a "restart" of the connection is needed, will return true, false otherwise.
func applyCompressionChanges(c *client, co *CompressionOpts) bool {
newMode := co.Mode
// Skip leaf connections that are "not supported" (because they
// will never do compression) or the ones that have already the
// new compression mode.
if c.leaf.compression == CompressionNotSupported || c.leaf.compression == newMode {
return false
}
// We need to close the connections if it had compression "off" or the new
// mode is compression "off", or if the new mode is "accept", because
// these require negotiation.
if c.leaf.compression == CompressionOff || newMode == CompressionOff || newMode == CompressionAccept {
return true
} else if newMode == CompressionS2Auto {
// If the mode is "s2_auto", we need to check if there is really
// need to change, and at any rate, we want to save the actual
// compression level here, not s2_auto.
c.updateS2AutoCompressionLevel(co, &c.leaf.compression)
} else {
// Simply change the compression writer
c.out.cw = s2.NewWriter(nil, s2WriterOptions(newMode)...)
c.leaf.compression = newMode
}
return false
}
type noFastProdStallReload struct {
noopOption
noStall bool
@@ -1031,7 +1338,7 @@ func (s *Server) recheckPinnedCerts(curOpts *Options, newOpts *Options) {
}
})
}
if s.gateway.enabled && reflect.DeepEqual(newOpts.Gateway.TLSPinnedCerts, curOpts.Gateway.TLSPinnedCerts) {
if s.gateway.enabled && !reflect.DeepEqual(newOpts.Gateway.TLSPinnedCerts, curOpts.Gateway.TLSPinnedCerts) {
gw := s.gateway
gw.RLock()
for _, c := range gw.out {
@@ -1115,11 +1422,6 @@ func (s *Server) ReloadOptions(newOpts *Options) error {
curOpts := s.getOpts()
// Wipe trusted keys if needed when we have an operator.
if len(curOpts.TrustedOperators) > 0 && len(curOpts.TrustedKeys) > 0 {
curOpts.TrustedKeys = nil
}
clientOrgPort := curOpts.Port
clusterOrgPort := curOpts.Cluster.Port
gatewayOrgPort := curOpts.Gateway.Port
@@ -1215,15 +1517,18 @@ func (s *Server) reloadOptions(curOpts, newOpts *Options) error {
newOpts.CustomClientAuthentication = curOpts.CustomClientAuthentication
newOpts.CustomRouterAuthentication = curOpts.CustomRouterAuthentication
changed, err := s.diffOptions(newOpts)
if err != nil {
// Do the validation before checking for differences. We need to ensure
// that the new options are valid. Note that there are possible side
// effects of calling validateOptions(), in that some default values
// may be set, etc... but that should be ok since the current options
// went through the same process on startup/previous reload.
if err := validateOptions(newOpts); err != nil {
return err
}
if len(changed) != 0 {
if err := validateOptions(newOpts); err != nil {
return err
}
changed, err := s.diffOptions(newOpts)
if err != nil {
return err
}
// Create a context that is used to pass special info that we may need
@@ -1258,7 +1563,7 @@ func imposeOrder(value any) error {
slices.Sort(value.AllowedOrigins)
case string, bool, uint8, uint16, uint64, int, int32, int64, time.Duration, float64, nil, LeafNodeOpts, ClusterOpts, *tls.Config, PinnedCertSet,
*URLAccResolver, *MemAccResolver, *DirAccResolver, *CacheDirAccResolver, Authentication, MQTTOpts, jwt.TagList,
*OCSPConfig, map[string]string, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig, *ProxiesConfig, WriteTimeoutPolicy:
*OCSPConfig, map[string]string, map[string]bool, JSLimitOpts, StoreCipher, *OCSPResponseCacheConfig, *ProxiesConfig, WriteTimeoutPolicy:
// explicitly skipped types
case *AuthCallout:
case JSTpmOpts:
@@ -1275,9 +1580,11 @@ func imposeOrder(value any) error {
// error.
func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
var (
oldConfig = reflect.ValueOf(s.getOpts()).Elem()
oldOpts = s.getOpts()
oldConfig = reflect.ValueOf(oldOpts).Elem()
newConfig = reflect.ValueOf(newOpts).Elem()
diffOpts = []option{}
skipTKeys = len(oldOpts.TrustedOperators) > 0 && len(oldOpts.TrustedKeys) > 0
// Need to keep track of whether JS is being disabled
// to prevent changing limits at runtime.
@@ -1286,6 +1593,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
jsMemLimitsChanged bool
jsFileLimitsChanged bool
jsStoreDirChanged bool
jsLimitsUpdate *jetStreamLimitsOption
)
for i := 0; i < oldConfig.NumField(); i++ {
field := oldConfig.Type().Field(i)
@@ -1294,6 +1602,17 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
if field.PkgPath != _EMPTY_ {
continue
}
optName := strings.ToLower(field.Name)
if skipTKeys && optName == "trustedkeys" {
// TrustedOperators and TrustedKeys change is not supported. During options
// validation, if they are both specified, a conflict error is returned.
// If only TrustedOperators is specified, the TrustedKeys is filled with
// the operators' signing keys. So here, if we detect that the current
// options have operators, we don't do the trusted keys comparison, so
// we can fail with the "not supported for TrustedOperators" config reload
// error instead of TrustedKeys (that the user would not have set).
continue
}
var (
oldValue = oldConfig.Field(i).Interface()
newValue = newConfig.Field(i).Interface()
@@ -1305,7 +1624,6 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
return nil, err
}
optName := strings.ToLower(field.Name)
// accounts and users (referencing accounts) will always differ as accounts
// contain internal state, say locks etc..., so we don't bother here.
// This also avoids races with atomic stats counters
@@ -1374,7 +1692,7 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
co := &clusterOption{
newValue: newClusterOpts,
permsChanged: !reflect.DeepEqual(newClusterOpts.Permissions, oldClusterOpts.Permissions),
compressChanged: !compressOptsEqual(&oldClusterOpts.Compression, &newClusterOpts.Compression),
compressChanged: !oldClusterOpts.Compression.equals(&newClusterOpts.Compression),
}
co.diffPoolAndAccounts(&oldClusterOpts)
// If there are added accounts, first make sure that we can look them up.
@@ -1445,6 +1763,11 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
tmpOld.tlsConfigOpts = nil
tmpNew.tlsConfigOpts = nil
// Allow TLSPinnedCerts through reload, existing connections
// are checked in recheckPinnedCerts
tmpOld.TLSPinnedCerts = nil
tmpNew.TLSPinnedCerts = nil
// Need to do the same for remote gateways' TLS configs.
// But we can't just set remotes' TLSConfig to nil otherwise this
// would lose the real TLS configuration.
@@ -1458,149 +1781,19 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
field.Name, oldValue, newValue)
}
case "leafnode":
// Similar to gateways
tmpOld := oldValue.(LeafNodeOpts)
tmpNew := newValue.(LeafNodeOpts)
tmpOld.TLSConfig = nil
tmpNew.TLSConfig = nil
tmpOld.tlsConfigOpts = nil
tmpNew.tlsConfigOpts = nil
// We will allow TLSHandshakeFirst to be config reloaded. First,
// we just want to detect if there was a change in the leafnodes{}
// block, and if not, we will check the remotes.
handshakeFirstChanged := tmpOld.TLSHandshakeFirst != tmpNew.TLSHandshakeFirst ||
tmpOld.TLSHandshakeFirstFallback != tmpNew.TLSHandshakeFirstFallback
// If changed, set them (in the temporary variables) to false so that the
// rest of the comparison does not fail.
if handshakeFirstChanged {
tmpOld.TLSHandshakeFirst, tmpNew.TLSHandshakeFirst = false, false
tmpOld.TLSHandshakeFirstFallback, tmpNew.TLSHandshakeFirstFallback = 0, 0
} else if len(tmpOld.Remotes) == len(tmpNew.Remotes) {
// Since we don't support changes in the remotes, we will do a
// simple pass to see if there was a change of this field.
for i := 0; i < len(tmpOld.Remotes); i++ {
if tmpOld.Remotes[i].TLSHandshakeFirst != tmpNew.Remotes[i].TLSHandshakeFirst {
handshakeFirstChanged = true
break
}
}
lno, err := getLeafNodeOptionsChanges(s, &tmpOld, &tmpNew)
// If there was an unsupported change, we will get an error with the name
// of the (first) field and its old and new value.
if err != nil {
return nil, fmt.Errorf("config reload not supported for %s: %v", field.Name, err)
}
// We also support config reload for compression. Check if it changed before
// blanking them out for the deep-equal check at the end.
compressionChanged := !compressOptsEqual(&tmpOld.Compression, &tmpNew.Compression)
if compressionChanged {
tmpOld.Compression, tmpNew.Compression = CompressionOpts{}, CompressionOpts{}
} else if len(tmpOld.Remotes) == len(tmpNew.Remotes) {
// Same that for tls first check, do the remotes now.
for i := range len(tmpOld.Remotes) {
if !compressOptsEqual(&tmpOld.Remotes[i].Compression, &tmpNew.Remotes[i].Compression) {
compressionChanged = true
break
}
}
// If there was an actual change...
if lno != nil {
diffOpts = append(diffOpts, lno)
}
// Check if the "disabled" option of each remote has changed.
var disabledChanged bool
for i := range len(tmpOld.Remotes) {
if tmpOld.Remotes[i].Disabled != tmpNew.Remotes[i].Disabled {
disabledChanged = true
break
}
}
// Need to do the same for remote leafnodes' TLS configs.
// But we can't just set remotes' TLSConfig to nil otherwise this
// would lose the real TLS configuration.
tmpOld.Remotes = copyRemoteLNConfigForReloadCompare(tmpOld.Remotes)
tmpNew.Remotes = copyRemoteLNConfigForReloadCompare(tmpNew.Remotes)
// Special check for leafnode remotes changes which are not supported right now.
leafRemotesChanged := func(a, b LeafNodeOpts) bool {
if len(a.Remotes) != len(b.Remotes) {
return true
}
// Check whether all remotes URLs are still the same.
for _, oldRemote := range a.Remotes {
var found bool
if oldRemote.LocalAccount == _EMPTY_ {
oldRemote.LocalAccount = globalAccountName
}
for _, newRemote := range b.Remotes {
// Bind to global account in case not defined.
if newRemote.LocalAccount == _EMPTY_ {
newRemote.LocalAccount = globalAccountName
}
if reflect.DeepEqual(oldRemote, newRemote) {
found = true
break
}
}
if !found {
return true
}
}
return false
}
// First check whether remotes changed at all. If they did not,
// skip them in the complete equal check.
if !leafRemotesChanged(tmpOld, tmpNew) {
tmpOld.Remotes = nil
tmpNew.Remotes = nil
}
// Special check for auth users to detect changes.
// If anything is off will fall through and fail below.
// If we detect they are semantically the same we nil them out
// to pass the check below.
if tmpOld.Users != nil || tmpNew.Users != nil {
if len(tmpOld.Users) == len(tmpNew.Users) {
oua := make(map[string]*User, len(tmpOld.Users))
nua := make(map[string]*User, len(tmpOld.Users))
for _, u := range tmpOld.Users {
oua[u.Username] = u
}
for _, u := range tmpNew.Users {
nua[u.Username] = u
}
same := true
for uname, u := range oua {
// If we can not find new one with same name, drop through to fail.
nu, ok := nua[uname]
if !ok {
same = false
break
}
// If username or password or account different break.
if u.Username != nu.Username || u.Password != nu.Password || u.Account.GetName() != nu.Account.GetName() {
same = false
break
}
}
// We can nil out here.
if same {
tmpOld.Users, tmpNew.Users = nil, nil
}
}
}
// If there is really a change prevents reload.
if !reflect.DeepEqual(tmpOld, tmpNew) {
// See TODO(ik) note below about printing old/new values.
return nil, fmt.Errorf("config reload not supported for %s: old=%v, new=%v",
field.Name, oldValue, newValue)
}
diffOpts = append(diffOpts, &leafNodeOption{
tlsFirstChanged: handshakeFirstChanged,
compressionChanged: compressionChanged,
disabledChanged: disabledChanged,
})
case "jetstream":
new := newValue.(bool)
old := oldValue.(bool)
@@ -1636,26 +1829,35 @@ func (s *Server) diffOptions(newOpts *Options) ([]option, error) {
fromSet = !fromUnset
toUnset = new == -1
toSet = !toUnset
increased = fromSet && toSet && new > old
)
if jsEnabled && modified {
// Cannot change limits from dynamic storage at runtime.
switch {
case increased:
// Allowed to increase, but not decrease.
if jsLimitsUpdate == nil {
jsLimitsUpdate = &jetStreamLimitsOption{}
diffOpts = append(diffOpts, jsLimitsUpdate)
}
if optName == "jetstreammaxmemory" {
jsLimitsUpdate.newMaxMemory = new
} else {
jsLimitsUpdate.newMaxStore = new
}
case fromSet && toUnset:
// Limits changed but it may mean that JS is being disabled,
// keep track of the change and error in case it is not.
switch optName {
case "jetstreammaxmemory":
if optName == "jetstreammaxmemory" {
jsMemLimitsChanged = true
case "jetstreammaxstore":
} else {
jsFileLimitsChanged = true
default:
return nil, fmt.Errorf("config reload not supported for jetstream max memory and store")
}
case fromUnset && toSet:
// Prevent changing from dynamic max memory / file at runtime.
return nil, fmt.Errorf("config reload not supported for jetstream dynamic max memory and store")
default:
return nil, fmt.Errorf("config reload not supported for jetstream max memory and store")
return nil, fmt.Errorf("config reload not supported for decreasing jetstream max memory and store")
}
}
case "jetstreammetacompact", "jetstreammetacompactsize", "jetstreammetacompactsync":
@@ -1801,32 +2003,6 @@ func copyRemoteGWConfigsWithoutTLSConfig(current []*RemoteGatewayOpts) []*Remote
return rgws
}
func copyRemoteLNConfigForReloadCompare(current []*RemoteLeafOpts) []*RemoteLeafOpts {
l := len(current)
if l == 0 {
return nil
}
rlns := make([]*RemoteLeafOpts, 0, l)
for _, rcfg := range current {
cp := *rcfg
cp.TLSConfig = nil
cp.tlsConfigOpts = nil
cp.TLSHandshakeFirst = false
// This is set only when processing a CONNECT, so reset here so that we
// don't fail the DeepEqual comparison.
cp.TLS = false
// For now, remove DenyImports/Exports since those get modified at runtime
// to add JS APIs.
cp.DenyImports, cp.DenyExports = nil, nil
// Remove compression mode
cp.Compression = CompressionOpts{}
// Reset disabled status
cp.Disabled = false
rlns = append(rlns, &cp)
}
return rlns
}
func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
var (
reloadLogging = false
@@ -1896,15 +2072,12 @@ func (s *Server) applyOptions(ctx *reloadContext, opts []option) {
s.sendStatszUpdate()
}
// For remote gateways and leafnodes, make sure that their TLS configuration
// For remote gateways, make sure that their TLS configuration
// is updated (since the config is "captured" early and changes would otherwise
// not be visible).
if s.gateway.enabled {
s.gateway.updateRemotesTLSConfig(newOpts)
}
if len(newOpts.LeafNode.Remotes) > 0 {
s.updateRemoteLeafNodesTLSConfig(newOpts)
}
// Always restart OCSP monitoring on reload.
if err := s.reloadOCSP(); err != nil {
@@ -2650,3 +2823,38 @@ func diffProxiesTrustedKeys(old, new []*ProxyConfig) ([]string, []string) {
}
return add, del
}
// This function calls `reflect.DeepEqual` on all public fields that are
// not part of the `ignoreFields` list. If they are all equal, returns nil,
// otherwise returns an error that will contain the name of the first field
// that fails the comparison, along with its old and new values.
func checkConfigsEqual(c1, c2 any, ignoreFields []string) error {
oldConfig := reflect.ValueOf(c1).Elem()
newConfig := reflect.ValueOf(c2).Elem()
for i := 0; i < oldConfig.NumField(); i++ {
field := oldConfig.Type().Field(i)
// field.PkgPath is empty for exported fields, and is not for unexported ones.
// We skip the unexported fields.
if field.PkgPath != _EMPTY_ {
continue
}
// If it is in the set of fields to ignore, move to the next.
// We expect the number of ignore fields to be small.
var ignored bool
for _, f := range ignoreFields {
if f == field.Name {
ignored = true
break
}
}
if ignored {
continue
}
oldValue := oldConfig.Field(i).Interface()
newValue := newConfig.Field(i).Interface()
if !reflect.DeepEqual(oldValue, newValue) {
return fmt.Errorf("field %q: old=%v, new=%v", field.Name, oldValue, newValue)
}
}
return nil
}
+126 -13
View File
@@ -19,6 +19,7 @@ import (
"io"
"math"
"slices"
"strings"
"time"
"github.com/nats-io/nats-server/v2/server/thw"
@@ -77,6 +78,18 @@ func (ms *MsgScheduling) init(seq uint64, subj string, ts int64) {
delete(ms.inflight, subj)
}
func (ms *MsgScheduling) update(subj string, ts int64) {
if sched, ok := ms.schedules[subj]; ok {
// Remove and add separately, it's for the same sequence, but if replicated
// this server could not know the previous timestamp yet.
ms.ttls.Remove(sched.seq, sched.ts)
ms.ttls.Add(sched.seq, ts)
sched.ts = ts
delete(ms.inflight, subj)
ms.resetTimer()
}
}
func (ms *MsgScheduling) markInflight(subj string) {
if _, ok := ms.schedules[subj]; ok {
ms.inflight[subj] = struct{}{}
@@ -90,8 +103,7 @@ func (ms *MsgScheduling) isInflight(subj string) bool {
func (ms *MsgScheduling) remove(seq uint64) {
if subj, ok := ms.seqToSubj[seq]; ok {
delete(ms.seqToSubj, seq)
delete(ms.schedules, subj)
ms.removeSubject(subj)
}
}
@@ -100,6 +112,7 @@ func (ms *MsgScheduling) removeSubject(subj string) {
ms.ttls.Remove(sched.seq, sched.ts)
delete(ms.schedules, subj)
delete(ms.seqToSubj, sched.seq)
delete(ms.inflight, subj)
}
}
@@ -142,11 +155,12 @@ func (ms *MsgScheduling) resetTimer() {
}
}
func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg) []*inMsg {
func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg, loadLast func(subj string, smv *StoreMsg) *StoreMsg) []*inMsg {
var (
smv StoreMsg
sm *StoreMsg
msgs []*inMsg
smv StoreMsg
srcSmv StoreMsg
sm *StoreMsg
msgs []*inMsg
)
ms.ttls.ExpireTasks(func(seq uint64, ts int64) bool {
// Need to grab the message for the specified sequence, and check
@@ -155,10 +169,26 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
if sm != nil {
// If already inflight, don't duplicate a scheduled message. The stream could
// be replicated and the scheduled message could take some time to propagate.
if ms.isInflight(sm.subj) {
subj := sm.subj
if ms.isInflight(subj) {
return false
}
// Validate the contents are correct if not, we just remove it from THW.
pattern := bytesToString(sliceHeader(JSSchedulePattern, sm.hdr))
if pattern == _EMPTY_ {
ms.remove(seq)
return true
}
loc, apiErr := loadMessageScheduleLocation(sm.hdr)
if apiErr != nil {
ms.remove(seq)
return true
}
next, repeat, ok := parseMsgSchedule(pattern, loc, ts)
if !ok {
ms.remove(seq)
return true
}
ttl, ok := getMessageScheduleTTL(sm.hdr)
if !ok {
ms.remove(seq)
@@ -169,27 +199,43 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
ms.remove(seq)
return true
}
rollup := getMessageScheduleRollup(sm.hdr)
source := getMessageScheduleSource(sm.hdr)
if source != _EMPTY_ {
// Fall back to the scheduled message's own content if the source has no last message.
if srcSm := loadLast(source, &srcSmv); srcSm != nil {
sm = srcSm
}
}
// Copy, as this is retrieved directly from storage, and we'll need to keep hold of this for some time.
// And in the case of headers, we'll copy all of them, but make changes.
hdr, msg := copyBytes(sm.hdr), copyBytes(sm.msg)
// Strip headers specific to the schedule.
hdr = removeHeaderIfPresent(hdr, JSSchedulePattern)
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Schedule-")
// Strip headers specific to message scheduling.
// Covers Nats-Schedule, Nats-Schedule-*, and Nats-Scheduler.
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Schedule")
// Strip headers that could prevent persisting this scheduled message.
hdr = removeHeaderIfPrefixPresent(hdr, "Nats-Expected-")
hdr = removeHeaderIfPresent(hdr, JSMsgId)
hdr = removeHeaderIfPresent(hdr, JSMessageTTL)
hdr = removeHeaderIfPresent(hdr, JSMsgRollup)
// Add headers for the scheduled message.
hdr = genHeader(hdr, JSScheduler, sm.subj)
hdr = genHeader(hdr, JSScheduleNext, JSScheduleNextPurge) // Purge the schedule message itself.
hdr = genHeader(hdr, JSScheduler, subj)
if !repeat {
hdr = genHeader(hdr, JSScheduleNext, JSScheduleNextPurge) // Purge the schedule message itself.
} else {
hdr = genHeader(hdr, JSScheduleNext, next.Format(time.RFC3339)) // Next time the schedule fires.
}
if ttl != _EMPTY_ {
hdr = genHeader(hdr, JSMessageTTL, ttl)
}
if rollup != _EMPTY_ {
hdr = genHeader(hdr, JSMsgRollup, rollup)
}
msgs = append(msgs, &inMsg{seq: seq, subj: target, hdr: hdr, msg: msg})
ms.markInflight(sm.subj)
ms.markInflight(subj)
return false
}
ms.remove(seq)
@@ -261,3 +307,70 @@ func (ms *MsgScheduling) decode(b []byte) (uint64, error) {
}
return stamp, nil
}
// parseMsgSchedule parses a message schedule pattern and returns the time
// to fire, whether it is a repeating schedule, and whether the pattern was valid.
func parseMsgSchedule(pattern string, loc *time.Location, ts int64) (time.Time, bool, bool) {
if pattern == _EMPTY_ {
return time.Time{}, false, true
}
// Exact time.
if strings.HasPrefix(pattern, "@at ") {
// Time zone is not supported for @at.
if loc != nil {
return time.Time{}, false, false
}
t, err := time.Parse(time.RFC3339, pattern[4:])
return t, false, err == nil
}
// Repeating on a simple interval.
if strings.HasPrefix(pattern, "@every ") {
// Time zone is not supported for @every.
if loc != nil {
return time.Time{}, false, false
}
dur, err := time.ParseDuration(pattern[7:])
if err != nil {
return time.Time{}, false, false
}
// Only allow intervals of at least a second.
if dur.Seconds() < 1 {
return time.Time{}, false, false
}
// If this schedule would trigger multiple times, for example after a restart, skip ahead and only fire once.
next := time.Unix(0, ts).UTC().Round(time.Second).Add(dur)
if now := time.Now().UTC(); next.Before(now) {
next = now.Round(time.Second).Add(dur)
}
return next, true, true
}
// Predefined schedules for cron.
switch pattern {
case "@yearly", "@annually":
pattern = "0 0 0 1 1 *"
case "@monthly":
pattern = "0 0 0 1 * *"
case "@weekly":
pattern = "0 0 0 * * 0"
case "@daily", "@midnight":
pattern = "0 0 0 * * *"
case "@hourly":
pattern = "0 0 * * * *"
}
// Parse the cron pattern.
next, err := parseCron(pattern, loc, ts)
if err != nil {
return time.Time{}, false, false
}
// If this schedule would trigger multiple times, for example after a restart, skip ahead and only fire once.
if now := time.Now().UTC(); next.Before(now) {
ts = now.Round(time.Second).UnixNano()
next, err = parseCron(pattern, loc, ts)
if err != nil {
return time.Time{}, false, false
}
}
return next, true, true
}
+7 -36
View File
@@ -31,7 +31,6 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"runtime/pprof"
@@ -234,7 +233,8 @@ type Server struct {
resolver netResolver
dialTimeout time.Duration
}
leafRemoteCfgs []*leafNodeCfg
leafRemoteCfgs map[*leafNodeCfg]struct{}
rmLeafRemoteCfgs map[string]*leafNodeCfg
leafRemoteAccounts sync.Map
leafNodeEnabled bool
leafDisableConnect bool // Used in test only
@@ -367,7 +367,8 @@ type Server struct {
syncOutSem chan struct{}
// Queue to process JS API requests that come from routes (or gateways)
jsAPIRoutedReqs *ipQueue[*jsAPIRoutedReq]
jsAPIRoutedReqs *ipQueue[*jsAPIRoutedReq]
jsAPIRoutedInfoReqs *ipQueue[*jsAPIRoutedReq]
// Delayed API responses.
delayedAPIResponses *ipQueue[*delayedAPIResponse]
@@ -645,32 +646,6 @@ func selectS2AutoModeBasedOnRTT(rtt time.Duration, rttThresholds []time.Duration
return CompressionS2Best
}
func compressOptsEqual(c1, c2 *CompressionOpts) bool {
if c1 == c2 {
return true
}
if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) {
return false
}
if c1.Mode != c2.Mode {
return false
}
// For s2_auto, if one has an empty RTTThresholds, it is equivalent
// to the defaultCompressionS2AutoRTTThresholds array, so compare with that.
if c1.Mode == CompressionS2Auto {
if len(c1.RTTThresholds) == 0 && !reflect.DeepEqual(c2.RTTThresholds, defaultCompressionS2AutoRTTThresholds) {
return false
}
if len(c2.RTTThresholds) == 0 && !reflect.DeepEqual(c1.RTTThresholds, defaultCompressionS2AutoRTTThresholds) {
return false
}
if !reflect.DeepEqual(c1.RTTThresholds, c2.RTTThresholds) {
return false
}
}
return true
}
// Returns an array of s2 WriterOption based on the route compression mode.
// So far we return a single option, but this way we can call s2.NewWriter()
// with a nil []s2.WriterOption, but not with a nil s2.WriterOption, so
@@ -1956,12 +1931,7 @@ func (s *Server) registerAccount(acc *Account) *Account {
// Helper to set the sublist based on preferences.
func (s *Server) setAccountSublist(acc *Account) {
if acc != nil && acc.sl == nil {
opts := s.getOpts()
if opts != nil && opts.NoSublistCache {
acc.sl = NewSublistNoCache()
} else {
acc.sl = NewSublistWithCache()
}
acc.sl = NewSublistForServer(s)
}
}
@@ -2293,6 +2263,7 @@ func (s *Server) Start() {
s.Noticef(" Node: %s", getHash(s.info.Name))
}
s.Noticef(" ID: %s", s.info.ID)
s.printFeatureFlags(opts)
defer s.Noticef("Server is ready")
@@ -3586,7 +3557,7 @@ func (s *Server) saveClosedClient(c *client, nc net.Conn, subs map[string]*subsc
if c.acc != nil && c.acc.Name != globalAccountName {
cc.acc = c.acc.Name
}
cc.JWT = c.opts.JWT
cc.JWT = redactBearerJWT(c.opts.JWT)
cc.IssuerKey = issuerForClient(c)
cc.Tags = c.tags
cc.NameTag = c.nameTag
+20 -19
View File
@@ -92,15 +92,15 @@ type ProcessJetStreamMsgHandler func(*inMsg)
type StreamStore interface {
StoreMsg(subject string, hdr, msg []byte, ttl int64) (uint64, int64, error)
StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64, ttl int64) error
StoreRawMsg(subject string, hdr, msg []byte, seq uint64, ts int64, ttl int64, discardNewCheck bool) error
SkipMsg(seq uint64) (uint64, error)
SkipMsgs(seq uint64, num uint64) error
FlushAllPending()
FlushAllPending() error
LoadMsg(seq uint64, sm *StoreMsg) (*StoreMsg, error)
LoadNextMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
LoadNextMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
LoadLastMsg(subject string, sm *StoreMsg) (*StoreMsg, error)
LoadPrevMsg(start uint64, smp *StoreMsg) (sm *StoreMsg, err error)
LoadPrevMsg(filter string, wc bool, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
LoadPrevMsgMulti(sl *gsl.SimpleSublist, start uint64, smp *StoreMsg) (sm *StoreMsg, skip uint64, err error)
RemoveMsg(seq uint64) (bool, error)
EraseMsg(seq uint64) (bool, error)
@@ -109,7 +109,7 @@ type StreamStore interface {
Compact(seq uint64) (uint64, error)
Truncate(seq uint64) error
GetSeqFromTime(t time.Time) uint64
FilteredState(seq uint64, subject string) SimpleState
FilteredState(seq uint64, subject string) (SimpleState, error)
SubjectsState(filterSubject string) map[string]SimpleState
SubjectsTotals(filterSubject string) map[string]uint64
AllLastSeqs() ([]uint64, error)
@@ -120,7 +120,7 @@ type StreamStore interface {
State() StreamState
FastState(*StreamState)
EncodedStreamState(failed uint64) (enc []byte, err error)
SyncDeleted(dbs DeleteBlocks)
SyncDeleted(dbs DeleteBlocks) error
Type() StorageType
RegisterStorageUpdates(StorageUpdateHandler)
RegisterStorageRemoveMsg(StorageRemoveMsgHandler)
@@ -360,11 +360,13 @@ func (dbs DeleteBlocks) NumDeleted() (total uint64) {
type ConsumerStore interface {
SetStarting(sseq uint64) error
UpdateStarting(sseq uint64)
Reset(sseq uint64) error
HasState() bool
UpdateDelivered(dseq, sseq, dc uint64, ts int64) error
UpdateAcks(dseq, sseq uint64) error
UpdateConfig(cfg *ConsumerConfig) error
Update(*ConsumerState) error
ForceUpdate(*ConsumerState) error
State() (*ConsumerState, error)
BorrowState() (*ConsumerState, error)
EncodedState() ([]byte, error)
@@ -464,13 +466,6 @@ type Pending struct {
Timestamp int64
}
// TemplateStore stores templates.
// Deprecated: stream templates are deprecated and will be removed in a future version.
type TemplateStore interface {
Store(*streamTemplate) error
Delete(*streamTemplate) error
}
const (
limitsPolicyJSONString = `"limits"`
interestPolicyJSONString = `"interest"`
@@ -602,15 +597,17 @@ func (st *StorageType) UnmarshalJSON(data []byte) error {
}
const (
ackNonePolicyJSONString = `"none"`
ackAllPolicyJSONString = `"all"`
ackExplicitPolicyJSONString = `"explicit"`
ackNonePolicyJSONString = `"none"`
ackAllPolicyJSONString = `"all"`
ackExplicitPolicyJSONString = `"explicit"`
ackFlowControlPolicyJSONString = `"flow_control"`
)
var (
ackNonePolicyJSONBytes = []byte(ackNonePolicyJSONString)
ackAllPolicyJSONBytes = []byte(ackAllPolicyJSONString)
ackExplicitPolicyJSONBytes = []byte(ackExplicitPolicyJSONString)
ackNonePolicyJSONBytes = []byte(ackNonePolicyJSONString)
ackAllPolicyJSONBytes = []byte(ackAllPolicyJSONString)
ackExplicitPolicyJSONBytes = []byte(ackExplicitPolicyJSONString)
ackFlowControlPolicyJSONBytes = []byte(ackFlowControlPolicyJSONString)
)
func (ap AckPolicy) MarshalJSON() ([]byte, error) {
@@ -621,6 +618,8 @@ func (ap AckPolicy) MarshalJSON() ([]byte, error) {
return ackAllPolicyJSONBytes, nil
case AckExplicit:
return ackExplicitPolicyJSONBytes, nil
case AckFlowControl:
return ackFlowControlPolicyJSONBytes, nil
default:
return nil, fmt.Errorf("can not marshal %v", ap)
}
@@ -634,6 +633,8 @@ func (ap *AckPolicy) UnmarshalJSON(data []byte) error {
*ap = AckAll
case ackExplicitPolicyJSONString:
*ap = AckExplicit
case ackFlowControlPolicyJSONString:
*ap = AckFlowControl
default:
return fmt.Errorf("can not unmarshal %q", data)
}
@@ -741,7 +742,7 @@ func isOutOfSpaceErr(err error) bool {
var errFirstSequenceMismatch = errors.New("first sequence mismatch")
func isClusterResetErr(err error) bool {
return err == errLastSeqMismatch || err == ErrStoreEOF || err == errFirstSequenceMismatch || errors.Is(err, errCatchupAbortedNoLeader) || err == errCatchupTooManyRetries
return err == errLastSeqMismatch || err == ErrStoreEOF || err == errFirstSequenceMismatch || errors.Is(err, errCatchupAbortedNoLeader) || err == errCatchupTooManyRetries || err == errAlreadyLeader
}
// Copy all fields.
File diff suppressed because it is too large Load Diff
+12
View File
@@ -121,6 +121,18 @@ func NewSublist(enableCache bool) *Sublist {
return &Sublist{root: newLevel()}
}
// NewSublistForServer will create a default sublist with caching enabled determined
// by the server options.
func NewSublistForServer(srv *Server) *Sublist {
if srv == nil {
return NewSublistNoCache() // Probably just unit tests.
}
if opts := srv.getOpts(); opts != nil {
return NewSublist(!opts.NoSublistCache)
}
return NewSublistNoCache()
}
// NewSublistWithCache will create a default sublist with caching enabled.
func NewSublistWithCache() *Sublist {
return NewSublist(true)
+5 -2
View File
@@ -155,8 +155,11 @@ func (hw *HashWheel) expireTasks(ts int64, callback func(seq uint64, expires int
slotLowest := int64(math.MaxInt64)
for seq, expires := range s.entries {
if expires <= ts && callback(seq, expires) {
delete(s.entries, seq)
hw.count--
// Only remove if not done so already by the callback.
if _, ok := s.entries[seq]; ok {
delete(s.entries, seq)
hw.count--
}
continue
}
if expires < slotLowest {
+8
View File
@@ -1,3 +1,11 @@
## 1.40.0
We're adopting a new release strategy to minimize dependency bloat in projects that consume Gomega. It is a limitation of the go mod toolchain that _test_ subdependencies of your project's direct dependencies get pulled in as *indirect* dependencies. In the case of Gomega, this ends up pulling in all of Ginkgo into your `go.mod` even if you are only using Gomega (Gomega uses Ginkgo for its own tests).
Going forward, releases will strip out all tests, tidy up the `go.mod` and then push this stripped down version to a new `master-lite` branch. These stripped-down versions will receive the `vx.y.z` git tag and will be picked up by the go toolchain.
Please open an issue if this new release process causes unexpected changes for your projects.
## 1.39.1
Update all dependencies. This auto-updated the required version of Go to 1.24, consistent with the fact that Go 1.23 has been out of support for almost six months.
+1 -1
View File
@@ -22,7 +22,7 @@ import (
"github.com/onsi/gomega/types"
)
const GOMEGA_VERSION = "1.39.1"
const GOMEGA_VERSION = "1.40.0"
const nilGomegaPanic = `You are trying to make an assertion, but haven't registered Gomega's fail handler.
If you're using Ginkgo then you probably forgot to put your assertion in an It().
@@ -744,16 +744,6 @@ func (s *Service) Delete(ctx context.Context, req *provider.DeleteRequest) (*pro
}
ctx = ctxpkg.ContextSetLockID(ctx, req.LockId)
// check DeleteRequest for any known opaque properties.
// FIXME these should be part of the DeleteRequest object
if req.Opaque != nil {
if _, ok := req.Opaque.Map["deleting_shared_resource"]; ok {
// it is a binary key; its existence signals true. Although, do not assume.
ctx = appctx.WithDeletingSharedResource(ctx)
}
}
md, err := s.Storage.GetMD(ctx, req.Ref, []string{}, []string{"id", "status"})
if err != nil {
return &provider.DeleteResponse{
@@ -182,9 +182,12 @@ func (s *service) GetUser(ctx context.Context, req *userpb.GetUserRequest) (*use
user, err := s.usermgr.GetUser(ctx, req.UserId, req.SkipFetchingUserGroups)
if err != nil {
res := &userpb.GetUserResponse{}
if _, ok := err.(errtypes.NotFound); ok {
switch err.(type) {
case errtypes.NotFound:
res.Status = status.NewNotFound(ctx, "user not found")
} else {
case errtypes.Unavailable:
res.Status = status.NewUnavailable(ctx, "user provider temporarily unavailable")
default:
res.Status = status.NewInternal(ctx, "error getting user")
}
return res, nil
@@ -205,9 +208,12 @@ func (s *service) GetUserByClaim(ctx context.Context, req *userpb.GetUserByClaim
user, err := s.usermgr.GetUserByClaim(ctx, req.Claim, req.Value, tenantID, req.SkipFetchingUserGroups)
if err != nil {
res := &userpb.GetUserByClaimResponse{}
if _, ok := err.(errtypes.NotFound); ok {
switch err.(type) {
case errtypes.NotFound:
res.Status = status.NewNotFound(ctx, fmt.Sprintf("user not found %s %s", req.Claim, req.Value))
} else {
case errtypes.Unavailable:
res.Status = status.NewUnavailable(ctx, "user provider temporarily unavailable")
default:
res.Status = status.NewInternal(ctx, "error getting user by claim")
}
return res, nil
@@ -84,9 +84,9 @@ type service struct {
allowedPathsForShares []*regexp.Regexp
}
func getShareManager(c *config) (share.Manager, error) {
func getShareManager(c *config, logger *zerolog.Logger) (share.Manager, error) {
if f, ok := registry.NewFuncs[c.Driver]; ok {
return f(c.Drivers[c.Driver])
return f(c.Drivers[c.Driver], logger)
}
return nil, errtypes.NotFound("driver not found: " + c.Driver)
}
@@ -114,7 +114,7 @@ func parseConfig(m map[string]interface{}) (*config, error) {
}
// New creates a new user share provider svc initialized from defaults
func NewDefault(m map[string]interface{}, ss *grpc.Server, _ *zerolog.Logger) (rgrpc.Service, error) {
func NewDefault(m map[string]any, ss *grpc.Server, logger *zerolog.Logger) (rgrpc.Service, error) {
c, err := parseConfig(m)
if err != nil {
@@ -123,7 +123,7 @@ func NewDefault(m map[string]interface{}, ss *grpc.Server, _ *zerolog.Logger) (r
c.init()
sm, err := getShareManager(c)
sm, err := getShareManager(c, logger)
if err != nil {
return nil, err
}
@@ -155,6 +155,13 @@ func (s *svc) handleProppatch(ctx context.Context, w http.ResponseWriter, r *htt
}
for j := range patches[i].Props {
propNameXML := patches[i].Props[j].XMLName
// favorites are now managed by the Graph API and can no longer be set using PROPPATCH. To avoid confusion, we return a 403 Forbidden when clients try to set the oc:favorites property
if propNameXML.Local == "favorite" {
w.WriteHeader(http.StatusForbidden)
return nil, nil, false
}
// don't use path.Join. It removes the double slash! concatenate with a /
key := fmt.Sprintf("%s/%s", patches[i].Props[j].XMLName.Space, patches[i].Props[j].XMLName.Local)
value := string(patches[i].Props[j].InnerXML)
-10
View File
@@ -27,16 +27,6 @@ import (
"go.opentelemetry.io/otel/trace"
)
// deletingSharedResource flags to a storage a shared resource is being deleted not by the owner.
type deletingSharedResource struct{}
func WithDeletingSharedResource(ctx context.Context) context.Context {
return context.WithValue(ctx, deletingSharedResource{}, struct{}{})
}
func DeletingSharedResourceFromContext(ctx context.Context) bool {
return ctx.Value(deletingSharedResource{}) != nil
}
// WithLogger returns a context with an associated logger.
func WithLogger(ctx context.Context, l *zerolog.Logger) context.Context {
return l.WithContext(ctx)
+21
View File
@@ -203,6 +203,15 @@ func (e TooEarly) Error() string { return "error: too early: " + string(e) }
// IsTooEarly implements the IsTooEarly interface.
func (e TooEarly) IsTooEarly() {}
// Unavailable is the error to use when a backend service (e.g. LDAP, database) is
// temporarily unreachable. Callers should treat this as a transient failure and retry.
type Unavailable string
func (e Unavailable) Error() string { return "error: unavailable: " + string(e) }
// IsUnavailable implements the IsUnavailable interface.
func (e Unavailable) IsUnavailable() {}
// IsNotFound is the interface to implement
// to specify that a resource is not found.
type IsNotFound interface {
@@ -293,6 +302,12 @@ type IsTooEarly interface {
IsTooEarly()
}
// IsUnavailable is the interface to implement to specify that a backend service is
// temporarily unavailable and the caller should retry.
type IsUnavailable interface {
IsUnavailable()
}
// NewErrtypeFromStatus maps a rpc status to an errtype
func NewErrtypeFromStatus(status *rpc.Status) error {
switch status.Code {
@@ -329,6 +344,8 @@ func NewErrtypeFromStatus(status *rpc.Status) error {
return BadRequest(status.Message)
case rpc.Code_CODE_TOO_EARLY:
return TooEarly(status.Message)
case rpc.Code_CODE_UNAVAILABLE:
return Unavailable(status.Message)
default:
return InternalError(status.Message)
}
@@ -363,6 +380,8 @@ func NewErrtypeFromHTTPStatusCode(code int, message string) error {
return PartialContent(message)
case http.StatusTooEarly:
return TooEarly(message)
case http.StatusServiceUnavailable:
return Unavailable(message)
case StatusChecksumMismatch:
return ChecksumMismatch(message)
default:
@@ -399,6 +418,8 @@ func NewHTTPStatusCodeFromErrtype(err error) int {
return http.StatusPartialContent
case TooEarly:
return http.StatusTooEarly
case Unavailable:
return http.StatusServiceUnavailable
case ChecksumMismatch:
return StatusChecksumMismatch
default:
+25 -21
View File
@@ -71,9 +71,8 @@ type RawStream struct {
c Config
}
func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
var s Stream
b := backoff.NewExponentialBackOff()
func JetStream(ctx context.Context, name string, cfg Config) (jetstream.JetStream, error) {
var js jetstream.JetStream
connect := func() error {
var tlsConf *tls.Config
@@ -120,27 +119,32 @@ func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
return err
}
jsConn, err := jetstream.New(conn)
if err != nil {
return err
}
js, err := jsConn.Stream(ctx, events.MainQueueName)
if err != nil {
return err
}
s = &RawStream{
js: js,
c: cfg,
}
return nil
js, err = jetstream.New(conn)
return err
}
err := backoff.Retry(connect, b)
err := backoff.Retry(connect, backoff.NewExponentialBackOff())
if err != nil {
return s, errors.Wrap(err, "could not connect to nats jetstream")
return nil, errors.Wrap(err, "could not connect to nats jetstream")
}
return s, nil
return js, nil
}
func FromConfig(ctx context.Context, name string, cfg Config) (Stream, error) {
jsConn, err := JetStream(ctx, name, cfg)
if err != nil {
return nil, err
}
js, err := jsConn.Stream(ctx, events.MainQueueName)
if err != nil {
return nil, err
}
return &RawStream{
js: js,
c: cfg,
}, nil
}
func (s *RawStream) Consume(group string, evs ...events.Unmarshaller) (<-chan Event, error) {
+35 -1
View File
@@ -2,6 +2,7 @@ package stream
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
@@ -11,7 +12,9 @@ import (
"github.com/cenkalti/backoff"
"github.com/go-micro/plugins/v4/events/natsjs"
"github.com/nats-io/nats.go/jetstream"
"github.com/opencloud-eu/reva/v2/pkg/events"
"github.com/opencloud-eu/reva/v2/pkg/events/raw"
"github.com/opencloud-eu/reva/v2/pkg/logger"
)
@@ -65,7 +68,38 @@ func NatsFromConfig(connName string, disableDurability bool, cfg NatsConfig) (ev
opts = append(opts, natsjs.DisableDurableStreams())
}
return Nats(opts...)
s, err := Nats(opts...)
if err != nil {
return nil, err
}
// apply a MaxAge to the main queue to prevent it from filling up
ctx := context.Background()
jsConn, err := raw.JetStream(ctx, connName, raw.Config{
Endpoint: cfg.Endpoint,
Cluster: cfg.Cluster,
TLSInsecure: cfg.TLSInsecure,
TLSRootCACertificate: cfg.TLSRootCACertificate,
EnableTLS: cfg.EnableTLS,
AuthUsername: cfg.AuthUsername,
AuthPassword: cfg.AuthPassword,
})
if err != nil {
return nil, err
}
streamCfg := jetstream.StreamConfig{
Name: "main-queue",
MaxAge: 7 * 24 * time.Hour,
}
_, err = jsConn.CreateStream(ctx, streamCfg)
if err != nil {
// If the stream already exists, update its configuration
if err == jetstream.ErrStreamNameAlreadyInUse {
_, _ = jsConn.UpdateStream(ctx, streamCfg)
}
}
return s, nil
}
// nats returns a nats streaming client
+13
View File
@@ -68,6 +68,15 @@ func NewInternal(ctx context.Context, msg string) *rpc.Status {
}
}
// NewUnavailable returns a Status with CODE_UNAVAILABLE.
func NewUnavailable(ctx context.Context, msg string) *rpc.Status {
return &rpc.Status{
Code: rpc.Code_CODE_UNAVAILABLE,
Message: msg,
Trace: getTrace(ctx),
}
}
// NewUnauthenticated returns a Status with CODE_UNAUTHENTICATED.
func NewUnauthenticated(ctx context.Context, err error, msg string) *rpc.Status {
return &rpc.Status{
@@ -191,6 +200,10 @@ func NewStatusFromErrType(ctx context.Context, msg string, err error) *rpc.Statu
return NewUnimplemented(ctx, err, msg+":"+err.Error())
case errtypes.BadRequest:
return NewInvalid(ctx, msg+":"+err.Error())
case errtypes.Unavailable:
return NewUnavailable(ctx, msg+": "+err.Error())
case errtypes.IsUnavailable:
return NewUnavailable(ctx, msg+": "+err.Error())
}
// map GRPC status codes coming from the auth middleware
+170 -60
View File
@@ -36,9 +36,9 @@ import (
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
"github.com/opencloud-eu/reva/v2/pkg/events"
"github.com/opencloud-eu/reva/v2/pkg/events/stream"
"github.com/opencloud-eu/reva/v2/pkg/logger"
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
"github.com/opencloud-eu/reva/v2/pkg/share"
migration "github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/migrations"
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/providercache"
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/receivedsharecache"
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/sharecache"
@@ -48,6 +48,7 @@ import (
"github.com/opencloud-eu/reva/v2/pkg/storagespace"
"github.com/opencloud-eu/reva/v2/pkg/utils"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/codes"
"golang.org/x/sync/errgroup"
"google.golang.org/genproto/protobuf/field_mask"
@@ -122,14 +123,20 @@ var (
)
type config struct {
GatewayAddr string `mapstructure:"gateway_addr"`
MaxConcurrency int `mapstructure:"max_concurrency"`
ProviderAddr string `mapstructure:"provider_addr"`
ServiceUserID string `mapstructure:"service_user_id"`
ServiceUserIdp string `mapstructure:"service_user_idp"`
MachineAuthAPIKey string `mapstructure:"machine_auth_apikey"`
CacheTTL int `mapstructure:"ttl"`
Events EventOptions `mapstructure:"events"`
GatewayAddr string `mapstructure:"gateway_addr"`
MaxConcurrency int `mapstructure:"max_concurrency"`
ProviderAddr string `mapstructure:"provider_addr"`
SystemUserID string `mapstructure:"system_user_id"`
SystemUserIdp string `mapstructure:"system_user_idp"`
MachineAuthAPIKey string `mapstructure:"machine_auth_apikey"`
ServiceAccountID string `mapstructure:"service_account_id"`
ServiceAccountSecret string `mapstructure:"service_account_secret"`
// ProviderRegistryAddr is the address of the storage registry used during
// migrations. Defaults to GatewayAddr when empty, because in the default
// OpenCloud deployment the registry is co-located with the gateway.
ProviderRegistryAddr string `mapstructure:"provider_registry_addr"`
CacheTTL int `mapstructure:"ttl"`
Events EventOptions `mapstructure:"events"`
}
// EventOptions are the configurable options for events
@@ -145,8 +152,6 @@ type EventOptions struct {
// Manager implements a share manager using a cs3 storage backend with local caching
type Manager struct {
sync.RWMutex
Cache providercache.Cache // holds all shares, sharded by provider id and space id
CreatedCache sharecache.Cache // holds the list of shares a user has created, sharded by user id
GroupReceivedCache sharecache.Cache // holds the list of shares a group has access to, sharded by group id
@@ -155,23 +160,25 @@ type Manager struct {
storage metadata.Storage
SpaceRoot *provider.ResourceId
initialized bool
ready chan struct{} // closed once initialize() has completed successfully
migrationsDone chan struct{} // closed once doMigrations() has returned on this instance
MaxConcurrency int
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient]
eventStream events.Stream
logger *zerolog.Logger
}
// NewDefault returns a new manager instance with default dependencies
func NewDefault(m map[string]interface{}) (share.Manager, error) {
func NewDefault(m map[string]interface{}, logger *zerolog.Logger) (share.Manager, error) {
c := &config{}
if err := mapstructure.Decode(m, c); err != nil {
err = errors.Wrap(err, "error creating a new manager")
return nil, err
}
s, err := metadata.NewCS3Storage(c.ProviderAddr, c.ProviderAddr, c.ServiceUserID, c.ServiceUserIdp, c.MachineAuthAPIKey)
s, err := metadata.NewCS3Storage(c.ProviderAddr, c.ProviderAddr, c.SystemUserID, c.SystemUserIdp, c.MachineAuthAPIKey)
if err != nil {
return nil, err
}
@@ -189,11 +196,34 @@ func NewDefault(m map[string]interface{}) (share.Manager, error) {
}
}
return New(s, gatewaySelector, c.CacheTTL, es, c.MaxConcurrency)
mgr, err := New(s, logger, gatewaySelector, c.CacheTTL, es, c.MaxConcurrency)
if err != nil {
return nil, err
}
providerRegistryAddr := c.ProviderRegistryAddr
if providerRegistryAddr == "" {
providerRegistryAddr = c.GatewayAddr
}
mgr.RunMigrations(migration.MigrationConfig{
ServiceAccountID: c.ServiceAccountID,
ServiceAccountSecret: c.ServiceAccountSecret,
ProviderRegistryAddr: providerRegistryAddr,
})
return mgr, nil
}
// New returns a new manager instance.
func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient], ttlSeconds int, es events.Stream, maxconcurrency int) (*Manager, error) {
func New(s metadata.Storage,
logger *zerolog.Logger,
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient],
ttlSeconds int,
es events.Stream,
maxconcurrency int,
) (*Manager, error) {
if logger == nil {
nop := zerolog.Nop()
logger = &nop
}
ttl := time.Duration(ttlSeconds) * time.Second
m := &Manager{
@@ -205,13 +235,38 @@ func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.Gate
gatewaySelector: gatewaySelector,
eventStream: es,
MaxConcurrency: maxconcurrency,
logger: logger,
ready: make(chan struct{}),
// migrationsDone is open (blocking) by default. It is closed by
// doMigrations when all migrations complete, or by SkipMigrations for
// callers (e.g. tests) that do not run migrations at all.
migrationsDone: make(chan struct{}),
}
// Initialize the metadata storage connection in the background, retrying
// with exponential backoff if the backend is not yet available.
go func() {
backoff := time.Second
for {
if err := m.initialize(context.Background()); err != nil {
logger.Info().Err(err).Dur("backoff", backoff).Msg("share manager: metadata storage initialization failed, retrying")
time.Sleep(backoff)
if backoff < 30*time.Second {
backoff *= 2
}
continue
}
logger.Debug().Msg("share manager: initialization succeeded")
close(m.ready)
return
}
}()
// listen for events
if m.eventStream != nil {
ch, err := events.Consume(m.eventStream, "jsoncs3sharemanager", _registeredEvents...)
if err != nil {
appctx.GetLogger(context.Background()).Error().Err(err).Msg("error consuming events")
logger.Error().Err(err).Msg("error consuming events")
}
go m.ProcessEvents(ch)
}
@@ -219,23 +274,13 @@ func New(s metadata.Storage, gatewaySelector pool.Selectable[gatewayv1beta1.Gate
return m, nil
}
// initialize connects to the metadata storage backend and ensures the required
// directory structure exists. It is called once at startup from a background
// goroutine (see New) and must not be called concurrently.
func (m *Manager) initialize(ctx context.Context) error {
_, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "initialize")
defer span.End()
if m.initialized {
span.SetStatus(codes.Ok, "already initialized")
return nil
}
m.Lock()
defer m.Unlock()
if m.initialized { // check if initialization happened while grabbing the lock
span.SetStatus(codes.Ok, "initialized while grabbing lock")
return nil
}
ctx = context.Background()
err := m.storage.Init(ctx, "jsoncs3-share-manager-metadata")
if err != nil {
span.RecordError(err)
@@ -261,21 +306,85 @@ func (m *Manager) initialize(ctx context.Context) error {
span.SetStatus(codes.Error, err.Error())
return err
}
err = m.storage.MakeDirIfNotExist(ctx, "migrations")
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return err
}
m.initialized = true
span.SetStatus(codes.Ok, "initialized")
return nil
}
// waitForInit blocks until the background initialization goroutine has
// successfully completed, or until ctx is cancelled.
func (m *Manager) waitForInit(ctx context.Context) error {
select {
case <-m.ready:
return nil
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "share manager not yet initialized")
}
}
// waitForMigrations blocks until both storage initialization and all data
// migrations have completed on this instance, or until ctx is cancelled.
// It is a strict superset of waitForInit and should be used by write operations
// to ensure no writes race with an in-progress migration.
func (m *Manager) waitForMigrations(ctx context.Context) error {
select {
case <-m.ready:
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "share manager not yet initialized")
}
select {
case <-m.migrationsDone:
return nil
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "share manager migrations not yet complete")
}
}
// RunMigrations starts data migrations in a background goroutine. It should be
// called once after New() in production server startup. Callers that do not
// need migrations should call SkipMigrations instead to unblock write operations.
func (m *Manager) RunMigrations(cfg migration.MigrationConfig) {
go m.doMigrations(cfg)
}
// SkipMigrations unblocks write operations on this instance without running
// any migrations. It must be called when RunMigrations will not be called,
// for example in tests.
func (m *Manager) SkipMigrations() {
close(m.migrationsDone)
}
func (m *Manager) doMigrations(cfg migration.MigrationConfig) {
// Always close migrationsDone when this goroutine exits, whether migrations
// ran, were skipped, or failed. This unblocks write operations on this
// instance. Non-winning instances are held here by acquireLock until the
// winning instance finishes, so the close happens only after the storage
// state is fully migrated.
defer close(m.migrationsDone)
if err := m.waitForInit(context.Background()); err != nil {
m.logger.Error().Err(err).Msg("share manager: aborting migrations, manager did not initialize")
return
}
m.logger.Debug().Msg("migrations start")
migrations := migration.New(*m.logger, m.gatewaySelector, m.storage, cfg, m, m)
migrations.RunMigrations()
}
func (m *Manager) ProcessEvents(ch <-chan events.Event) {
log := logger.New()
log := m.logger
ctx := context.Background()
if err := m.waitForInit(ctx); err != nil {
log.Error().Err(err).Msg("share manager: error waiting for initialization")
return
}
for event := range ch {
ctx := context.Background()
if err := m.initialize(ctx); err != nil {
log.Error().Err(err).Msg("error initializing manager")
}
if ev, ok := event.Event.(events.SpaceDeleted); ok {
log.Debug().Msgf("space deleted event: %v", ev)
go func() { m.purgeSpace(ctx, ev.ID) }()
@@ -287,7 +396,7 @@ func (m *Manager) ProcessEvents(ch <-chan events.Event) {
func (m *Manager) Share(ctx context.Context, md *provider.ResourceInfo, g *collaboration.ShareGrant) (*collaboration.Share, error) {
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "Share")
defer span.End()
if err := m.initialize(ctx); err != nil {
if err := m.waitForMigrations(ctx); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, err
@@ -436,7 +545,7 @@ func (m *Manager) GetShare(ctx context.Context, ref *collaboration.ShareReferenc
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "GetShare")
defer span.End()
sublog := appctx.GetLogger(ctx).With().Str("id", ref.GetId().GetOpaqueId()).Str("key", ref.GetKey().String()).Str("driver", "jsoncs3").Str("handler", "GetShare").Logger()
if err := m.initialize(ctx); err != nil {
if err := m.waitForInit(ctx); err != nil {
return nil, err
}
@@ -494,7 +603,7 @@ func (m *Manager) Unshare(ctx context.Context, ref *collaboration.ShareReference
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "Unshare")
defer span.End()
if err := m.initialize(ctx); err != nil {
if err := m.waitForMigrations(ctx); err != nil {
return err
}
@@ -511,7 +620,7 @@ func (m *Manager) UpdateShare(ctx context.Context, ref *collaboration.ShareRefer
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "UpdateShare")
defer span.End()
if err := m.initialize(ctx); err != nil {
if err := m.waitForMigrations(ctx); err != nil {
return nil, err
}
@@ -599,7 +708,7 @@ func (m *Manager) ListShares(ctx context.Context, filters []*collaboration.Filte
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "ListShares")
defer span.End()
if err := m.initialize(ctx); err != nil {
if err := m.waitForInit(ctx); err != nil {
return nil, err
}
@@ -816,7 +925,7 @@ func (m *Manager) ListReceivedShares(ctx context.Context, filters []*collaborati
defer span.End()
sublog := appctx.GetLogger(ctx).With().Str("driver", "jsoncs3").Str("handler", "ListReceivedShares").Logger()
if err := m.initialize(ctx); err != nil {
if err := m.waitForInit(ctx); err != nil {
return nil, err
}
@@ -1012,7 +1121,7 @@ func (m *Manager) convert(ctx context.Context, userID string, s *collaboration.S
// GetReceivedShare returns the information for a received share.
func (m *Manager) GetReceivedShare(ctx context.Context, ref *collaboration.ShareReference) (*collaboration.ReceivedShare, error) {
if err := m.initialize(ctx); err != nil {
if err := m.waitForInit(ctx); err != nil {
return nil, err
}
@@ -1056,7 +1165,7 @@ func (m *Manager) UpdateReceivedShare(ctx context.Context, receivedShare *collab
ctx, span := appctx.GetTracerProvider(ctx).Tracer(tracerName).Start(ctx, "UpdateReceivedShare")
defer span.End()
if err := m.initialize(ctx); err != nil {
if err := m.waitForMigrations(ctx); err != nil {
return nil, err
}
@@ -1103,8 +1212,8 @@ func updateShareID(share *collaboration.Share) {
// Load imports shares and received shares from channels (e.g. during migration)
func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Share, receivedShareChan <-chan share.ReceivedShareWithUser) error {
log := appctx.GetLogger(ctx)
if err := m.initialize(ctx); err != nil {
l := m.logger
if err := m.waitForInit(ctx); err != nil {
return err
}
@@ -1119,14 +1228,14 @@ func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Shar
updateShareID(s)
}
if err := m.Cache.Add(context.Background(), s.GetResourceId().GetStorageId(), s.GetResourceId().GetSpaceId(), s.Id.OpaqueId, s); err != nil {
log.Error().Err(err).Interface("share", s).Msg("error persisting share")
l.Error().Err(err).Interface("share", s).Msg("error persisting share")
} else {
log.Debug().Str("storageid", s.GetResourceId().GetStorageId()).Str("spaceid", s.GetResourceId().GetSpaceId()).Str("shareid", s.Id.OpaqueId).Msg("imported share")
l.Debug().Str("storageid", s.GetResourceId().GetStorageId()).Str("spaceid", s.GetResourceId().GetSpaceId()).Str("shareid", s.Id.OpaqueId).Msg("imported share")
}
if err := m.CreatedCache.Add(ctx, s.GetCreator().GetOpaqueId(), s.Id.OpaqueId); err != nil {
log.Error().Err(err).Interface("share", s).Msg("error persisting created cache")
l.Error().Err(err).Interface("share", s).Msg("error persisting created cache")
} else {
log.Debug().Str("creatorid", s.GetCreator().GetOpaqueId()).Str("shareid", s.Id.OpaqueId).Msg("updated created cache")
l.Debug().Str("creatorid", s.GetCreator().GetOpaqueId()).Str("shareid", s.Id.OpaqueId).Msg("updated created cache")
}
}
wg.Done()
@@ -1137,18 +1246,19 @@ func (m *Manager) Load(ctx context.Context, shareChan <-chan *collaboration.Shar
if !shareIsRoutable(s.ReceivedShare.GetShare()) {
updateShareID(s.ReceivedShare.GetShare())
}
switch s.ReceivedShare.Share.Grantee.Type {
case provider.GranteeType_GRANTEE_TYPE_USER:
if err := m.UserReceivedStates.Add(context.Background(), s.ReceivedShare.GetShare().GetGrantee().GetUserId().GetOpaqueId(), s.ReceivedShare.GetShare().GetResourceId().GetSpaceId(), s.ReceivedShare); err != nil {
log.Error().Err(err).Interface("received share", s).Msg("error persisting received share for user")
if s.UserID != nil {
spaceid := s.ReceivedShare.GetShare().GetResourceId().GetStorageId() + shareid.IDDelimiter + s.ReceivedShare.GetShare().GetResourceId().GetSpaceId()
if err := m.UserReceivedStates.Add(context.Background(), s.UserID.GetOpaqueId(), spaceid, s.ReceivedShare); err != nil {
l.Error().Err(err).Interface("received share", s).Msg("error persisting received share for user")
} else {
log.Debug().Str("userid", s.ReceivedShare.GetShare().GetGrantee().GetUserId().GetOpaqueId()).Str("spaceid", s.ReceivedShare.GetShare().GetResourceId().GetSpaceId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share userdata")
l.Debug().Str("userid", s.UserID.GetOpaqueId()).Str("spaceid", spaceid).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share userdata")
}
case provider.GranteeType_GRANTEE_TYPE_GROUP:
}
if s.ReceivedShare.Share.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_GROUP && s.UserID == nil {
if err := m.GroupReceivedCache.Add(context.Background(), s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId(), s.ReceivedShare.GetShare().GetId().GetOpaqueId()); err != nil {
log.Error().Err(err).Interface("received share", s).Msg("error persisting received share to group cache")
l.Error().Err(err).Interface("received share", s).Msg("error persisting received share to group cache")
} else {
log.Debug().Str("groupid", s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share group cache")
l.Debug().Str("groupid", s.ReceivedShare.GetShare().GetGrantee().GetGroupId().GetOpaqueId()).Str("shareid", s.ReceivedShare.GetShare().Id.OpaqueId).Msg("updated received share group cache")
}
}
}
@@ -1220,7 +1330,7 @@ func (m *Manager) removeShare(ctx context.Context, s *collaboration.Share, skipS
func (m *Manager) CleanupStaleShares(ctx context.Context) {
log := appctx.GetLogger(ctx)
if err := m.initialize(ctx); err != nil {
if err := m.waitForMigrations(ctx); err != nil {
return
}
@@ -0,0 +1,435 @@
// Copyright 2026 OpenCloud GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// In applying this license, CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.
package migration
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/cenkalti/backoff"
grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1"
userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
collaboration "github.com/cs3org/go-cs3apis/cs3/sharing/collaboration/v1beta1"
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
registry "github.com/cs3org/go-cs3apis/cs3/storage/registry/v1beta1"
typesv1beta1 "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
"github.com/google/uuid"
ctxpkg "github.com/opencloud-eu/reva/v2/pkg/ctx"
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
"github.com/opencloud-eu/reva/v2/pkg/share"
"github.com/opencloud-eu/reva/v2/pkg/share/manager/jsoncs3/shareid"
"github.com/opencloud-eu/reva/v2/pkg/utils"
"github.com/rs/zerolog"
"google.golang.org/grpc"
)
// storageProvider is the narrow subset of provider.ProviderAPIClient that the
// migration actually uses. Keeping it narrow makes test stubs trivial to write.
type storageProvider interface {
ListGrants(ctx context.Context, in *provider.ListGrantsRequest, opts ...grpc.CallOption) (*provider.ListGrantsResponse, error)
}
type ImportSpaceMembersMigration struct {
cfg config
sharesChan chan *collaboration.Share
receivedChan chan share.ReceivedShareWithUser
userCache map[string]*userpb.UserId
groupCache map[string]*grouppb.GroupId
providerResolver func(context.Context, *provider.StorageSpace) (storageProvider, error)
}
func init() {
registerMigration(&ImportSpaceMembersMigration{})
}
func (m *ImportSpaceMembersMigration) Initialize(cfg config) {
m.cfg = cfg
m.sharesChan = make(chan *collaboration.Share)
m.receivedChan = make(chan share.ReceivedShareWithUser)
m.userCache = make(map[string]*userpb.UserId)
m.groupCache = make(map[string]*grouppb.GroupId)
m.providerResolver = func(ctx context.Context, space *provider.StorageSpace) (storageProvider, error) {
return m.storageProviderForSpace(ctx, space)
}
}
func (m *ImportSpaceMembersMigration) Name() string {
return "import_space_members"
}
func (m *ImportSpaceMembersMigration) Version() int {
return 1
}
func (m *ImportSpaceMembersMigration) Migrate() error {
gwc, err := m.cfg.gatewaySelector.Next()
if err != nil {
return err
}
svcCtx, err := utils.GetServiceUserContextWithContext(context.Background(), gwc, m.cfg.serviceAccountID, m.cfg.serviceAccountSecret)
if err != nil {
m.cfg.logger.Error().Err(err).Msg("failed to get service user context for migration")
return err
}
// List all project spaces.
listRes, err := gwc.ListStorageSpaces(svcCtx, &provider.ListStorageSpacesRequest{
Opaque: utils.AppendPlainToOpaque(nil, "unrestricted", "true"),
Filters: []*provider.ListStorageSpacesRequest_Filter{
{
Type: provider.ListStorageSpacesRequest_Filter_TYPE_SPACE_TYPE,
Term: &provider.ListStorageSpacesRequest_Filter_SpaceType{SpaceType: "project"},
},
},
})
if err != nil {
m.cfg.logger.Error().Err(err).Msg("space-membership migration: failed to list storage spaces")
return err
}
if listRes.GetStatus().GetCode() != rpc.Code_CODE_OK {
m.cfg.logger.Error().Str("status", listRes.GetStatus().GetMessage()).Msg("space-membership migration: ListStorageSpaces returned non-OK status")
return errtypes.InternalError("ListStorageSpaces")
}
spaces := listRes.GetStorageSpaces()
m.cfg.logger.Info().Int("spaces", len(spaces)).Msg("Starting migration")
// loadCtx is cancelled when the producer finishes (or fails) so that the
// Load goroutine — which blocks reading from the channels — is not left
// waiting forever if we return early from an error.
loadCtx, cancelLoad := context.WithCancel(svcCtx)
defer cancelLoad()
var wg sync.WaitGroup
var loaderError error
wg.Go(func() {
loaderError = m.cfg.loader.Load(loadCtx, m.sharesChan, m.receivedChan)
})
migrated := 0
for _, space := range spaces {
sharesCreated, err := m.migrateSpace(loadCtx, space)
if err != nil {
m.cfg.logger.Error().Err(err).Str("space", space.GetId().GetOpaqueId()).Msg("failed to migrate space; continuing with remaining spaces")
continue
}
migrated++
m.cfg.logger.Debug().
Str("space", space.GetId().GetOpaqueId()).
Int("shares_created", sharesCreated).
Msg("space migrated")
if migrated%10 == 0 {
m.cfg.logger.Info().
Int("migrated", migrated).
Int("total", len(spaces)).
Msg("migration progress")
}
}
close(m.receivedChan)
close(m.sharesChan)
wg.Wait()
m.cfg.logger.Info().Err(loaderError).Int("migrated", migrated).Int("total", len(spaces)).Msg("Migration finished")
return loaderError
}
func (m *ImportSpaceMembersMigration) migrateSpace(ctx context.Context, space *provider.StorageSpace) (int, error) {
spClient, err := m.providerResolver(ctx, space)
if err != nil {
return 0, err
}
ref := &provider.Reference{ResourceId: space.GetRoot()}
grantsRes, err := spClient.ListGrants(ctx, &provider.ListGrantsRequest{Ref: ref})
if err != nil {
return 0, err
}
if grantsRes.GetStatus().GetCode() != rpc.Code_CODE_OK {
return 0, errtypes.NewErrtypeFromStatus(grantsRes.GetStatus())
}
sharesCreated := 0
for _, grant := range grantsRes.GetGrants() {
share, receivedShares, err := m.spaceGrantToShares(ctx, grant, space)
if err != nil {
m.cfg.logger.Error().Err(err).
Interface("grant", grant).
Msg("Failed to convert grant to shares")
continue
}
if share == nil {
// share already existed; nothing to import for this grant
continue
}
select {
case m.sharesChan <- share:
case <-ctx.Done():
return sharesCreated, ctx.Err()
}
for _, rs := range receivedShares {
select {
case m.receivedChan <- rs:
case <-ctx.Done():
return sharesCreated, ctx.Err()
}
}
sharesCreated++
}
return sharesCreated, nil
}
// resolveRetries is the maximum number of times resolveUserID / resolveGroupID
// will retry after receiving an errtypes.Unavailable response (LDAP down).
const resolveRetries = 10
// retryOnUnavailable calls op, retrying with exponential backoff whenever op
// returns errtypes.Unavailable. Any other error (including context
// cancellation) stops the loop immediately and is returned as-is.
// Retries are capped at resolveRetries attempts and respect ctx cancellation.
func retryOnUnavailable(ctx context.Context, log zerolog.Logger, op func() error) error {
b := backoff.WithContext(
backoff.WithMaxRetries(backoff.NewExponentialBackOff(), resolveRetries),
ctx,
)
notify := func(err error, d time.Duration) {
log.Warn().Err(err).Dur("retry_in", d).Msg("identity provider temporarily unavailable, retrying")
}
return backoff.RetryNotify(func() error {
err := op()
if err == nil {
return nil
}
if _, ok := err.(errtypes.Unavailable); ok {
return err // transient — keep retrying
}
return backoff.Permanent(err) // permanent — stop immediately
}, b, notify)
}
func (m *ImportSpaceMembersMigration) resolveUserID(ctx context.Context, opaqueID string) (*userpb.UserId, error) {
if id, ok := m.userCache[opaqueID]; ok {
return id, nil
}
var id *userpb.UserId
err := retryOnUnavailable(ctx, m.cfg.logger, func() error {
gwc, err := m.cfg.gatewaySelector.Next()
if err != nil {
return err
}
res, err := gwc.GetUser(ctx, &userpb.GetUserRequest{
UserId: &userpb.UserId{OpaqueId: opaqueID},
SkipFetchingUserGroups: true,
})
if err != nil {
return err
}
if res.GetStatus().GetCode() != rpc.Code_CODE_OK {
// errtypes.NewErrtypeFromStatus maps CODE_UNAVAILABLE → errtypes.Unavailable,
// which retryOnUnavailable will retry; all other codes are treated as permanent.
return errtypes.NewErrtypeFromStatus(res.GetStatus())
}
id = res.GetUser().GetId()
return nil
})
if err != nil {
return nil, err
}
m.userCache[opaqueID] = id
return id, nil
}
func (m *ImportSpaceMembersMigration) resolveGroupID(ctx context.Context, opaqueID string) (*grouppb.GroupId, error) {
if id, ok := m.groupCache[opaqueID]; ok {
return id, nil
}
var id *grouppb.GroupId
err := retryOnUnavailable(ctx, m.cfg.logger, func() error {
gwc, err := m.cfg.gatewaySelector.Next()
if err != nil {
return err
}
res, err := gwc.GetGroup(ctx, &grouppb.GetGroupRequest{
GroupId: &grouppb.GroupId{OpaqueId: opaqueID},
SkipFetchingMembers: true,
})
if err != nil {
return err
}
if res.GetStatus().GetCode() != rpc.Code_CODE_OK {
return errtypes.NewErrtypeFromStatus(res.GetStatus())
}
id = res.GetGroup().GetId()
return nil
})
if err != nil {
return nil, err
}
m.groupCache[opaqueID] = id
return id, nil
}
func (m *ImportSpaceMembersMigration) spaceGrantToShares(ctx context.Context, grant *provider.Grant, space *provider.StorageSpace) (*collaboration.Share, []share.ReceivedShareWithUser, error) {
// The grantee ids as persisted on disk do not have an IDP or type stored as
// part of the userid/groupid. Resolve them via the gateway so we get the
// full userid
switch grant.GetGrantee().GetType() {
case provider.GranteeType_GRANTEE_TYPE_GROUP:
groupID, err := m.resolveGroupID(ctx, grant.GetGrantee().GetGroupId().GetOpaqueId())
if err != nil {
return nil, nil, fmt.Errorf("resolve group %s: %w", grant.GetGrantee().GetGroupId().GetOpaqueId(), err)
}
grant.Grantee.Id = &provider.Grantee_GroupId{GroupId: groupID}
case provider.GranteeType_GRANTEE_TYPE_USER:
userID, err := m.resolveUserID(ctx, grant.GetGrantee().GetUserId().GetOpaqueId())
if err != nil {
return nil, nil, fmt.Errorf("resolve user %s: %w", grant.GetGrantee().GetUserId().GetOpaqueId(), err)
}
grant.Grantee.Id = &provider.Grantee_UserId{UserId: userID}
}
ref := &collaboration.ShareReference{
Spec: &collaboration.ShareReference_Key{
Key: &collaboration.ShareKey{
ResourceId: space.GetRoot(),
Grantee: grant.GetGrantee(),
},
},
}
ctx = ctxpkg.ContextSetUser(ctx, &userpb.User{Id: grant.Creator})
if s, err := m.cfg.manager.GetShare(ctx, ref); err == nil {
// FIXME: Verify the actual grants?
m.cfg.logger.Debug().Interface("share", s).Msg("share already exists")
return nil, nil, nil
}
ts := utils.TSNow()
shareID := shareid.Encode(space.GetRoot().GetStorageId(), space.GetRoot().GetSpaceId(), uuid.NewString())
creator := grant.GetCreator()
if creator.Type == userpb.UserType_USER_TYPE_INVALID {
creator = nil
}
newShare := &collaboration.Share{
Id: &collaboration.ShareId{OpaqueId: shareID},
ResourceId: space.GetRoot(),
Permissions: &collaboration.SharePermissions{Permissions: grant.GetPermissions()},
Grantee: grant.GetGrantee(),
Expiration: grant.GetExpiration(),
Owner: creator,
Creator: creator,
Ctime: ts,
Mtime: ts,
}
var newReceivedShares []share.ReceivedShareWithUser
switch grant.GetGrantee().GetType() {
case provider.GranteeType_GRANTEE_TYPE_GROUP:
gwc, err := m.cfg.gatewaySelector.Next()
if err != nil {
m.cfg.logger.Error().Err(err).Msg("Failed to get gateway client")
return nil, nil, err
}
gr, err := gwc.GetMembers(ctx, &grouppb.GetMembersRequest{
GroupId: grant.GetGrantee().GetGroupId(),
})
if err != nil {
m.cfg.logger.Error().Err(err).Msg("Failed to expand group membership")
return nil, nil, err
}
if gr.GetStatus().GetCode() != rpc.Code_CODE_OK {
m.cfg.logger.Error().Str("Status", gr.GetStatus().GetMessage()).Msg("Failed to expand group membership")
return nil, nil, errtypes.NewErrtypeFromStatus(gr.GetStatus())
}
for _, u := range gr.GetMembers() {
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
UserID: u,
ReceivedShare: &collaboration.ReceivedShare{
Share: newShare,
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
},
})
}
// Also add a group-level entry (UserID == nil) so the group cache is populated.
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
UserID: nil,
ReceivedShare: &collaboration.ReceivedShare{
Share: newShare,
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
},
})
case provider.GranteeType_GRANTEE_TYPE_USER:
newReceivedShares = append(newReceivedShares, share.ReceivedShareWithUser{
UserID: grant.GetGrantee().GetUserId(),
ReceivedShare: &collaboration.ReceivedShare{
Share: newShare,
State: collaboration.ShareState_SHARE_STATE_ACCEPTED,
},
})
}
return newShare, newReceivedShares, nil
}
// storageProviderForSpace resolves the storageprovider responsible for the
// given storage space and returns a dialled client. In the default opencloud
// deployment the storage registry is co-located with the gateway, so
// the GatewayAddr is used as the registry address.
func (m *ImportSpaceMembersMigration) storageProviderForSpace(ctx context.Context, space *provider.StorageSpace) (provider.ProviderAPIClient, error) {
srClient, err := pool.GetStorageRegistryClient(m.cfg.providerRegistryAddr)
if err != nil {
return nil, fmt.Errorf("get storage registry client: %w", err)
}
spaceJSON, err := json.Marshal(space)
if err != nil {
return nil, fmt.Errorf("marshal space: %w", err)
}
res, err := srClient.GetStorageProviders(ctx, &registry.GetStorageProvidersRequest{
Opaque: &typesv1beta1.Opaque{
Map: map[string]*typesv1beta1.OpaqueEntry{
"space": {
Decoder: "json",
Value: spaceJSON,
},
},
},
})
if err != nil {
return nil, fmt.Errorf("GetStorageProviders: %w", err)
}
if len(res.GetProviders()) == 0 {
return nil, fmt.Errorf("no storage provider found for space %s", space.GetId().GetOpaqueId())
}
c, err := pool.GetStorageProviderServiceClient(res.GetProviders()[0].GetAddress())
if err != nil {
return nil, fmt.Errorf("dial storage provider: %w", err)
}
return c, nil
}
@@ -0,0 +1,353 @@
// Copyright 2026 OpenCloud GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// In applying this license, CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.
package migration
import (
"cmp"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"os"
"os/signal"
"slices"
"syscall"
"time"
gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
"github.com/opencloud-eu/reva/v2/pkg/rgrpc/todo/pool"
"github.com/opencloud-eu/reva/v2/pkg/share"
"github.com/opencloud-eu/reva/v2/pkg/storage/utils/metadata"
"github.com/rs/zerolog"
)
const stateFile = "migrations/state.json"
const (
lockFile = "migrations/lock.json"
lockTTL = time.Minute
lockHeartbeatInterval = 20 * time.Second
)
// lockPollInterval is how long acquireLock sleeps between retries when the
// lock is held by another instance. Declared as a variable so tests can
// shorten it without rebuilding.
var lockPollInterval = 5 * time.Second
// lockData is the content written to the lock file.
type lockData struct {
Timestamp time.Time `json:"timestamp"`
InstanceID string `json:"instance_id"`
}
type migration interface {
Name() string
Version() int
Initialize(config)
Migrate() error
}
// persistedState is the on-disk representation of the migration state.
type persistedState struct {
Version int `json:"version"`
}
type state struct {
version int
}
// MigrationConfig holds all caller-supplied options for a migration run.
// It is intentionally a plain struct so that new fields can be added without
// changing function signatures throughout the call chain.
type MigrationConfig struct {
ServiceAccountID string
ServiceAccountSecret string
ProviderRegistryAddr string
}
type config struct {
logger zerolog.Logger
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient]
storage metadata.Storage
serviceAccountID string
serviceAccountSecret string
providerRegistryAddr string
manager share.Manager
loader share.LoadableManager
}
type Migrations struct {
config
state state
instanceID string
}
var migrations []migration
// registerMigration is only supposed to be call from init(), which runs sequentially
// so we don't need ot protect migrations with a lock
func registerMigration(m migration) {
migrations = append(migrations, m)
}
func New(logger zerolog.Logger,
gatewaySelector pool.Selectable[gatewayv1beta1.GatewayAPIClient],
storage metadata.Storage,
cfg MigrationConfig,
manager share.Manager,
loader share.LoadableManager,
) Migrations {
slices.SortFunc(migrations, func(a, b migration) int {
return cmp.Compare(a.Version(), b.Version())
})
b := make([]byte, 8)
_, _ = rand.Read(b)
instanceID := fmt.Sprintf("%x", b)
return Migrations{
config{
logger: logger.With().Str("jsoncs3", "migrations").Logger(),
gatewaySelector: gatewaySelector,
storage: storage,
serviceAccountID: cfg.ServiceAccountID,
serviceAccountSecret: cfg.ServiceAccountSecret,
providerRegistryAddr: cfg.ProviderRegistryAddr,
manager: manager,
loader: loader,
},
state{},
instanceID,
}
}
// acquireLock tries to atomically create the lock file, blocking until the lock
// is obtained. It returns the etag of the lock file on success. It retries
// indefinitely until ctx is cancelled. A lock whose timestamp is older than
// lockTTL is considered stale and will be taken over.
func (m *Migrations) acquireLock(ctx context.Context) (string, error) {
m.logger.Debug().Str("instance", m.instanceID).Msg("acquiring migration lock")
for {
// Fast path: create the lock file only if it does not exist yet.
data, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
if err != nil {
return "", err
}
res, err := m.storage.Upload(ctx, metadata.UploadRequest{
Path: lockFile,
Content: data,
IfNoneMatch: []string{"*"},
})
if err == nil {
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock acquired")
return res.Etag, nil
}
// Propagate context cancellation immediately.
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
// Any error other than a conflict means something unexpected happened.
if !isConflict(err) {
return "", err
}
// Lock file already exists — read it to decide whether it is stale.
dl, err := m.storage.Download(ctx, metadata.DownloadRequest{Path: lockFile})
if err != nil {
if _, ok := err.(errtypes.IsNotFound); ok {
// Lock was released between our upload attempt and the download;
// retry acquiring it immediately.
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock vanished during read; retrying")
continue
}
return "", err
}
var existing lockData
stale := true
if err := json.Unmarshal(dl.Content, &existing); err == nil {
stale = time.Since(existing.Timestamp) > lockTTL
}
if stale {
m.logger.Debug().
Str("instance", m.instanceID).
Str("held_by", existing.InstanceID).
Time("lock_timestamp", existing.Timestamp).
Msg("migration lock is stale; attempting takeover")
// Atomically take over the stale lock using the etag we just read.
newData, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
if err != nil {
return "", err
}
res, err := m.storage.Upload(ctx, metadata.UploadRequest{
Path: lockFile,
Content: newData,
IfMatchEtag: dl.Etag,
})
if err == nil {
m.logger.Debug().Str("instance", m.instanceID).Msg("migration lock acquired via stale takeover")
return res.Etag, nil
}
// Another instance took the stale lock before us; loop and retry.
m.logger.Debug().Str("instance", m.instanceID).Err(err).Msg("stale lock takeover lost race; retrying")
continue
}
m.logger.Debug().
Str("instance", m.instanceID).
Str("held_by", existing.InstanceID).
Time("lock_timestamp", existing.Timestamp).
Dur("poll_interval", lockPollInterval).
Msg("migration lock held by another instance; waiting")
// Lock is fresh and held by another instance; wait before retrying.
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(lockPollInterval):
}
}
}
// startHeartbeat spawns a goroutine that periodically renews the lock file so
// that it is not considered stale while a long migration is running. Call the
// returned cancel function to stop the heartbeat.
func (m *Migrations) startHeartbeat(ctx context.Context, etag string) context.CancelFunc {
hbCtx, cancel := context.WithCancel(ctx)
go func() {
ticker := time.NewTicker(lockHeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-hbCtx.Done():
return
case <-ticker.C:
data, err := json.Marshal(lockData{Timestamp: time.Now(), InstanceID: m.instanceID})
if err != nil {
m.logger.Warn().Err(err).Msg("failed to marshal heartbeat data for migration lock")
return
}
res, err := m.storage.Upload(hbCtx, metadata.UploadRequest{
Path: lockFile,
Content: data,
IfMatchEtag: etag,
})
if err != nil {
m.logger.Warn().Err(err).Msg("failed to renew migration lock; another instance may take over")
return
}
etag = res.Etag
}
}
}()
return cancel
}
// releaseLock deletes the lock file unconditionally.
func (m *Migrations) releaseLock(ctx context.Context) {
if err := m.storage.Delete(ctx, lockFile); err != nil {
m.logger.Warn().Err(err).Msg("failed to release migration lock")
}
}
// isConflict returns true for errors that signal a conditional-upload conflict,
// i.e. the lock file already exists or the etag did not match.
func isConflict(err error) bool {
switch err.(type) {
case errtypes.IsAlreadyExists, errtypes.IsAborted, errtypes.IsPreconditionFailed:
return true
}
return false
}
// loadState reads the persisted migration version from storage. If no state
// file exists yet (fresh deployment) it returns version 0 without error.
func (m *Migrations) loadState(ctx context.Context) error {
data, err := m.storage.SimpleDownload(ctx, stateFile)
if err != nil {
if _, ok := err.(errtypes.IsNotFound); ok {
m.state = state{version: 0}
return nil
}
return err
}
var ps persistedState
if err := json.Unmarshal(data, &ps); err != nil {
return err
}
m.state = state{version: ps.Version}
return nil
}
// saveState writes the current migration version to storage so that already-
// applied migrations are not re-run on the next server start.
func (m *Migrations) saveState(ctx context.Context) error {
data, err := json.Marshal(persistedState{Version: m.state.version})
if err != nil {
return err
}
return m.storage.SimpleUpload(ctx, stateFile, data)
}
func (m *Migrations) RunMigrations() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
etag, err := m.acquireLock(ctx)
if err != nil {
m.logger.Error().Err(err).Msg("failed to acquire migration lock; skipping migrations")
return
}
cancelHB := m.startHeartbeat(ctx, etag)
defer cancelHB()
defer m.releaseLock(ctx)
if err := m.loadState(ctx); err != nil {
m.logger.Error().Err(err).Msg("failed to load migration state; skipping migrations")
return
}
m.logger.Info().Int("current state", m.state.version).Msg("checking migrations")
for _, mig := range migrations {
if mig.Version() > m.state.version {
m.logger.Info().Str("migration", mig.Name()).Int("version", mig.Version()).Msg("running migration")
mig.Initialize(m.config)
if err := mig.Migrate(); err != nil {
m.logger.Error().Err(err).Str("migration", mig.Name()).Msg("migration failed; stopping")
return
}
m.state.version = mig.Version()
if err := m.saveState(ctx); err != nil {
m.logger.Error().Err(err).Msg("failed to save migration state; stopping")
return
}
} else {
m.logger.Info().Str("migration", mig.Name()).Int("version", mig.Version()).Msg("skipping migration")
}
}
}
+2 -1
View File
@@ -28,6 +28,7 @@ import (
ctxpkg "github.com/opencloud-eu/reva/v2/pkg/ctx"
"github.com/opencloud-eu/reva/v2/pkg/share"
"github.com/rs/zerolog"
"google.golang.org/genproto/protobuf/field_mask"
userv1beta1 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
@@ -46,7 +47,7 @@ func init() {
}
// New returns a new manager.
func New(c map[string]interface{}) (share.Manager, error) {
func New(c map[string]any, _ *zerolog.Logger) (share.Manager, error) {
state := map[string]map[*collaboration.ShareId]collaboration.ShareState{}
mp := map[string]map[*collaboration.ShareId]*provider.Reference{}
return &manager{
@@ -18,11 +18,14 @@
package registry
import "github.com/opencloud-eu/reva/v2/pkg/share"
import (
"github.com/opencloud-eu/reva/v2/pkg/share"
"github.com/rs/zerolog"
)
// NewFunc is the function that share managers
// should register at init time.
type NewFunc func(map[string]interface{}) (share.Manager, error)
type NewFunc func(map[string]any, *zerolog.Logger) (share.Manager, error)
// NewFuncs is a map containing all the registered share managers.
var NewFuncs = map[string]NewFunc{}
+1 -1
View File
@@ -72,7 +72,7 @@ func NewNatsKeyValueFromJetStream(c Config, js jetstream.JetStream) (jetstream.K
if err != nil {
kvConfig := jetstream.KeyValueConfig{
Bucket: c.Database,
TTL: 0, // we don't do TTLs for this store
TTL: c.TTL,
}
if c.DisablePersistence {
kvConfig.Storage = jetstream.MemoryStorage
@@ -65,16 +65,16 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
} else {
err := c.kv.Purge(ctx, baseKey)
if err != nil && err != nats.ErrKeyNotFound {
appctx.GetLogger(ctx).Error().Err(err).Str("record", path).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
appctx.GetLogger(ctx).Error().Err(err).Str("record", baseKey).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
}
err = c.kv.Purge(ctx, cacheKey(spaceID, nodeID))
if err != nil && err != nats.ErrKeyNotFound {
appctx.GetLogger(ctx).Error().Err(err).Str("record", path).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
appctx.GetLogger(ctx).Error().Err(err).Str("record", cacheKey(spaceID, nodeID)).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
}
}
watcher, err := c.kv.Watch(ctx, baseKey+".*")
watcher, err := c.kv.Watch(ctx, baseKey+".>")
if err != nil {
return err
}
@@ -85,7 +85,6 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
break
}
key := update.Key()
spaceID, nodeID, ok := c.getByReverseCacheKey(ctx, key)
if !ok {
appctx.GetLogger(ctx).Error().Str("record", key).Msg("could not get spaceID and nodeID from cache")
@@ -94,12 +93,12 @@ func (c *IDCache) DeleteByPath(ctx context.Context, path string) error {
err := c.kv.Purge(ctx, key)
if err != nil && err != nats.ErrKeyNotFound {
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
}
err = c.kv.Purge(ctx, cacheKey(spaceID, nodeID))
if err != nil && err != nats.ErrKeyNotFound {
appctx.GetLogger(ctx).Error().Err(err).Str("record", key).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not get spaceID and nodeID from cache")
appctx.GetLogger(ctx).Error().Err(err).Str("record", cacheKey(spaceID, nodeID)).Str("spaceID", spaceID).Str("nodeID", nodeID).Msg("could not purge from cache")
}
}
return nil
+11 -1
View File
@@ -24,6 +24,7 @@ import (
"os"
"path/filepath"
"syscall"
"time"
"github.com/rs/zerolog"
tusd "github.com/tus/tusd/v2/pkg/handler"
@@ -58,7 +59,8 @@ func init() {
type posixFS struct {
storage.FS
um usermapper.Mapper
tree *tree.Tree
um usermapper.Mapper
}
// New returns an implementation to of the storage.FS interface that talk to
@@ -70,6 +72,7 @@ func NewDefault(m map[string]interface{}, stream events.Stream, log *zerolog.Log
}
o.IDCache.Database += "_v2" // Use a versioned bucket name to avoid conflicts with previous implementations
o.IDCache.TTL = 0 // Disable TTL for the ID cache, as the posix driver relies on it for caching file IDs and we don't want them to expire
kv, err := cache.NewNatsKeyValue(o.IDCache)
if err != nil {
return nil, errors.Wrap(err, "could not create nats key value store")
@@ -80,6 +83,7 @@ func NewDefault(m map[string]interface{}, stream events.Stream, log *zerolog.Log
}
o.IDCache.Database += "_history" // Use a versioned bucket name to avoid conflicts with previous implementations
o.IDCache.TTL = 24 * 60 * time.Minute
historyKv, err := cache.NewNatsKeyValue(o.IDCache)
if err != nil {
return nil, errors.Wrap(err, "could not create nats key value store")
@@ -215,11 +219,17 @@ func New(o *options.Options, stream events.Stream, cache, historyCache *idcache.
mw := middleware.NewFS(dfs, hooks...)
fs.FS = mw
fs.tree = tp
fs.um = um
return fs, nil
}
// WarmupIDCache allows triggering a posix fs scan and id cache warmup manually.
func (fs *posixFS) WarmupIDCache(root string, assimilate, onlyDirty bool) error {
return fs.tree.WarmupIDCache(root, assimilate, onlyDirty)
}
// ListUploadSessions returns the upload sessions matching the given filter
func (fs *posixFS) ListUploadSessions(ctx context.Context, filter storage.UploadSessionFilter) ([]storage.UploadSession, error) {
return fs.FS.(storage.UploadSessionLister).ListUploadSessions(ctx, filter)
@@ -37,10 +37,8 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
user "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
provider "github.com/cs3org/go-cs3apis/cs3/storage/provider/v1beta1"
"github.com/opencloud-eu/reva/v2/pkg/appctx"
"github.com/opencloud-eu/reva/v2/pkg/errtypes"
"github.com/opencloud-eu/reva/v2/pkg/events"
"github.com/opencloud-eu/reva/v2/pkg/storage/fs/posix/blobstore"
@@ -590,11 +588,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) error {
}
}()
if appctx.DeletingSharedResourceFromContext(ctx) {
src := filepath.Join(n.ParentPath(), n.Name)
return os.RemoveAll(src)
}
var sizeDiff int64
if n.IsDir(ctx) {
treesize, err := n.GetTreeSize(ctx)
@@ -819,11 +812,3 @@ func isLockFile(path string) bool {
func isTrash(path string) bool {
return strings.HasSuffix(path, ".trashinfo") || strings.HasSuffix(path, ".trashitem") || strings.Contains(path, ".Trash")
}
func (t *Tree) AddLabel(ctx context.Context, ref *provider.Reference, userID *user.UserId, label string) error {
return errtypes.NotSupported("AddLabel not implemented")
}
func (t *Tree) RemoveLabel(ctx context.Context, ref *provider.Reference, userID *user.UserId, label string) error {
return errtypes.NotSupported("RemoveLabel not implemented")
}
@@ -101,6 +101,7 @@ func ServiceAccountPermissions() *provider.ResourcePermissions {
Delete: true, // for cli restore command with replace option
CreateContainer: true, // for space provisioning
AddGrant: true, // for initial project space member assignment
ListGrants: true, // for initial project space member assignment
}
}
@@ -429,11 +429,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) (err error) {
// remove entry from cache immediately to avoid inconsistencies
defer func() { _ = t.idCache.Delete(path) }()
if appctx.DeletingSharedResourceFromContext(ctx) {
src := filepath.Join(n.ParentPath(), n.Name)
return os.Remove(src)
}
// get the original path
origin, err := t.lookup.Path(ctx, n, node.NoCheck)
if err != nil {
@@ -445,11 +445,6 @@ func (t *Tree) Delete(ctx context.Context, n *node.Node) (err error) {
// remove entry from cache immediately to avoid inconsistencies
defer func() { _ = t.idCache.Delete(path) }()
if appctx.DeletingSharedResourceFromContext(ctx) {
src := filepath.Join(n.ParentPath(), n.Name)
return os.Remove(src)
}
// get the original path
origin, err := t.lookup.Path(ctx, n, node.NoCheck)
if err != nil {
+38 -1
View File
@@ -93,6 +93,40 @@ func (disk *Disk) SimpleUpload(ctx context.Context, uploadpath string, content [
// Upload stores a file on disk
func (disk *Disk) Upload(_ context.Context, req UploadRequest) (*UploadResponse, error) {
p := disk.targetPath(req.Path)
// IfNoneMatch: ["*"] means create the file only if it does not already
// exist. Use O_EXCL so the check and the create are atomic on the local
// filesystem.
for _, tag := range req.IfNoneMatch {
if tag != "*" {
continue
}
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
if errors.Is(err, os.ErrExist) {
return nil, errtypes.AlreadyExists(p)
}
return nil, err
}
if _, err := f.Write(req.Content); err != nil {
_ = f.Close()
return nil, err
}
if err := f.Close(); err != nil {
return nil, err
}
info, err := os.Stat(p)
if err != nil {
return nil, err
}
res := &UploadResponse{}
res.Etag, err = calcEtag(info.ModTime(), info.Size())
if err != nil {
return nil, err
}
return res, nil
}
if req.IfMatchEtag != "" {
info, err := os.Stat(p)
if err != nil && !errors.Is(err, os.ErrNotExist) {
@@ -170,7 +204,10 @@ func (disk *Disk) Download(_ context.Context, req DownloadRequest) (*DownloadRes
// SimpleDownload reads a file from disk
func (disk *Disk) SimpleDownload(ctx context.Context, downloadpath string) ([]byte, error) {
res, err := disk.Download(ctx, DownloadRequest{Path: downloadpath})
return res.Content, err
if err != nil {
return nil, err
}
return res.Content, nil
}
// Delete deletes a path
+48 -41
View File
@@ -266,15 +266,10 @@ func (i *Identity) GetLDAPUserByFilter(ctx context.Context, lc ldap.Client, filt
res, err := lc.Search(searchRequest)
if err != nil {
log.Debug().Str("backend", "ldap").Err(err).Str("userfilter", filter).Msg("Error looking up user by filter")
var errmsg string
if lerr, ok := err.(*ldap.Error); ok {
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
errmsg = fmt.Sprintf("too many results searching for user '%s'", filter)
}
}
span.SetAttributes(attribute.String("ldap.error", errmsg))
span.SetStatus(codes.Error, errmsg)
return nil, errtypes.NotFound(errmsg)
classified := classifySearchError(err, fmt.Sprintf("too many results searching for user '%s'", filter))
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
if len(res.Entries) == 0 {
return nil, errtypes.NotFound(filter)
@@ -306,9 +301,10 @@ func (i *Identity) GetLDAPUserByDN(ctx context.Context, lc ldap.Client, dn strin
res, err := lc.Search(searchRequest)
if err != nil {
log.Debug().Str("backend", "ldap").Err(err).Str("dn", dn).Msg("Error looking up user by DN")
span.SetAttributes(attribute.String("ldap.error", err.Error()))
span.SetStatus(codes.Error, "")
return nil, errtypes.NotFound(dn)
classified := classifySearchError(err, "")
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
span.SetStatus(codes.Ok, "")
if len(res.Entries) == 0 {
@@ -337,9 +333,10 @@ func (i *Identity) GetLDAPUsers(ctx context.Context, lc ldap.Client, query, tena
sr, err := lc.Search(searchRequest)
if err != nil {
log.Debug().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error searching users")
span.SetAttributes(attribute.String("ldap.error", err.Error()))
span.SetStatus(codes.Error, "")
return nil, errtypes.NotFound(query)
classified := classifySearchError(err, "")
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
span.SetAttributes(attribute.Int("ldap.result_count", len(sr.Entries)))
@@ -376,7 +373,8 @@ func (i *Identity) IsLDAPUserInDisabledGroup(ctx context.Context, lc ldap.Client
sr, err := lc.Search(searchRequest)
if err != nil {
log.Error().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error looking up error group")
// Err on the side of caution.
// Err on the side of caution: treat search failures (including network
// errors) as if the user is in the disabled group.
span.SetAttributes(attribute.String("ldap.error", err.Error()))
span.SetStatus(codes.Error, "")
return true
@@ -423,10 +421,10 @@ func (i *Identity) GetLDAPUserGroups(ctx context.Context, lc ldap.Client, userEn
// not having any groups in LDAP
return []string{}, nil
}
span.SetAttributes(attribute.String("ldap.error", err.Error()))
span.SetStatus(codes.Error, "")
return []string{}, err
classified := classifySearchError(err, "")
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
span.SetStatus(codes.Ok, "")
span.SetAttributes(attribute.Int("ldap.result_count", len(sr.Entries)))
@@ -504,15 +502,10 @@ func (i *Identity) GetLDAPGroupByFilter(ctx context.Context, lc ldap.Client, fil
res, err := lc.Search(searchRequest)
if err != nil {
log.Debug().Str("backend", "ldap").Err(err).Str("filter", filter).Msg("Error looking up group by filter")
var errmsg string
if lerr, ok := err.(*ldap.Error); ok {
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
errmsg = fmt.Sprintf("too many results searching for group '%s'", filter)
}
}
span.SetAttributes(attribute.String("ldap.error", errmsg))
span.SetStatus(codes.Error, "")
return nil, errtypes.NotFound(errmsg)
classified := classifySearchError(err, fmt.Sprintf("too many results searching for group '%s'", filter))
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
if len(res.Entries) == 0 {
return nil, errtypes.NotFound(filter)
@@ -543,10 +536,11 @@ func (i *Identity) GetLDAPGroups(ctx context.Context, lc ldap.Client, query stri
setLDAPSearchSpanAttributes(span, searchRequest)
sr, err := lc.Search(searchRequest)
if err != nil {
span.SetAttributes(attribute.String("ldap.error", err.Error()))
span.SetStatus(codes.Error, "")
log.Debug().Str("backend", "ldap").Err(err).Str("query", query).Msg("Error search for groups")
return nil, errtypes.NotFound(query)
classified := classifySearchError(err, "")
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
span.SetStatus(codes.Ok, "")
return sr.Entries, nil
@@ -919,15 +913,10 @@ func (i *Identity) GetLDAPTenantByFilter(ctx context.Context, lc ldap.Client, fi
res, err := lc.Search(searchRequest)
if err != nil {
log.Debug().Str("backend", "ldap").Err(err).Str("tenantfilter", filter).Msg("Error looking up tenant by filter")
var errmsg string
if lerr, ok := err.(*ldap.Error); ok {
if lerr.ResultCode == ldap.LDAPResultSizeLimitExceeded {
errmsg = fmt.Sprintf("too many results searching for tenant '%s'", filter)
}
}
span.SetAttributes(attribute.String("ldap.error", errmsg))
span.SetStatus(codes.Error, errmsg)
return nil, errtypes.NotFound(errmsg)
classified := classifySearchError(err, fmt.Sprintf("too many results searching for tenant '%s'", filter))
span.SetAttributes(attribute.String("ldap.error", classified.Error()))
span.SetStatus(codes.Error, classified.Error())
return nil, classified
}
if len(res.Entries) == 0 {
return nil, errtypes.NotFound(filter)
@@ -980,6 +969,24 @@ func (i *Identity) getTenantAttributeFilter(attribute, value string) (string, er
), nil
}
// classifySearchError maps a raw error from lc.Search to the appropriate
// errtypes value:
// - ldap.ErrorNetwork → errtypes.Unavailable (transient; caller should retry)
// - ldap.LDAPResultSizeLimitExceeded → errtypes.NotFound(sizeExceededMsg)
// - anything else → errtypes.NotFound("") (preserving prior behaviour)
//
// The sizeExceededMsg is only used for the SizeLimitExceeded case; pass an
// empty string if the caller does not need a custom message for that case.
func classifySearchError(err error, sizeExceededMsg string) error {
if ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
return errtypes.Unavailable("ldap server unreachable: " + err.Error())
}
if sizeExceededMsg != "" && ldap.IsErrorWithCode(err, ldap.LDAPResultSizeLimitExceeded) {
return errtypes.NotFound(sizeExceededMsg)
}
return errtypes.NotFound("")
}
func setLDAPSearchSpanAttributes(span trace.Span, request *ldap.SearchRequest) {
span.SetAttributes(
attribute.String("ldap.basedn", request.BaseDN),
+1 -1
View File
@@ -586,7 +586,7 @@ func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader
// Length of encrypted portion of the packet (header, payload, padding).
// Enforce minimum padding and packet size.
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPacketSize)
// Enforce block size.
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
+7 -3
View File
@@ -274,10 +274,14 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiA
}
// Filter algorithms based on those supported by MultiAlgorithmSigner.
// Iterate over the signer's algorithms first to preserve its preference order.
supportedKeyAlgos := algorithmsForKeyFormat(keyFormat)
var keyAlgos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if slices.Contains(as.Algorithms(), underlyingAlgo(algo)) {
keyAlgos = append(keyAlgos, algo)
for _, signerAlgo := range as.Algorithms() {
if idx := slices.IndexFunc(supportedKeyAlgos, func(algo string) bool {
return underlyingAlgo(algo) == signerAlgo
}); idx >= 0 {
keyAlgos = append(keyAlgos, supportedKeyAlgos[idx])
}
}
+32 -1
View File
@@ -61,13 +61,42 @@ func (r *responseDeduper) addAll(dr *DriverResponse) {
}
func (r *responseDeduper) addPackage(p *Package) {
if r.seenPackages[p.ID] != nil {
if prev := r.seenPackages[p.ID]; prev != nil {
// Package already seen in a previous response. Merge the file lists,
// removing duplicates. This can happen when the same package appears
// in multiple driver responses that are being merged together.
prev.GoFiles = appendUniqueStrings(prev.GoFiles, p.GoFiles)
prev.CompiledGoFiles = appendUniqueStrings(prev.CompiledGoFiles, p.CompiledGoFiles)
prev.OtherFiles = appendUniqueStrings(prev.OtherFiles, p.OtherFiles)
prev.IgnoredFiles = appendUniqueStrings(prev.IgnoredFiles, p.IgnoredFiles)
prev.EmbedFiles = appendUniqueStrings(prev.EmbedFiles, p.EmbedFiles)
prev.EmbedPatterns = appendUniqueStrings(prev.EmbedPatterns, p.EmbedPatterns)
return
}
r.seenPackages[p.ID] = p
r.dr.Packages = append(r.dr.Packages, p)
}
// appendUniqueStrings appends elements from src to dst, skipping duplicates.
func appendUniqueStrings(dst, src []string) []string {
if len(src) == 0 {
return dst
}
seen := make(map[string]bool, len(dst))
for _, s := range dst {
seen[s] = true
}
for _, s := range src {
if !seen[s] {
dst = append(dst, s)
}
}
return dst
}
func (r *responseDeduper) addRoot(id string) {
if r.seenRoots[id] {
return
@@ -832,6 +861,8 @@ func golistargs(cfg *Config, words []string, goVersion int) []string {
// go list doesn't let you pass -test and -find together,
// probably because you'd just get the TestMain.
fmt.Sprintf("-find=%t", !cfg.Tests && cfg.Mode&findFlags == 0 && !usesExportData(cfg)),
// VCS information is not needed when not printing Stale or StaleReason fields
"-buildvcs=false",
}
// golang/go#60456: with go1.21 and later, go list serves pgo variants, which

Some files were not shown because too many files have changed in this diff Show More